├── README.md ├── gen_probing_samples.py ├── label_inference.py ├── prepare_model.py ├── recovery ├── __init__.py ├── consts.py ├── data_processing.py ├── metrics.py ├── nn │ ├── __init__.py │ ├── densenet.py │ ├── models.py │ ├── models │ │ ├── attention.py │ │ ├── densenet.py │ │ ├── googlenet.py │ │ ├── inceptionv3.py │ │ ├── inceptionv4.py │ │ ├── mobilenet.py │ │ ├── mobilenetv2.py │ │ ├── nasnet.py │ │ ├── preactresnet.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── rir.py │ │ ├── senet.py │ │ ├── shufflenet.py │ │ ├── shufflenetv2.py │ │ ├── squeezenet.py │ │ ├── stochasticdepth.py │ │ ├── vgg.py │ │ ├── wideresidual.py │ │ └── xception.py │ ├── modules.py │ ├── revnet.py │ └── revnet_utils.py ├── optimization_strategy.py ├── recovery_algo.py ├── training.py └── utils.py └── recovery_data.py /README.md: -------------------------------------------------------------------------------- 1 | # Unlearning Inversion Attacks 2 | 3 | This repository contains Python code for the inversion attack in machine unlearning in the paper "Learn What You Want to Unlearn: Unlearning Inversion Attacks against Machine Unlearning". 4 | 5 | The code has been tested under CUDA 11.8, Python 3.9.13 and Pytorch 2.0.1. 6 | 7 | There are two components in our code: preparation of the pretrained model (for learning and unlearning later) and data recovery/label inference for unlearned model. 8 | 9 | 10 | ## Part 1: Prepare the pretrained model 11 | Run the following command (with ResNet18 and STL-10 as example) to prepare the 12 | ``` 13 | python prepare_model.py --model ResNet18 --dataset stl10 --exclude_num 1000 --seed 0 --save_folder results/models 14 | ``` 15 | where `--exclude_num` is the leave-out samples for learning/unlearning later, the saved checkpoints are stored under `--save_folder`. The dataset are stored under `./datasets`. 16 | 17 | ## Part 2-1: Recover the unlearned samples 18 | The script `recovery_data.py` automatically test the exact unlearning and the approximate unlearning. The command can be 19 | ``` 20 | python recovery_data.py --model ResNet18 --dataset stl10 --ft_samples 1000 --unlearn_samples 1 --seed 0 --model_save_folder results/models 21 | ``` 22 | where `--ft_samples` is the size of finetuning dataset and `--unlearn_samples` controls the number of unlearned samples. 23 | 24 | ## Part 2-2: Infer the unlearned labels 25 | 26 | There are two steps. First, we use `gen_probing_samples.py` to generate probing samples for a given model and for each class. 27 | For example, to generate probing samples for class 0, we run 28 | ``` 29 | python gen_probing_samples.py --classid 0 --model_save_folder results/models --redo_ft 30 | ``` 31 | where `--redo_ft` enables re-finetuning the pretrained model. We can disable it to save time if there is already a finetuned one. 32 | 33 | 34 | Second, we use the probing samples (after manually combining the probing samples from all class into a pickled dictionary named `query_sample_dict.pkl` under the `probing_sample` subfolder under the model folder) and `label_inference.py` to infer the unlearned label for the exact unlearning and the approximate unlearning. The example command can be 35 | 36 | ``` 37 | python label_inference.py --model_save_folder results/models --load_folder_name resnet18_stl10_ex1000_s0 38 | ``` 39 | -------------------------------------------------------------------------------- /label_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import copy 5 | from torchvision import transforms 6 | import numpy as np 7 | import pickle 8 | import recovery 9 | import pandas as pd 10 | 11 | 12 | def make_dict(tmp_dict, output, labels, name): 13 | if tmp_dict is None: 14 | tmp_dict = {'Confidence':[], 'Type':[], 'Class':[]} 15 | 16 | for c in range(output.shape[1]): 17 | idx_per_c = torch.where(labels == c)[0][:10] 18 | conf_list = output[idx_per_c, c].tolist() 19 | 20 | tmp_dict['Confidence'].extend(conf_list) 21 | tmp_dict['Type'].extend([name] * len(conf_list)) 22 | tmp_dict['Class'].extend([c] * len(conf_list)) 23 | return tmp_dict 24 | 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description='') 29 | parser.add_argument('--load_folder_name', type=str, default='resnet18_stl10_ex1000_s0', help='folder of pretrained model') 30 | parser.add_argument('--model_save_folder', type=str, default='results/models', help='folder of pretrained model') 31 | args = parser.parse_args() 32 | 33 | model = 'ResNet18' 34 | dataset = 'stl10' 35 | num_classes = 10 36 | seed = 0 37 | ft_samples = 512 38 | img_size = 32 if dataset == 'cifar10' else 96 39 | excluded_num = 10000 if dataset == 'cifar10' else 1000 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | 44 | 45 | final_dict = torch.load(os.path.join(args.model_save_folder, args.load_folder_name, 'final.pth')) 46 | setup = recovery.utils.system_startup() 47 | defs = recovery.training_strategy('conservative') 48 | defs.lr = 1e-3 49 | defs.epochs = 1 50 | defs.batch_size = 128 51 | defs.optimizer = 'SGD' 52 | defs.scheduler = 'linear' 53 | defs.warmup = False 54 | defs.weight_decay = 0.0 55 | defs.dropout = 0.0 56 | defs.augmentations = False 57 | defs.dryrun = False 58 | 59 | 60 | loss_fn, _org_trainloader, validloader, num_classes, _exd, dmlist, dslist = recovery.construct_dataloaders(dataset.lower(), defs, data_path=f'datasets/{dataset.lower()}', normalize=False, exclude_num=excluded_num) 61 | dm = torch.as_tensor(dmlist, **setup)[:, None, None] 62 | ds = torch.as_tensor(dslist, **setup)[:, None, None] 63 | normalizer = transforms.Normalize(dmlist, dslist) 64 | 65 | 66 | # *** used for batch case *** 67 | excluded_data = final_dict['excluded_data'] 68 | index = torch.tensor(np.random.choice(len(excluded_data[0]), ft_samples, replace=False)) 69 | print("Batch index", index.tolist()) 70 | X_all, y_all = excluded_data[0][index], excluded_data[1][index] 71 | print("FT data size", X_all.shape, y_all.shape) 72 | trainset_all = recovery.data_processing.SubTrainDataset(X_all, y_all, transform=transforms.Normalize(dmlist, dslist)) 73 | trainloader_all = torch.utils.data.DataLoader(trainset_all, batch_size=min(defs.batch_size, len(trainset_all)), shuffle=True, num_workers=8, pin_memory=True) 74 | 75 | 76 | 77 | model_pretrain, _ = recovery.construct_model(model, num_classes=num_classes, num_channels=3) 78 | model_pretrain.load_state_dict(final_dict['net_sd']) 79 | model_pretrain.eval() 80 | 81 | ft_folder = os.path.join(args.model_save_folder, args.load_folder_name, 'probing_samples') 82 | ft_path = os.path.join(ft_folder, f'finetune_{defs.epochs}ep.pt') 83 | 84 | model_ft, _ = recovery.construct_model(model, num_classes=num_classes, num_channels=3) 85 | model_ft.load_state_dict(torch.load(ft_path)) 86 | model_ft.eval() 87 | 88 | 89 | class_pred_dict = {} 90 | for class_id in range(num_classes): 91 | print(class_id) 92 | # ## Unlearn one 93 | 94 | 95 | model_ft = model_ft.cpu() 96 | unlearn_ids = torch.where(y_all == class_id)[0] 97 | unlearn_ids = unlearn_ids[:len(unlearn_ids) // 10] 98 | print(f"Unlearn {unlearn_ids}") 99 | 100 | X = torch.stack([xt for i, xt in enumerate(X_all) if i not in unlearn_ids]) 101 | y = torch.tensor([yt for i, yt in enumerate(y_all) if i not in unlearn_ids]) 102 | print("Exact unlearn data size", X.shape, y.shape) 103 | trainset_unlearn = recovery.data_processing.SubTrainDataset(X, y, transform=transforms.Normalize(dmlist, dslist)) 104 | trainloader_unlearn = torch.utils.data.DataLoader(trainset_unlearn, batch_size=min(defs.batch_size, len(trainset_unlearn)), shuffle=True, num_workers=8, pin_memory=True) 105 | 106 | X_unlearn = torch.stack([xt for i, xt in enumerate(X_all) if i in unlearn_ids]) 107 | y_unlearn = torch.tensor([yt for i, yt in enumerate(y_all) if i in unlearn_ids]) 108 | 109 | print(f"***** Train unlearned model (withouth {unlearn_ids}) *****") 110 | model_unlearn_one, _ = recovery.construct_model(model, num_classes=num_classes, num_channels=3) 111 | model_unlearn_one.load_state_dict(final_dict['net_sd']) 112 | model_unlearn_one.eval() 113 | model_unlearn_one.to(**setup) 114 | unlearn_stats = recovery.train(model_unlearn_one, loss_fn, trainloader_unlearn, validloader, defs, setup=setup, ckpt_path=None, finetune=True) 115 | model_unlearn_one.cpu() 116 | 117 | # approx 118 | batch_size = min(defs.batch_size, len(X_unlearn)) 119 | approx_diff = [p.detach() for p in recovery.recovery_algo.loss_steps(model_ft, normalizer(X_unlearn), y_unlearn, lr=defs.lr, local_steps=defs.epochs * len(X_unlearn) // batch_size, batch_size=batch_size)] 120 | model_app_unlearn_one = copy.deepcopy(model_ft) 121 | 122 | old_params = {} 123 | for i, (name, params) in enumerate(model_app_unlearn_one.named_parameters()): 124 | old_params[name] = params.clone() 125 | old_params[name] += approx_diff[i] 126 | for name, params in model_app_unlearn_one.named_parameters(): 127 | params.data.copy_(old_params[name]) 128 | 129 | 130 | # ## unlearn half 131 | 132 | 133 | unlearn_ids = torch.where(y_all == class_id)[0] 134 | unlearn_ids = unlearn_ids[:len(unlearn_ids) // 2] 135 | print(f"Unlearn {unlearn_ids}") 136 | 137 | X = torch.stack([xt for i, xt in enumerate(X_all) if i not in unlearn_ids]) 138 | y = torch.tensor([yt for i, yt in enumerate(y_all) if i not in unlearn_ids]) 139 | print("Exact unlearn data size", X.shape, y.shape) 140 | trainset_unlearn = recovery.data_processing.SubTrainDataset(X, y, transform=transforms.Normalize(dmlist, dslist)) 141 | trainloader_unlearn = torch.utils.data.DataLoader(trainset_unlearn, batch_size=min(defs.batch_size, len(trainset_unlearn)), shuffle=True, num_workers=8, pin_memory=True) 142 | 143 | X_unlearn = torch.stack([xt for i, xt in enumerate(X_all) if i in unlearn_ids]) 144 | y_unlearn = torch.tensor([yt for i, yt in enumerate(y_all) if i in unlearn_ids]) 145 | 146 | print(f"***** Train unlearned model (withouth {unlearn_ids}) *****") 147 | model_unlearn_half, _ = recovery.construct_model(model, num_classes=num_classes, num_channels=3) 148 | model_unlearn_half.load_state_dict(final_dict['net_sd']) 149 | model_unlearn_half.eval() 150 | model_unlearn_half.to(**setup) 151 | unlearn_stats = recovery.train(model_unlearn_half, loss_fn, trainloader_unlearn, validloader, defs, setup=setup, ckpt_path=None, finetune=True) 152 | model_unlearn_half.cpu() 153 | 154 | # approx 155 | batch_size = min(defs.batch_size, len(X_unlearn)) 156 | approx_diff = [p.detach() for p in recovery.recovery_algo.loss_steps(model_ft, normalizer(X_unlearn), y_unlearn, lr=defs.lr, local_steps=defs.epochs * len(X_unlearn) // batch_size, batch_size=batch_size)] 157 | model_app_unlearn_half = copy.deepcopy(model_ft) 158 | 159 | old_params = {} 160 | for i, (name, params) in enumerate(model_app_unlearn_half.named_parameters()): 161 | old_params[name] = params.clone() 162 | old_params[name] += approx_diff[i] 163 | for name, params in model_app_unlearn_half.named_parameters(): 164 | params.data.copy_(old_params[name]) 165 | 166 | 167 | # ## unlearn all 168 | 169 | 170 | unlearn_ids = torch.where(y_all == class_id)[0] 171 | print(f"Unlearn {unlearn_ids}") 172 | 173 | X = torch.stack([xt for i, xt in enumerate(X_all) if i not in unlearn_ids]) 174 | y = torch.tensor([yt for i, yt in enumerate(y_all) if i not in unlearn_ids]) 175 | print("Exact unlearn data size", X.shape, y.shape) 176 | trainset_unlearn = recovery.data_processing.SubTrainDataset(X, y, transform=transforms.Normalize(dmlist, dslist)) 177 | trainloader_unlearn = torch.utils.data.DataLoader(trainset_unlearn, batch_size=min(defs.batch_size, len(trainset_unlearn)), shuffle=True, num_workers=8, pin_memory=True) 178 | 179 | X_unlearn = torch.stack([xt for i, xt in enumerate(X_all) if i in unlearn_ids]) 180 | y_unlearn = torch.tensor([yt for i, yt in enumerate(y_all) if i in unlearn_ids]) 181 | 182 | print(f"***** Train unlearned model (withouth {unlearn_ids}) *****") 183 | model_unlearn, _ = recovery.construct_model(model, num_classes=num_classes, num_channels=3) 184 | model_unlearn.load_state_dict(final_dict['net_sd']) 185 | model_unlearn.eval() 186 | model_unlearn.to(**setup) 187 | unlearn_stats = recovery.train(model_unlearn, loss_fn, trainloader_unlearn, validloader, defs, setup=setup, ckpt_path=None, finetune=True) 188 | model_unlearn.cpu() 189 | 190 | # approx 191 | batch_size = min(defs.batch_size, len(X_unlearn)) 192 | approx_diff = [p.detach() for p in recovery.recovery_algo.loss_steps(model_ft, normalizer(X_unlearn), y_unlearn, lr=defs.lr, local_steps=defs.epochs * len(X_unlearn) // batch_size, batch_size=batch_size)] # lr is not important in cosine 193 | model_app_unlearn = copy.deepcopy(model_ft) 194 | old_params = {} 195 | for i, (name, params) in enumerate(model_app_unlearn.named_parameters()): 196 | old_params[name] = params.clone() 197 | old_params[name] += approx_diff[i] 198 | for name, params in model_app_unlearn.named_parameters(): 199 | params.data.copy_(old_params[name]) 200 | 201 | 202 | # # Plot 203 | 204 | 205 | sample_dict = pickle.load(open(os.path.join(args.model_save_folder, args.load_folder_name, 'probing_samples','query_sample_dict.pkl'), 'rb')) 206 | sample_datats = torch.cat([x[:10] for x in sample_dict.values()]) 207 | 208 | 209 | with torch.no_grad(): 210 | model_ft.cuda() 211 | test_output = model_ft(normalizer(sample_datats.cuda())).softmax(dim=1).detach().cpu() 212 | model_ft.cpu() 213 | sample_labelts = test_output.argmax(dim=1) 214 | 215 | model_unlearn_one.cuda() 216 | exact_unlearn_class1_one =model_unlearn_one(normalizer(sample_datats.cuda())).detach().softmax(dim=1).cpu() 217 | model_unlearn_one.cpu() 218 | 219 | model_app_unlearn_one.cuda() 220 | approx_unlearn_class1_one =model_app_unlearn_one(normalizer(sample_datats.cuda())).softmax(dim=1).cpu() 221 | model_app_unlearn_one.cpu() 222 | 223 | model_unlearn_half.cuda() 224 | exact_unlearn_class1_half =model_unlearn_half(normalizer(sample_datats.cuda())).softmax(dim=1).cpu() 225 | model_unlearn_half.cpu() 226 | 227 | model_app_unlearn_half.cuda() 228 | approx_unlearn_class1_half =model_app_unlearn_half(normalizer(sample_datats.cuda())).softmax(dim=1).cpu() 229 | model_app_unlearn_half.cpu() 230 | 231 | model_unlearn.cuda() 232 | exact_unlearn_class1 =model_unlearn(normalizer(sample_datats.cuda())).softmax(dim=1).cpu() 233 | model_unlearn.cpu() 234 | 235 | model_app_unlearn.cuda() 236 | approx_unlearn_class1 =model_app_unlearn(normalizer(sample_datats.cuda())).softmax(dim=1).cpu() 237 | model_app_unlearn.cpu() 238 | 239 | 240 | v, i = test_output.max(dim=1) 241 | print(v, i, np.unique(i, return_counts=True)) 242 | 243 | 244 | all_dict = None 245 | all_dict = make_dict(all_dict, exact_unlearn_class1_one - test_output, sample_labelts, 'Exact $(p_u=0.1)$') 246 | all_dict = make_dict(all_dict, approx_unlearn_class1_one - test_output, sample_labelts, 'Approx. $(p_u=0.1)$') 247 | all_dict = make_dict(all_dict, exact_unlearn_class1_half - test_output, sample_labelts, 'Exact $(p_u=0.5)$') 248 | all_dict = make_dict(all_dict, approx_unlearn_class1_half - test_output, sample_labelts, 'Approx. $(p_u=0.5)$') 249 | all_dict = make_dict(all_dict, exact_unlearn_class1 - test_output, sample_labelts, 'Exact $(p_u=1.0)$') 250 | all_dict = make_dict(all_dict, approx_unlearn_class1 - test_output, sample_labelts, 'Approx. $(p_u=1.0)$') 251 | plot_df = pd.DataFrame.from_dict(all_dict) 252 | 253 | 254 | 255 | tmp = plot_df.groupby(['Type', 'Class']).agg({'Confidence': 'mean'}).reset_index() 256 | 257 | tmp['CC'] = tmp[['Class', 'Confidence']].apply(tuple, axis=1) 258 | 259 | class_pred_dict[class_id] = tmp.groupby('Type')['CC'].agg(lambda x: list(x)[np.array([x0[1] for x0 in list(x)]).argmin()][0]) 260 | 261 | print(class_pred_dict) 262 | pickle.dump(class_pred_dict, open(os.path.join(args.model_save_folder, f'{dataset}_unlearnlabel_{defs.epochs}.pkl'), 'wb')) 263 | 264 | -------------------------------------------------------------------------------- /prepare_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import recovery 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser(description='prepare models for exact and approximate unlearning.') 10 | parser.add_argument('--model', default='ConvNet', type=str, help='Vision model.') 11 | parser.add_argument('--dataset', default='cifar10', type=str) 12 | parser.add_argument('--tr_strat', default='conservative', type=str, help='training strategy') 13 | parser.add_argument('--epochs', default=None, type=int, help='updated epochs') 14 | 15 | parser.add_argument('--seed', default=None, type=int, help='random seed') 16 | 17 | parser.add_argument('--exclude_num', default=0, type=int, help='excluded samples during training') 18 | 19 | parser.add_argument('--save_steps', default=10, type=int, help='steps to save ckpt') 20 | 21 | parser.add_argument('--save_folder', default='results/models', type=str, help='result folder') 22 | 23 | parser.add_argument('--save_name', default='', type=str, help='saving file name') 24 | args = parser.parse_args() 25 | if args.seed is not None: 26 | np.random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | torch.cuda.manual_seed(args.seed) 29 | random.seed(args.seed) 30 | 31 | 32 | 33 | 34 | save_folder = os.path.join(args.save_folder, f"{args.model.lower()}_{args.dataset.lower()}_ex{args.exclude_num}_{args.save_name}") 35 | os.makedirs(save_folder, exist_ok=True) 36 | setup = recovery.utils.system_startup() 37 | defs = recovery.training_strategy(args.tr_strat) 38 | defs.validate = args.save_steps 39 | 40 | if args.epochs is not None: 41 | defs.epochs = args.epochs 42 | 43 | loss_fn, trainloader, validloader, num_classes, excluded_data, data_mean, data_std = recovery.construct_dataloaders(args.dataset, defs, data_path=f'datasets/{args.dataset.lower()}', shuffle=True, normalize=True, exclude_num=args.exclude_num) 44 | 45 | 46 | model, _ = recovery.construct_model(args.model, num_classes=num_classes, seed=args.seed, num_channels=3) 47 | model.to(**setup) 48 | 49 | 50 | stats = recovery.train(model, loss_fn, trainloader, validloader, defs, setup=setup, ckpt_path=save_folder) 51 | 52 | resdict = { 53 | 'tr_args': args.__dict__, 54 | 'tr_strat': defs.__dict__, 55 | 'stats': stats, 56 | 'net_sd': model.state_dict(), 57 | 'excluded_data': excluded_data 58 | } 59 | torch.save(resdict, os.path.join(save_folder, 'final.pth')) 60 | -------------------------------------------------------------------------------- /recovery/__init__.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Library of routines.""" 3 | 4 | from recovery import nn 5 | from recovery.nn import construct_model, MetaMonkey 6 | 7 | from .data_processing import construct_dataloaders 8 | from .training import train 9 | from recovery import utils 10 | 11 | from .optimization_strategy import training_strategy 12 | 13 | 14 | from .recovery_algo import GradientReconstructor, UnlearnReconstructor 15 | 16 | from recovery import data_processing 17 | __all__ = ['train', 'construct_dataloaders', 'construct_model', 'data_processing', 'MetaMonkey', 18 | 'training_strategy', 'nn', 'utils', 'consts' 19 | 'metrics', 'GradientReconstructor', 'UnlearnReconstructor'] 20 | -------------------------------------------------------------------------------- /recovery/consts.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Setup constants, ymmv.""" 3 | 4 | PIN_MEMORY = True 5 | NON_BLOCKING = False 6 | BENCHMARK = True 7 | MULTITHREAD_DATAPROCESSING = 4 8 | 9 | 10 | cifar10_mean = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822] 11 | cifar10_std = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324] 12 | cifar100_mean = [0.5071598291397095, 0.4866936206817627, 0.44120192527770996] 13 | cifar100_std = [0.2673342823982239, 0.2564384639263153, 0.2761504650115967] 14 | stl10_mean = [0.44671064615249634, 0.4398098886013031, 0.4066464304924011] 15 | stl10_std = [0.26034098863601685, 0.2565772831439972, 0.2712673842906952] 16 | -------------------------------------------------------------------------------- /recovery/data_processing.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Repeatable code parts concerning data loading.""" 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import Dataset 9 | import os 10 | 11 | from .consts import * 12 | 13 | class Loss: 14 | """Abstract class, containing necessary methods. 15 | 16 | Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model 17 | containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations 18 | of the actual metric that is targeted. 19 | 20 | """ 21 | 22 | def __init__(self): 23 | """Init.""" 24 | pass 25 | 26 | def __call__(self, reference, argmin): 27 | """Return l(x, y).""" 28 | raise NotImplementedError() 29 | return value, name, format 30 | 31 | def metric(self, reference, argmin): 32 | """The actually sought metric.""" 33 | raise NotImplementedError() 34 | return value, name, format 35 | 36 | 37 | class PSNR(Loss): 38 | """A classical MSE target. 39 | 40 | The minimized criterion is MSE Loss, the actual metric is average PSNR. 41 | """ 42 | 43 | def __init__(self): 44 | """Init with torch MSE.""" 45 | self.loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') 46 | 47 | def __call__(self, x=None, y=None): 48 | """Return l(x, y).""" 49 | name = 'MSE' 50 | format = '.6f' 51 | if x is None: 52 | return name, format 53 | else: 54 | value = 0.5 * self.loss_fn(x, y) 55 | return value, name, format 56 | 57 | def metric(self, x=None, y=None): 58 | """The actually sought metric.""" 59 | name = 'avg PSNR' 60 | format = '.3f' 61 | if x is None: 62 | return name, format 63 | else: 64 | value = self.psnr_compute(x, y) 65 | return value, name, format 66 | 67 | @staticmethod 68 | def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0): 69 | """Standard PSNR.""" 70 | def get_psnr(img_in, img_ref): 71 | mse = ((img_in - img_ref)**2).mean() 72 | if mse > 0 and torch.isfinite(mse): 73 | return (10 * torch.log10(factor**2 / mse)).item() 74 | elif not torch.isfinite(mse): 75 | return float('nan') 76 | else: 77 | return float('inf') 78 | 79 | if batched: 80 | psnr = get_psnr(img_batch.detach(), ref_batch) 81 | else: 82 | [B, C, m, n] = img_batch.shape 83 | psnrs = [] 84 | for sample in range(B): 85 | psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) 86 | psnr = np.mean(psnrs) 87 | 88 | return psnr 89 | 90 | 91 | class Classification(Loss): 92 | """A classical NLL loss for classification. Evaluation has the softmax baked in. 93 | 94 | The minimized criterion is cross entropy, the actual metric is total accuracy. 95 | """ 96 | 97 | def __init__(self): 98 | """Init with torch MSE.""" 99 | self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, 100 | reduce=None, reduction='mean') 101 | 102 | def __call__(self, x=None, y=None): 103 | """Return l(x, y).""" 104 | name = 'CrossEntropy' 105 | format = '1.5f' 106 | if x is None: 107 | return name, format 108 | else: 109 | value = self.loss_fn(x, y) 110 | return value, name, format 111 | 112 | def metric(self, x=None, y=None): 113 | """The actually sought metric.""" 114 | name = 'Accuracy' 115 | format = '6.2%' 116 | if x is None: 117 | return name, format 118 | else: 119 | value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0] 120 | return value.detach(), name, format 121 | 122 | 123 | class SubTrainDataset(Dataset): 124 | def __init__(self, data, targets, transform=None, target_transform=None): 125 | self.data = data 126 | self.targets = targets 127 | self.transform = transform 128 | self.target_transform = target_transform 129 | def __getitem__(self, index): 130 | img, target = self.data[index], self.targets[index] 131 | 132 | if self.transform is not None: 133 | img = self.transform(img) 134 | 135 | if self.target_transform is not None: 136 | target = self.target_transform(target) 137 | 138 | return img, target 139 | def __len__(self): 140 | return len(self.data) 141 | 142 | 143 | def construct_dataloaders(dataset, defs, data_path='~/data', shuffle=True, normalize=True, exclude_num=0): 144 | """Return a dataloader with given dataset and augmentation, normalize data?.""" 145 | path = os.path.expanduser(data_path) 146 | 147 | if dataset == 'cifar10': 148 | trainset, validset, excluded_data, data_mean, data_std = _build_cifar10(path, defs.augmentations, normalize, exclude_num) 149 | loss_fn = Classification() 150 | elif dataset == 'cifar100': 151 | trainset, validset, excluded_data, data_mean, data_std = _build_cifar100(path, defs.augmentations, normalize, exclude_num) 152 | loss_fn = Classification() 153 | elif dataset == 'stl10': 154 | trainset, validset, excluded_data, data_mean, data_std = _build_stl10(path, defs.augmentations, normalize, exclude_num) 155 | loss_fn = Classification() 156 | 157 | num_classes = len(np.unique([y for x, y in validset])) 158 | if MULTITHREAD_DATAPROCESSING: 159 | num_workers = min(torch.get_num_threads(), MULTITHREAD_DATAPROCESSING) if torch.get_num_threads() > 1 else 0 160 | else: 161 | num_workers = 0 162 | 163 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=min(defs.batch_size, len(trainset)), 164 | shuffle=shuffle, drop_last=True, num_workers=num_workers, pin_memory=PIN_MEMORY) 165 | validloader = torch.utils.data.DataLoader(validset, batch_size=min(defs.batch_size, len(trainset)), 166 | shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=PIN_MEMORY) 167 | 168 | return loss_fn, trainloader, validloader, num_classes, excluded_data, data_mean, data_std 169 | 170 | 171 | def _build_cifar10(data_path, augmentations=True, normalize=True, exclude_num=0): 172 | """Define CIFAR-10 with everything considered.""" 173 | # Load data 174 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 175 | validset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 176 | 177 | if cifar10_mean is None: 178 | data_mean, data_std = _get_meanstd(trainset) 179 | else: 180 | data_mean, data_std = cifar10_mean, cifar10_std 181 | 182 | 183 | data, target = [], [] 184 | for x, y in trainset: 185 | data.append(x) 186 | target.append(y) 187 | data = torch.stack(data) 188 | target = torch.tensor(target, dtype=torch.long) 189 | 190 | trainset = SubTrainDataset(data[exclude_num:], target[exclude_num:]) 191 | excluded_data = (data[:exclude_num], target[:exclude_num]) 192 | 193 | # Organize preprocessing 194 | transform = transforms.Compose([transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 195 | if augmentations: 196 | transform_train = transforms.Compose([ 197 | transforms.RandomCrop(32, padding=4), 198 | transforms.RandomHorizontalFlip(), 199 | transform]) 200 | trainset.transform = transform_train 201 | else: 202 | trainset.transform = transform 203 | 204 | validset.transform = transforms.Compose([ 205 | transforms.ToTensor(), 206 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 207 | 208 | return trainset, validset, excluded_data, data_mean, data_std 209 | 210 | 211 | def _build_stl10(data_path, augmentations=True, normalize=True, exclude_num=0): 212 | """Define STL-10 with everything considered.""" 213 | # Load data 214 | trainset = torchvision.datasets.STL10(root=data_path, split='train', download=True, transform=transforms.ToTensor()) 215 | validset = torchvision.datasets.STL10(root=data_path, split='test', download=True, transform=transforms.ToTensor()) 216 | 217 | if stl10_mean is None: 218 | data_mean, data_std = _get_meanstd(trainset) 219 | else: 220 | data_mean, data_std = stl10_mean, stl10_std 221 | 222 | 223 | data, target = [], [] 224 | for x, y in trainset: 225 | data.append(x) 226 | target.append(y) 227 | data = torch.stack(data) 228 | target = torch.tensor(target, dtype=torch.long) 229 | 230 | trainset = SubTrainDataset(data[exclude_num:], target[exclude_num:]) 231 | excluded_data = (data[:exclude_num], target[:exclude_num]) 232 | 233 | # Organize preprocessing 234 | transform = transforms.Compose([transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 235 | if augmentations: 236 | transform_train = transforms.Compose([ 237 | transforms.RandomCrop(96, padding=4), 238 | transforms.RandomHorizontalFlip(), 239 | transform]) 240 | trainset.transform = transform_train 241 | else: 242 | trainset.transform = transform 243 | validset.transform = transforms.Compose([ 244 | transforms.ToTensor(), 245 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 246 | 247 | return trainset, validset, excluded_data, data_mean, data_std 248 | 249 | def _build_cifar100(data_path, augmentations=True, normalize=True, exclude_num=0): 250 | """Define CIFAR-100 with everything considered.""" 251 | # Load data 252 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 253 | validset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 254 | 255 | if cifar100_mean is None: 256 | data_mean, data_std = _get_meanstd(trainset) 257 | else: 258 | data_mean, data_std = cifar100_mean, cifar100_std 259 | 260 | 261 | data, target = [], [] 262 | for x, y in trainset: 263 | data.append(x) 264 | target.append(y) 265 | data = torch.stack(data) 266 | target = torch.tensor(target, dtype=torch.long) 267 | trainset = SubTrainDataset(data[exclude_num:], target[exclude_num:]) 268 | excluded_data = (data[:exclude_num], target[:exclude_num]) 269 | 270 | # Organize preprocessing 271 | transform = transforms.Compose([transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 272 | if augmentations: 273 | transform_train = transforms.Compose([ 274 | transforms.RandomCrop(32, padding=4), 275 | transforms.RandomHorizontalFlip(), 276 | transform]) 277 | trainset.transform = transform_train 278 | else: 279 | trainset.transform = transform 280 | validset.transform = transforms.Compose([ 281 | transforms.ToTensor(), 282 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 283 | 284 | return trainset, validset, excluded_data, data_mean, data_std 285 | 286 | 287 | def _get_meanstd(trainset): 288 | cc = torch.cat([trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1) 289 | data_mean = torch.mean(cc, dim=1).tolist() 290 | data_std = torch.std(cc, dim=1).tolist() 291 | return data_mean, data_std 292 | 293 | 294 | -------------------------------------------------------------------------------- /recovery/metrics.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """This is code based on https://sudomake.ai/inception-score-explained/.""" 3 | import torch 4 | import torchvision 5 | 6 | from collections import defaultdict 7 | 8 | class InceptionScore(torch.nn.Module): 9 | """Class that manages and returns the inception score of images.""" 10 | 11 | def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)): 12 | """Initialize with setup and target inception batch size.""" 13 | super().__init__() 14 | self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False) 15 | self.model = torchvision.models.inception_v3(pretrained=True).to(**setup) 16 | self.model.eval() 17 | self.batch_size = batch_size 18 | 19 | def forward(self, image_batch): 20 | """Image batch should have dimensions BCHW and should be normalized. 21 | 22 | B should be divisible by self.batch_size. 23 | """ 24 | B, C, H, W = image_batch.shape 25 | batches = B // self.batch_size 26 | scores = [] 27 | for batch in range(batches): 28 | input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size]) 29 | scores.append(self.model(input)) 30 | prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1) 31 | entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx)) 32 | return entropy.sum() 33 | 34 | 35 | def psnr(img_batch, ref_batch, batched=False, factor=1.0): 36 | """Standard PSNR.""" 37 | def get_psnr(img_in, img_ref): 38 | mse = ((img_in - img_ref)**2).mean() 39 | if mse > 0 and torch.isfinite(mse): 40 | return (10 * torch.log10(factor**2 / mse)) 41 | elif not torch.isfinite(mse): 42 | return img_batch.new_tensor(float('nan')) 43 | else: 44 | return img_batch.new_tensor(float('inf')) 45 | 46 | if batched: 47 | psnr = get_psnr(img_batch.detach(), ref_batch) 48 | else: 49 | [B, C, m, n] = img_batch.shape 50 | psnrs = [] 51 | for sample in range(B): 52 | psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) 53 | psnr = torch.stack(psnrs, dim=0).mean() 54 | 55 | return psnr.item() 56 | 57 | 58 | def total_variation(x): 59 | """Anisotropic TV.""" 60 | dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) 61 | dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) 62 | return dx + dy 63 | 64 | 65 | 66 | def activation_errors(model, x1, x2): 67 | """Compute activation-level error metrics for every module in the network.""" 68 | model.eval() 69 | 70 | device = next(model.parameters()).device 71 | 72 | hooks = [] 73 | data = defaultdict(dict) 74 | inputs = torch.cat((x1, x2), dim=0) 75 | separator = x1.shape[0] 76 | 77 | def check_activations(self, input, output): 78 | module_name = str(*[name for name, mod in model.named_modules() if self is mod]) 79 | try: 80 | layer_inputs = input[0].detach() 81 | residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2) 82 | se_error = residual.sum() 83 | mse_error = residual.mean() 84 | sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(), 85 | layer_inputs[separator:].flatten(), 86 | dim=0, eps=1e-8).detach() 87 | data['se'][module_name] = se_error.item() 88 | data['mse'][module_name] = mse_error.item() 89 | data['sim'][module_name] = sim.item() 90 | except (KeyboardInterrupt, SystemExit): 91 | raise 92 | except AttributeError: 93 | pass 94 | 95 | for name, module in model.named_modules(): 96 | hooks.append(module.register_forward_hook(check_activations)) 97 | 98 | try: 99 | outputs = model(inputs.to(device)) 100 | for hook in hooks: 101 | hook.remove() 102 | except Exception as e: 103 | for hook in hooks: 104 | hook.remove() 105 | raise 106 | 107 | return data 108 | -------------------------------------------------------------------------------- /recovery/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental modules and unexperimental model hooks.""" 2 | 3 | from .models import construct_model 4 | from .modules import MetaMonkey 5 | 6 | __all__ = ['construct_model', 'MetaMonkey'] 7 | -------------------------------------------------------------------------------- /recovery/nn/densenet.py: -------------------------------------------------------------------------------- 1 | """DenseNet in PyTorch.""" 2 | """Adaptation we did with ******.""" 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class _Bottleneck(nn.Module): 11 | def __init__(self, in_planes, growth_rate): 12 | super().__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(4 * growth_rate) 16 | self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 17 | 18 | def forward(self, x): 19 | out = self.conv1(F.relu(self.bn1(x))) 20 | out = self.conv2(F.relu(self.bn2(out))) 21 | out = torch.cat([out, x], 1) 22 | return out 23 | 24 | 25 | class _Transition(nn.Module): 26 | def __init__(self, in_planes, out_planes): 27 | super().__init__() 28 | self.bn = nn.BatchNorm2d(in_planes) 29 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 30 | 31 | def forward(self, x): 32 | out = self.conv(F.relu(self.bn(x))) 33 | out = F.avg_pool2d(out, 2) 34 | return out 35 | 36 | 37 | class _DenseNet(nn.Module): 38 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 39 | super().__init__() 40 | self.growth_rate = growth_rate 41 | 42 | num_planes = 2 * growth_rate 43 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 44 | 45 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 46 | num_planes += nblocks[0] * growth_rate 47 | out_planes = int(math.floor(num_planes * reduction)) 48 | self.trans1 = _Transition(num_planes, out_planes) 49 | num_planes = out_planes 50 | 51 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 52 | num_planes += nblocks[1] * growth_rate 53 | out_planes = int(math.floor(num_planes * reduction)) 54 | self.trans2 = _Transition(num_planes, out_planes) 55 | num_planes = out_planes 56 | 57 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 58 | num_planes += nblocks[2] * growth_rate 59 | out_planes = int(math.floor(num_planes * reduction)) 60 | # self.trans3 = Transition(num_planes, out_planes) 61 | # num_planes = out_planes 62 | 63 | # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 64 | # num_planes += nblocks[3]*growth_rate 65 | 66 | self.bn = nn.BatchNorm2d(num_planes) 67 | num_planes = 132 * growth_rate // 12 * 2 * 2 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.dense3(out) 82 | # out = self.trans3(self.dense3(out)) 83 | # out = self.dense4(out) 84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def densenet_cifar(num_classes=10): 91 | """Instantiate the smallest DenseNet.""" 92 | return _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) 93 | -------------------------------------------------------------------------------- /recovery/nn/models.py: -------------------------------------------------------------------------------- 1 | """Define basic models and translate some torchvision stuff.""" 2 | """Stuff from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py.""" 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | 7 | from torchvision.models.resnet import Bottleneck 8 | from .revnet import iRevNet 9 | from .densenet import _DenseNet, _Bottleneck 10 | 11 | from collections import OrderedDict 12 | import numpy as np 13 | from ..utils import set_random_seed 14 | 15 | 16 | 17 | 18 | def construct_model(model, num_classes=10, seed=None, num_channels=3, modelkey=None): 19 | """Return various models.""" 20 | if modelkey is None: 21 | if seed is None: 22 | model_init_seed = np.random.randint(0, 2**32 - 10) 23 | else: 24 | model_init_seed = seed 25 | else: 26 | model_init_seed = modelkey 27 | set_random_seed(model_init_seed) 28 | 29 | if model in ['ConvNet']: 30 | model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) 31 | elif model == 'BeyondInferringCifar': 32 | model = torch.nn.Sequential(OrderedDict([ 33 | ('conv1', torch.nn.Conv2d(3, 32, 3, stride=2, padding=1)), 34 | ('relu0', torch.nn.LeakyReLU()), 35 | ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)), 36 | ('relu1', torch.nn.LeakyReLU()), 37 | ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)), 38 | ('relu2', torch.nn.LeakyReLU()), 39 | ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)), 40 | ('relu3', torch.nn.LeakyReLU()), 41 | ('flatt', torch.nn.Flatten()), 42 | ('linear0', torch.nn.Linear(12544, 12544)), 43 | ('relu4', torch.nn.LeakyReLU()), 44 | ('linear1', torch.nn.Linear(12544, 10)), 45 | ('softmax', torch.nn.Softmax(dim=1)) 46 | ])) 47 | elif model == 'MLP': 48 | width = 1024 49 | model = torch.nn.Sequential(OrderedDict([ 50 | ('flatten', torch.nn.Flatten()), 51 | ('linear0', torch.nn.Linear(3072, width)), 52 | ('relu0', torch.nn.ReLU()), 53 | ('linear1', torch.nn.Linear(width, width)), 54 | ('relu1', torch.nn.ReLU()), 55 | ('linear2', torch.nn.Linear(width, width)), 56 | ('relu2', torch.nn.ReLU()), 57 | ('linear3', torch.nn.Linear(width, num_classes))])) 58 | elif model == 'TwoLP': 59 | width = 2048 60 | model = torch.nn.Sequential(OrderedDict([ 61 | ('flatten', torch.nn.Flatten()), 62 | ('linear0', torch.nn.Linear(3072, width)), 63 | ('relu0', torch.nn.ReLU()), 64 | ('linear3', torch.nn.Linear(width, num_classes))])) 65 | elif model == 'ResNet20': 66 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16) 67 | elif model == 'ResNet20-nostride': 68 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16, 69 | strides=[1, 1, 1, 1]) 70 | elif model == 'ResNet20-10': 71 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 10) 72 | elif model == 'ResNet20-4': 73 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4) 74 | elif model == 'ResNet20-4-unpooled': 75 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4, 76 | pool='max') 77 | elif model == 'ResNet28-10': 78 | model = ResNet(torchvision.models.resnet.BasicBlock, [4, 4, 4], num_classes=num_classes, base_width=16 * 10) 79 | elif model == 'ResNet32': 80 | model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16) 81 | elif model == 'ResNet32-10': 82 | model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16 * 10) 83 | elif model == 'ResNet44': 84 | model = ResNet(torchvision.models.resnet.BasicBlock, [7, 7, 7], num_classes=num_classes, base_width=16) 85 | elif model == 'ResNet56': 86 | model = ResNet(torchvision.models.resnet.BasicBlock, [9, 9, 9], num_classes=num_classes, base_width=16) 87 | elif model == 'ResNet110': 88 | model = ResNet(torchvision.models.resnet.BasicBlock, [18, 18, 18], num_classes=num_classes, base_width=16) 89 | elif model == 'ResNet18': 90 | model = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes, base_width=64) 91 | elif model == 'ResNet34': 92 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes, base_width=64) 93 | elif model == 'ResNet50': 94 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64) 95 | elif model == 'ResNet50-2': 96 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64 * 2) 97 | elif model == 'ResNet101': 98 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, base_width=64) 99 | elif model == 'ResNet152': 100 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, base_width=64) 101 | elif model == 'MobileNet': 102 | inverted_residual_setting = [ 103 | # t, c, n, s 104 | [1, 16, 1, 1], 105 | [6, 24, 2, 1], # cifar adaptation, cf.https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py 106 | [6, 32, 3, 2], 107 | [6, 64, 4, 2], 108 | [6, 96, 3, 1], 109 | [6, 160, 3, 2], 110 | [6, 320, 1, 1], 111 | ] 112 | model = torchvision.models.MobileNetV2(num_classes=num_classes, 113 | inverted_residual_setting=inverted_residual_setting, 114 | width_mult=1.0) 115 | model.features[0] = torchvision.models.mobilenet.ConvBNReLU(num_channels, 32, stride=1) # this is fixed to width=1 116 | elif model == 'MNASNet': 117 | model = torchvision.models.MNASNet(1.0, num_classes=num_classes, dropout=0.2) 118 | elif model == 'DenseNet121': 119 | model = torchvision.models.DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), 120 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=num_classes, 121 | memory_efficient=False) 122 | elif model == 'DenseNet40': 123 | model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) 124 | elif model == 'DenseNet40-4': 125 | model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12 * 4, num_classes=num_classes) 126 | elif model == 'iRevNet': 127 | if num_classes <= 100: 128 | in_shape = [num_channels, 32, 32] # only for cifar right now 129 | model = iRevNet(nBlocks=[18, 18, 18], nStrides=[1, 2, 2], 130 | nChannels=[16, 64, 256], nClasses=num_classes, 131 | init_ds=0, dropout_rate=0.1, affineBN=True, 132 | in_shape=in_shape, mult=4) 133 | else: 134 | in_shape = [3, 224, 224] # only for imagenet 135 | model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], 136 | nChannels=[24, 96, 384, 1536], nClasses=num_classes, 137 | init_ds=2, dropout_rate=0.1, affineBN=True, 138 | in_shape=in_shape, mult=4) 139 | else: 140 | raise NotImplementedError('Model not implemented.') 141 | 142 | print(f'Model initialized with random key {model_init_seed}.') 143 | return model, model_init_seed 144 | 145 | 146 | class ResNet(torchvision.models.ResNet): 147 | """ResNet generalization for CIFAR thingies.""" 148 | 149 | def __init__(self, block, layers, num_classes=10, 150 | groups=1, base_width=64, replace_stride_with_dilation=None, 151 | norm_layer=None, strides=[1, 2, 2, 2], pool='avg'): 152 | """Initialize as usual. Layers and strides are scriptable.""" 153 | super(torchvision.models.ResNet, self).__init__() # nn.Module 154 | if norm_layer is None: 155 | norm_layer = nn.BatchNorm2d 156 | self._norm_layer = norm_layer 157 | 158 | 159 | self.dilation = 1 160 | if replace_stride_with_dilation is None: 161 | # each element in the tuple indicates if we should replace 162 | # the 2x2 stride with a dilated convolution instead 163 | replace_stride_with_dilation = [False, False, False, False] 164 | if len(replace_stride_with_dilation) != 4: 165 | raise ValueError("replace_stride_with_dilation should be None " 166 | "or a 4-element tuple, got {}".format(replace_stride_with_dilation)) 167 | self.groups = groups 168 | 169 | self.inplanes = base_width 170 | self.base_width = 64 # Do this to circumvent BasicBlock errors. The value is not actually used. 171 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 172 | self.bn1 = norm_layer(self.inplanes) 173 | self.relu = nn.ReLU(inplace=True) 174 | 175 | self.layers = torch.nn.ModuleList() 176 | width = self.inplanes 177 | for idx, layer in enumerate(layers): 178 | self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx])) 179 | width *= 2 180 | 181 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == 'avg' else nn.AdaptiveMaxPool2d((1, 1)) 182 | self.fc = nn.Linear(width // 2 * block.expansion, num_classes) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 187 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 188 | nn.init.constant_(m.weight, 1) 189 | nn.init.constant_(m.bias, 0) 190 | 191 | 192 | 193 | def _forward_impl(self, x): 194 | # See note [TorchScript super()] 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | 199 | for layer in self.layers: 200 | x = layer(x) 201 | 202 | x = self.pool(x) 203 | x = torch.flatten(x, 1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | 209 | class ConvNet(torch.nn.Module): 210 | """ConvNetBN.""" 211 | 212 | def __init__(self, width=32, num_classes=10, num_channels=3): 213 | """Init with width and num classes.""" 214 | super().__init__() 215 | self.model = torch.nn.Sequential(OrderedDict([ 216 | ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)), 217 | ('bn0', torch.nn.BatchNorm2d(1 * width)), 218 | ('relu0', torch.nn.ReLU()), 219 | 220 | ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)), 221 | ('bn1', torch.nn.BatchNorm2d(2 * width)), 222 | ('relu1', torch.nn.ReLU()), 223 | 224 | ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)), 225 | ('bn2', torch.nn.BatchNorm2d(2 * width)), 226 | ('relu2', torch.nn.ReLU()), 227 | 228 | ('conv3', torch.nn.Conv2d(2 * width, 4 * width, kernel_size=3, padding=1)), 229 | ('bn3', torch.nn.BatchNorm2d(4 * width)), 230 | ('relu3', torch.nn.ReLU()), 231 | 232 | ('conv4', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 233 | ('bn4', torch.nn.BatchNorm2d(4 * width)), 234 | ('relu4', torch.nn.ReLU()), 235 | 236 | ('conv5', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 237 | ('bn5', torch.nn.BatchNorm2d(4 * width)), 238 | ('relu5', torch.nn.ReLU()), 239 | 240 | ('pool0', torch.nn.MaxPool2d(3)), 241 | 242 | ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 243 | ('bn6', torch.nn.BatchNorm2d(4 * width)), 244 | ('relu6', torch.nn.ReLU()), 245 | 246 | ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 247 | ('bn6', torch.nn.BatchNorm2d(4 * width)), 248 | ('relu6', torch.nn.ReLU()), 249 | 250 | ('conv7', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 251 | ('bn7', torch.nn.BatchNorm2d(4 * width)), 252 | ('relu7', torch.nn.ReLU()), 253 | 254 | ('pool1', torch.nn.MaxPool2d(3)), 255 | ('flatten', torch.nn.Flatten()), 256 | ('linear', torch.nn.Linear(36 * width, num_classes)) 257 | ])) 258 | 259 | def forward(self, input): 260 | return self.model(input) 261 | -------------------------------------------------------------------------------- /recovery/nn/models/attention.py: -------------------------------------------------------------------------------- 1 | """residual attention network in pytorch 2 | 3 | 4 | 5 | [1] Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Cheng Li, Honggang Zhang, Xiaogang Wang, Xiaoou Tang 6 | 7 | Residual Attention Network for Image Classification 8 | https://arxiv.org/abs/1704.06904 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | #"""The Attention Module is built by pre-activation Residual Unit [11] with the 16 | #number of channels in each stage is the same as ResNet [10].""" 17 | 18 | class PreActResidualUnit(nn.Module): 19 | """PreAct Residual Unit 20 | Args: 21 | in_channels: residual unit input channel number 22 | out_channels: residual unit output channel numebr 23 | stride: stride of residual unit when stride = 2, downsample the featuremap 24 | """ 25 | 26 | def __init__(self, in_channels, out_channels, stride): 27 | super().__init__() 28 | 29 | bottleneck_channels = int(out_channels / 4) 30 | self.residual_function = nn.Sequential( 31 | #1x1 conv 32 | nn.BatchNorm2d(in_channels), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(in_channels, bottleneck_channels, 1, stride), 35 | 36 | #3x3 conv 37 | nn.BatchNorm2d(bottleneck_channels), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1), 40 | 41 | #1x1 conv 42 | nn.BatchNorm2d(bottleneck_channels), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(bottleneck_channels, out_channels, 1) 45 | ) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 2 or (in_channels != out_channels): 49 | self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride) 50 | 51 | def forward(self, x): 52 | 53 | res = self.residual_function(x) 54 | shortcut = self.shortcut(x) 55 | 56 | return res + shortcut 57 | 58 | class AttentionModule1(nn.Module): 59 | 60 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 61 | super().__init__() 62 | #"""The hyperparameter p denotes the number of preprocessing Residual 63 | #Units before splitting into trunk branch and mask branch. t denotes 64 | #the number of Residual Units in trunk branch. r denotes the number of 65 | #Residual Units between adjacent pooling layer in the mask branch.""" 66 | assert in_channels == out_channels 67 | 68 | self.pre = self._make_residual(in_channels, out_channels, p) 69 | self.trunk = self._make_residual(in_channels, out_channels, t) 70 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 71 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 72 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 73 | self.soft_resdown4 = self._make_residual(in_channels, out_channels, r) 74 | 75 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 76 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 77 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 78 | self.soft_resup4 = self._make_residual(in_channels, out_channels, r) 79 | 80 | self.shortcut_short = PreActResidualUnit(in_channels, out_channels, 1) 81 | self.shortcut_long = PreActResidualUnit(in_channels, out_channels, 1) 82 | 83 | self.sigmoid = nn.Sequential( 84 | nn.BatchNorm2d(out_channels), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 87 | nn.BatchNorm2d(out_channels), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 90 | nn.Sigmoid() 91 | ) 92 | 93 | self.last = self._make_residual(in_channels, out_channels, p) 94 | 95 | def forward(self, x): 96 | ###We make the size of the smallest output map in each mask branch 7*7 to be consistent 97 | #with the smallest trunk output map size. 98 | ###Thus 3,2,1 max-pooling layers are used in mask branch with input size 56 * 56, 28 * 28, 14 * 14 respectively. 99 | x = self.pre(x) 100 | input_size = (x.size(2), x.size(3)) 101 | 102 | x_t = self.trunk(x) 103 | 104 | #first downsample out 28 105 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 106 | x_s = self.soft_resdown1(x_s) 107 | 108 | #28 shortcut 109 | shape1 = (x_s.size(2), x_s.size(3)) 110 | shortcut_long = self.shortcut_long(x_s) 111 | 112 | #seccond downsample out 14 113 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 114 | x_s = self.soft_resdown2(x_s) 115 | 116 | #14 shortcut 117 | shape2 = (x_s.size(2), x_s.size(3)) 118 | shortcut_short = self.soft_resdown3(x_s) 119 | 120 | #third downsample out 7 121 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 122 | x_s = self.soft_resdown3(x_s) 123 | 124 | #mid 125 | x_s = self.soft_resdown4(x_s) 126 | x_s = self.soft_resup1(x_s) 127 | 128 | #first upsample out 14 129 | x_s = self.soft_resup2(x_s) 130 | x_s = F.interpolate(x_s, size=shape2) 131 | x_s += shortcut_short 132 | 133 | #second upsample out 28 134 | x_s = self.soft_resup3(x_s) 135 | x_s = F.interpolate(x_s, size=shape1) 136 | x_s += shortcut_long 137 | 138 | #thrid upsample out 54 139 | x_s = self.soft_resup4(x_s) 140 | x_s = F.interpolate(x_s, size=input_size) 141 | 142 | x_s = self.sigmoid(x_s) 143 | x = (1 + x_s) * x_t 144 | x = self.last(x) 145 | 146 | return x 147 | 148 | def _make_residual(self, in_channels, out_channels, p): 149 | 150 | layers = [] 151 | for _ in range(p): 152 | layers.append(PreActResidualUnit(in_channels, out_channels, 1)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | class AttentionModule2(nn.Module): 157 | 158 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 159 | super().__init__() 160 | #"""The hyperparameter p denotes the number of preprocessing Residual 161 | #Units before splitting into trunk branch and mask branch. t denotes 162 | #the number of Residual Units in trunk branch. r denotes the number of 163 | #Residual Units between adjacent pooling layer in the mask branch.""" 164 | assert in_channels == out_channels 165 | 166 | self.pre = self._make_residual(in_channels, out_channels, p) 167 | self.trunk = self._make_residual(in_channels, out_channels, t) 168 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 169 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 170 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 171 | 172 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 173 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 174 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 175 | 176 | self.shortcut = PreActResidualUnit(in_channels, out_channels, 1) 177 | 178 | self.sigmoid = nn.Sequential( 179 | nn.BatchNorm2d(out_channels), 180 | nn.ReLU(inplace=True), 181 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 182 | nn.BatchNorm2d(out_channels), 183 | nn.ReLU(inplace=True), 184 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 185 | nn.Sigmoid() 186 | ) 187 | 188 | self.last = self._make_residual(in_channels, out_channels, p) 189 | 190 | def forward(self, x): 191 | x = self.pre(x) 192 | input_size = (x.size(2), x.size(3)) 193 | 194 | x_t = self.trunk(x) 195 | 196 | #first downsample out 14 197 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 198 | x_s = self.soft_resdown1(x_s) 199 | 200 | #14 shortcut 201 | shape1 = (x_s.size(2), x_s.size(3)) 202 | shortcut = self.shortcut(x_s) 203 | 204 | #seccond downsample out 7 205 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 206 | x_s = self.soft_resdown2(x_s) 207 | 208 | #mid 209 | x_s = self.soft_resdown3(x_s) 210 | x_s = self.soft_resup1(x_s) 211 | 212 | #first upsample out 14 213 | x_s = self.soft_resup2(x_s) 214 | x_s = F.interpolate(x_s, size=shape1) 215 | x_s += shortcut 216 | 217 | #second upsample out 28 218 | x_s = self.soft_resup3(x_s) 219 | x_s = F.interpolate(x_s, size=input_size) 220 | 221 | x_s = self.sigmoid(x_s) 222 | x = (1 + x_s) * x_t 223 | x = self.last(x) 224 | 225 | return x 226 | 227 | def _make_residual(self, in_channels, out_channels, p): 228 | 229 | layers = [] 230 | for _ in range(p): 231 | layers.append(PreActResidualUnit(in_channels, out_channels, 1)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | class AttentionModule3(nn.Module): 236 | 237 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 238 | super().__init__() 239 | 240 | assert in_channels == out_channels 241 | 242 | self.pre = self._make_residual(in_channels, out_channels, p) 243 | self.trunk = self._make_residual(in_channels, out_channels, t) 244 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 245 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 246 | 247 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 248 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 249 | 250 | self.shortcut = PreActResidualUnit(in_channels, out_channels, 1) 251 | 252 | self.sigmoid = nn.Sequential( 253 | nn.BatchNorm2d(out_channels), 254 | nn.ReLU(inplace=True), 255 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 256 | nn.BatchNorm2d(out_channels), 257 | nn.ReLU(inplace=True), 258 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 259 | nn.Sigmoid() 260 | ) 261 | 262 | self.last = self._make_residual(in_channels, out_channels, p) 263 | 264 | def forward(self, x): 265 | x = self.pre(x) 266 | input_size = (x.size(2), x.size(3)) 267 | 268 | x_t = self.trunk(x) 269 | 270 | #first downsample out 14 271 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 272 | x_s = self.soft_resdown1(x_s) 273 | 274 | #mid 275 | x_s = self.soft_resdown2(x_s) 276 | x_s = self.soft_resup1(x_s) 277 | 278 | #first upsample out 14 279 | x_s = self.soft_resup2(x_s) 280 | x_s = F.interpolate(x_s, size=input_size) 281 | 282 | x_s = self.sigmoid(x_s) 283 | x = (1 + x_s) * x_t 284 | x = self.last(x) 285 | 286 | return x 287 | 288 | def _make_residual(self, in_channels, out_channels, p): 289 | 290 | layers = [] 291 | for _ in range(p): 292 | layers.append(PreActResidualUnit(in_channels, out_channels, 1)) 293 | 294 | return nn.Sequential(*layers) 295 | 296 | class Attention(nn.Module): 297 | """residual attention netowrk 298 | Args: 299 | block_num: attention module number for each stage 300 | """ 301 | 302 | def __init__(self, block_num, class_num=100): 303 | 304 | super().__init__() 305 | self.pre_conv = nn.Sequential( 306 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), 307 | nn.BatchNorm2d(64), 308 | nn.ReLU(inplace=True) 309 | ) 310 | 311 | self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1) 312 | self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2) 313 | self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3) 314 | self.stage4 = nn.Sequential( 315 | PreActResidualUnit(1024, 2048, 2), 316 | PreActResidualUnit(2048, 2048, 1), 317 | PreActResidualUnit(2048, 2048, 1) 318 | ) 319 | self.avg = nn.AdaptiveAvgPool2d(1) 320 | self.linear = nn.Linear(2048, 100) 321 | 322 | def forward(self, x): 323 | x = self.pre_conv(x) 324 | x = self.stage1(x) 325 | x = self.stage2(x) 326 | x = self.stage3(x) 327 | x = self.stage4(x) 328 | x = self.avg(x) 329 | x = x.view(x.size(0), -1) 330 | x = self.linear(x) 331 | 332 | return x 333 | 334 | def _make_stage(self, in_channels, out_channels, num, block): 335 | 336 | layers = [] 337 | layers.append(PreActResidualUnit(in_channels, out_channels, 2)) 338 | 339 | for _ in range(num): 340 | layers.append(block(out_channels, out_channels)) 341 | 342 | return nn.Sequential(*layers) 343 | 344 | def attention56(): 345 | return Attention([1, 1, 1]) 346 | 347 | def attention92(): 348 | return Attention([1, 2, 3]) 349 | 350 | -------------------------------------------------------------------------------- /recovery/nn/models/densenet.py: -------------------------------------------------------------------------------- 1 | """dense net in pytorch 2 | 3 | 4 | 5 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. 6 | 7 | Densely Connected Convolutional Networks 8 | https://arxiv.org/abs/1608.06993v5 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | 16 | #"""Bottleneck layers. Although each layer only produces k 17 | #output feature-maps, it typically has many more inputs. It 18 | #has been noted in [37, 11] that a 1×1 convolution can be in- 19 | #troduced as bottleneck layer before each 3×3 convolution 20 | #to reduce the number of input feature-maps, and thus to 21 | #improve computational efficiency.""" 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_channels, growth_rate): 24 | super().__init__() 25 | #"""In our experiments, we let each 1×1 convolution 26 | #produce 4k feature-maps.""" 27 | inner_channel = 4 * growth_rate 28 | 29 | #"""We find this design especially effective for DenseNet and 30 | #we refer to our network with such a bottleneck layer, i.e., 31 | #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` , 32 | #as DenseNet-B.""" 33 | self.bottle_neck = nn.Sequential( 34 | nn.BatchNorm2d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 37 | nn.BatchNorm2d(inner_channel), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False) 40 | ) 41 | 42 | def forward(self, x): 43 | return torch.cat([x, self.bottle_neck(x)], 1) 44 | 45 | #"""We refer to layers between blocks as transition 46 | #layers, which do convolution and pooling.""" 47 | class Transition(nn.Module): 48 | def __init__(self, in_channels, out_channels): 49 | super().__init__() 50 | #"""The transition layers used in our experiments 51 | #consist of a batch normalization layer and an 1×1 52 | #convolutional layer followed by a 2×2 average pooling 53 | #layer""". 54 | self.down_sample = nn.Sequential( 55 | nn.BatchNorm2d(in_channels), 56 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 57 | nn.AvgPool2d(2, stride=2) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.down_sample(x) 62 | 63 | #DesneNet-BC 64 | #B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3)) 65 | #C stands for compression factor(0<=theta<=1) 66 | class DenseNet(nn.Module): 67 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=100): 68 | super().__init__() 69 | self.growth_rate = growth_rate 70 | 71 | #"""Before entering the first dense block, a convolution 72 | #with 16 (or twice the growth rate for DenseNet-BC) 73 | #output channels is performed on the input images.""" 74 | inner_channels = 2 * growth_rate 75 | 76 | #For convolutional layers with kernel size 3×3, each 77 | #side of the inputs is zero-padded by one pixel to keep 78 | #the feature-map size fixed. 79 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) 80 | 81 | self.features = nn.Sequential() 82 | 83 | for index in range(len(nblocks) - 1): 84 | self.features.add_module("dense_block_layer_{}".format(index), self._make_dense_layers(block, inner_channels, nblocks[index])) 85 | inner_channels += growth_rate * nblocks[index] 86 | 87 | #"""If a dense block contains m feature-maps, we let the 88 | #following transition layer generate θm output feature- 89 | #maps, where 0 < θ ≤ 1 is referred to as the compression 90 | #fac-tor. 91 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 92 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 93 | inner_channels = out_channels 94 | 95 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1])) 96 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 97 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 98 | self.features.add_module('relu', nn.ReLU(inplace=True)) 99 | 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | 102 | self.linear = nn.Linear(inner_channels, num_class) 103 | 104 | def forward(self, x): 105 | output = self.conv1(x) 106 | output = self.features(output) 107 | output = self.avgpool(output) 108 | output = output.view(output.size()[0], -1) 109 | output = self.linear(output) 110 | return output 111 | 112 | def _make_dense_layers(self, block, in_channels, nblocks): 113 | dense_block = nn.Sequential() 114 | for index in range(nblocks): 115 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 116 | in_channels += self.growth_rate 117 | return dense_block 118 | 119 | def densenet121(num_classes): 120 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_class=num_classes) 121 | 122 | def densenet169(num_classes): 123 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, num_class=num_classes) 124 | 125 | def densenet201(num_classes): 126 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, num_class=num_classes) 127 | 128 | def densenet161(num_classes): 129 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, num_class=num_classes) 130 | 131 | -------------------------------------------------------------------------------- /recovery/nn/models/googlenet.py: -------------------------------------------------------------------------------- 1 | """google net in pytorch 2 | 3 | 4 | 5 | [1] Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 6 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 7 | 8 | Going Deeper with Convolutions 9 | https://arxiv.org/abs/1409.4842v1 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | class Inception(nn.Module): 16 | def __init__(self, input_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj): 17 | super().__init__() 18 | 19 | #1x1conv branch 20 | self.b1 = nn.Sequential( 21 | nn.Conv2d(input_channels, n1x1, kernel_size=1), 22 | nn.BatchNorm2d(n1x1), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | #1x1conv -> 3x3conv branch 27 | self.b2 = nn.Sequential( 28 | nn.Conv2d(input_channels, n3x3_reduce, kernel_size=1), 29 | nn.BatchNorm2d(n3x3_reduce), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(n3x3), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | #1x1conv -> 5x5conv branch 37 | #we use 2 3x3 conv filters stacked instead 38 | #of 1 5x5 filters to obtain the same receptive 39 | #field with fewer parameters 40 | self.b3 = nn.Sequential( 41 | nn.Conv2d(input_channels, n5x5_reduce, kernel_size=1), 42 | nn.BatchNorm2d(n5x5_reduce), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(n5x5_reduce, n5x5, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(n5x5, n5x5), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 48 | nn.BatchNorm2d(n5x5), 49 | nn.ReLU(inplace=True) 50 | ) 51 | 52 | #3x3pooling -> 1x1conv 53 | #same conv 54 | self.b4 = nn.Sequential( 55 | nn.MaxPool2d(3, stride=1, padding=1), 56 | nn.Conv2d(input_channels, pool_proj, kernel_size=1), 57 | nn.BatchNorm2d(pool_proj), 58 | nn.ReLU(inplace=True) 59 | ) 60 | 61 | def forward(self, x): 62 | return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], dim=1) 63 | 64 | 65 | class GoogleNet(nn.Module): 66 | 67 | def __init__(self, num_class=100): 68 | super().__init__() 69 | self.prelayer = nn.Sequential( 70 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 71 | nn.BatchNorm2d(64), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 74 | nn.BatchNorm2d(64), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(64, 192, kernel_size=3, padding=1, bias=False), 77 | nn.BatchNorm2d(192), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | #although we only use 1 conv layer as prelayer, 82 | #we still use name a3, b3....... 83 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 84 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 85 | 86 | ##"""In general, an Inception network is a network consisting of 87 | ##modules of the above type stacked upon each other, with occasional 88 | ##max-pooling layers with stride 2 to halve the resolution of the 89 | ##grid""" 90 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 91 | 92 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 93 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 94 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 95 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 96 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 97 | 98 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 99 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 100 | 101 | #input feature size: 8*8*1024 102 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 103 | self.dropout = nn.Dropout2d(p=0.4) 104 | self.linear = nn.Linear(1024, num_class) 105 | 106 | def forward(self, x): 107 | x = self.prelayer(x) 108 | x = self.maxpool(x) 109 | x = self.a3(x) 110 | x = self.b3(x) 111 | 112 | x = self.maxpool(x) 113 | 114 | x = self.a4(x) 115 | x = self.b4(x) 116 | x = self.c4(x) 117 | x = self.d4(x) 118 | x = self.e4(x) 119 | 120 | x = self.maxpool(x) 121 | 122 | x = self.a5(x) 123 | x = self.b5(x) 124 | 125 | #"""It was found that a move from fully connected layers to 126 | #average pooling improved the top-1 accuracy by about 0.6%, 127 | #however the use of dropout remained essential even after 128 | #removing the fully connected layers.""" 129 | x = self.avgpool(x) 130 | x = self.dropout(x) 131 | x = x.view(x.size()[0], -1) 132 | x = self.linear(x) 133 | 134 | return x 135 | 136 | def googlenet(): 137 | return GoogleNet() 138 | 139 | 140 | -------------------------------------------------------------------------------- /recovery/nn/models/inceptionv3.py: -------------------------------------------------------------------------------- 1 | """ inceptionv3 in pytorch 2 | 3 | 4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna 5 | 6 | Rethinking the Inception Architecture for Computer Vision 7 | https://arxiv.org/abs/1512.00567v3 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class BasicConv2d(nn.Module): 15 | 16 | def __init__(self, input_channels, output_channels, **kwargs): 17 | super().__init__() 18 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 19 | self.bn = nn.BatchNorm2d(output_channels) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | 27 | return x 28 | 29 | #same naive inception module 30 | class InceptionA(nn.Module): 31 | 32 | def __init__(self, input_channels, pool_features): 33 | super().__init__() 34 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1) 35 | 36 | self.branch5x5 = nn.Sequential( 37 | BasicConv2d(input_channels, 48, kernel_size=1), 38 | BasicConv2d(48, 64, kernel_size=5, padding=2) 39 | ) 40 | 41 | self.branch3x3 = nn.Sequential( 42 | BasicConv2d(input_channels, 64, kernel_size=1), 43 | BasicConv2d(64, 96, kernel_size=3, padding=1), 44 | BasicConv2d(96, 96, kernel_size=3, padding=1) 45 | ) 46 | 47 | self.branchpool = nn.Sequential( 48 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 49 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1) 50 | ) 51 | 52 | def forward(self, x): 53 | 54 | #x -> 1x1(same) 55 | branch1x1 = self.branch1x1(x) 56 | 57 | #x -> 1x1 -> 5x5(same) 58 | branch5x5 = self.branch5x5(x) 59 | #branch5x5 = self.branch5x5_2(branch5x5) 60 | 61 | #x -> 1x1 -> 3x3 -> 3x3(same) 62 | branch3x3 = self.branch3x3(x) 63 | 64 | #x -> pool -> 1x1(same) 65 | branchpool = self.branchpool(x) 66 | 67 | outputs = [branch1x1, branch5x5, branch3x3, branchpool] 68 | 69 | return torch.cat(outputs, 1) 70 | 71 | #downsample 72 | #Factorization into smaller convolutions 73 | class InceptionB(nn.Module): 74 | 75 | def __init__(self, input_channels): 76 | super().__init__() 77 | 78 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2) 79 | 80 | self.branch3x3stack = nn.Sequential( 81 | BasicConv2d(input_channels, 64, kernel_size=1), 82 | BasicConv2d(64, 96, kernel_size=3, padding=1), 83 | BasicConv2d(96, 96, kernel_size=3, stride=2) 84 | ) 85 | 86 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 87 | 88 | def forward(self, x): 89 | 90 | #x - > 3x3(downsample) 91 | branch3x3 = self.branch3x3(x) 92 | 93 | #x -> 3x3 -> 3x3(downsample) 94 | branch3x3stack = self.branch3x3stack(x) 95 | 96 | #x -> avgpool(downsample) 97 | branchpool = self.branchpool(x) 98 | 99 | #"""We can use two parallel stride 2 blocks: P and C. P is a pooling 100 | #layer (either average or maximum pooling) the activation, both of 101 | #them are stride 2 the filter banks of which are concatenated as in 102 | #figure 10.""" 103 | outputs = [branch3x3, branch3x3stack, branchpool] 104 | 105 | return torch.cat(outputs, 1) 106 | 107 | #Factorizing Convolutions with Large Filter Size 108 | class InceptionC(nn.Module): 109 | def __init__(self, input_channels, channels_7x7): 110 | super().__init__() 111 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 112 | 113 | c7 = channels_7x7 114 | 115 | #In theory, we could go even further and argue that one can replace any n × n 116 | #convolution by a 1 × n convolution followed by a n × 1 convolution and the 117 | #computational cost saving increases dramatically as n grows (see figure 6). 118 | self.branch7x7 = nn.Sequential( 119 | BasicConv2d(input_channels, c7, kernel_size=1), 120 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 121 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 122 | ) 123 | 124 | self.branch7x7stack = nn.Sequential( 125 | BasicConv2d(input_channels, c7, kernel_size=1), 126 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 127 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)), 128 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 129 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 130 | ) 131 | 132 | self.branch_pool = nn.Sequential( 133 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 134 | BasicConv2d(input_channels, 192, kernel_size=1), 135 | ) 136 | 137 | def forward(self, x): 138 | 139 | #x -> 1x1(same) 140 | branch1x1 = self.branch1x1(x) 141 | 142 | #x -> 1layer 1*7 and 7*1 (same) 143 | branch7x7 = self.branch7x7(x) 144 | 145 | #x-> 2layer 1*7 and 7*1(same) 146 | branch7x7stack = self.branch7x7stack(x) 147 | 148 | #x-> avgpool (same) 149 | branchpool = self.branch_pool(x) 150 | 151 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool] 152 | 153 | return torch.cat(outputs, 1) 154 | 155 | class InceptionD(nn.Module): 156 | 157 | def __init__(self, input_channels): 158 | super().__init__() 159 | 160 | self.branch3x3 = nn.Sequential( 161 | BasicConv2d(input_channels, 192, kernel_size=1), 162 | BasicConv2d(192, 320, kernel_size=3, stride=2) 163 | ) 164 | 165 | self.branch7x7 = nn.Sequential( 166 | BasicConv2d(input_channels, 192, kernel_size=1), 167 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 168 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), 169 | BasicConv2d(192, 192, kernel_size=3, stride=2) 170 | ) 171 | 172 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2) 173 | 174 | def forward(self, x): 175 | 176 | #x -> 1x1 -> 3x3(downsample) 177 | branch3x3 = self.branch3x3(x) 178 | 179 | #x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample) 180 | branch7x7 = self.branch7x7(x) 181 | 182 | #x -> avgpool (downsample) 183 | branchpool = self.branchpool(x) 184 | 185 | outputs = [branch3x3, branch7x7, branchpool] 186 | 187 | return torch.cat(outputs, 1) 188 | 189 | 190 | #same 191 | class InceptionE(nn.Module): 192 | def __init__(self, input_channels): 193 | super().__init__() 194 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1) 195 | 196 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1) 197 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 198 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 199 | 200 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1) 201 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 202 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 203 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 204 | 205 | self.branch_pool = nn.Sequential( 206 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 207 | BasicConv2d(input_channels, 192, kernel_size=1) 208 | ) 209 | 210 | def forward(self, x): 211 | 212 | #x -> 1x1 (same) 213 | branch1x1 = self.branch1x1(x) 214 | 215 | # x -> 1x1 -> 3x1 216 | # x -> 1x1 -> 1x3 217 | # concatenate(3x1, 1x3) 218 | #"""7. Inception modules with expanded the filter bank outputs. 219 | #This architecture is used on the coarsest (8 × 8) grids to promote 220 | #high dimensional representations, as suggested by principle 221 | #2 of Section 2.""" 222 | branch3x3 = self.branch3x3_1(x) 223 | branch3x3 = [ 224 | self.branch3x3_2a(branch3x3), 225 | self.branch3x3_2b(branch3x3) 226 | ] 227 | branch3x3 = torch.cat(branch3x3, 1) 228 | 229 | # x -> 1x1 -> 3x3 -> 1x3 230 | # x -> 1x1 -> 3x3 -> 3x1 231 | #concatenate(1x3, 3x1) 232 | branch3x3stack = self.branch3x3stack_1(x) 233 | branch3x3stack = self.branch3x3stack_2(branch3x3stack) 234 | branch3x3stack = [ 235 | self.branch3x3stack_3a(branch3x3stack), 236 | self.branch3x3stack_3b(branch3x3stack) 237 | ] 238 | branch3x3stack = torch.cat(branch3x3stack, 1) 239 | 240 | branchpool = self.branch_pool(x) 241 | 242 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool] 243 | 244 | return torch.cat(outputs, 1) 245 | 246 | class InceptionV3(nn.Module): 247 | 248 | def __init__(self, num_classes=100): 249 | super().__init__() 250 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, padding=1) 251 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 252 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 253 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 254 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 255 | 256 | #naive inception module 257 | self.Mixed_5b = InceptionA(192, pool_features=32) 258 | self.Mixed_5c = InceptionA(256, pool_features=64) 259 | self.Mixed_5d = InceptionA(288, pool_features=64) 260 | 261 | #downsample 262 | self.Mixed_6a = InceptionB(288) 263 | 264 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 265 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 266 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 267 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 268 | 269 | #downsample 270 | self.Mixed_7a = InceptionD(768) 271 | 272 | self.Mixed_7b = InceptionE(1280) 273 | self.Mixed_7c = InceptionE(2048) 274 | 275 | #6*6 feature size 276 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 277 | self.dropout = nn.Dropout2d() 278 | self.linear = nn.Linear(2048, num_classes) 279 | 280 | def forward(self, x): 281 | 282 | #32 -> 30 283 | x = self.Conv2d_1a_3x3(x) 284 | x = self.Conv2d_2a_3x3(x) 285 | x = self.Conv2d_2b_3x3(x) 286 | x = self.Conv2d_3b_1x1(x) 287 | x = self.Conv2d_4a_3x3(x) 288 | 289 | #30 -> 30 290 | x = self.Mixed_5b(x) 291 | x = self.Mixed_5c(x) 292 | x = self.Mixed_5d(x) 293 | 294 | #30 -> 14 295 | #Efficient Grid Size Reduction to avoid representation 296 | #bottleneck 297 | x = self.Mixed_6a(x) 298 | 299 | #14 -> 14 300 | #"""In practice, we have found that employing this factorization does not 301 | #work well on early layers, but it gives very good results on medium 302 | #grid-sizes (On m × m feature maps, where m ranges between 12 and 20). 303 | #On that level, very good results can be achieved by using 1 × 7 convolutions 304 | #followed by 7 × 1 convolutions.""" 305 | x = self.Mixed_6b(x) 306 | x = self.Mixed_6c(x) 307 | x = self.Mixed_6d(x) 308 | x = self.Mixed_6e(x) 309 | 310 | #14 -> 6 311 | #Efficient Grid Size Reduction 312 | x = self.Mixed_7a(x) 313 | 314 | #6 -> 6 315 | #We are using this solution only on the coarsest grid, 316 | #since that is the place where producing high dimensional 317 | #sparse representation is the most critical as the ratio of 318 | #local processing (by 1 × 1 convolutions) is increased compared 319 | #to the spatial aggregation.""" 320 | x = self.Mixed_7b(x) 321 | x = self.Mixed_7c(x) 322 | 323 | #6 -> 1 324 | x = self.avgpool(x) 325 | x = self.dropout(x) 326 | x = x.view(x.size(0), -1) 327 | x = self.linear(x) 328 | return x 329 | 330 | 331 | def inceptionv3(): 332 | return InceptionV3() 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /recovery/nn/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """mobilenet in pytorch 2 | 3 | 4 | 5 | [1] Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam 6 | 7 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 8 | https://arxiv.org/abs/1704.04861 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class DepthSeperabelConv2d(nn.Module): 16 | 17 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 18 | super().__init__() 19 | self.depthwise = nn.Sequential( 20 | nn.Conv2d( 21 | input_channels, 22 | input_channels, 23 | kernel_size, 24 | groups=input_channels, 25 | **kwargs), 26 | nn.BatchNorm2d(input_channels), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | self.pointwise = nn.Sequential( 31 | nn.Conv2d(input_channels, output_channels, 1), 32 | nn.BatchNorm2d(output_channels), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.depthwise(x) 38 | x = self.pointwise(x) 39 | 40 | return x 41 | 42 | 43 | class BasicConv2d(nn.Module): 44 | 45 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 46 | 47 | super().__init__() 48 | self.conv = nn.Conv2d( 49 | input_channels, output_channels, kernel_size, **kwargs) 50 | self.bn = nn.BatchNorm2d(output_channels) 51 | self.relu = nn.ReLU(inplace=True) 52 | 53 | def forward(self, x): 54 | x = self.conv(x) 55 | x = self.bn(x) 56 | x = self.relu(x) 57 | 58 | return x 59 | 60 | 61 | class MobileNet(nn.Module): 62 | 63 | """ 64 | Args: 65 | width multipler: The role of the width multiplier α is to thin 66 | a network uniformly at each layer. For a given 67 | layer and width multiplier α, the number of 68 | input channels M becomes αM and the number of 69 | output channels N becomes αN. 70 | """ 71 | 72 | def __init__(self, width_multiplier=1, class_num=100): 73 | super().__init__() 74 | 75 | alpha = width_multiplier 76 | self.stem = nn.Sequential( 77 | BasicConv2d(3, int(32 * alpha), 3, padding=1, bias=False), 78 | DepthSeperabelConv2d( 79 | int(32 * alpha), 80 | int(64 * alpha), 81 | 3, 82 | padding=1, 83 | bias=False 84 | ) 85 | ) 86 | 87 | #downsample 88 | self.conv1 = nn.Sequential( 89 | DepthSeperabelConv2d( 90 | int(64 * alpha), 91 | int(128 * alpha), 92 | 3, 93 | stride=2, 94 | padding=1, 95 | bias=False 96 | ), 97 | DepthSeperabelConv2d( 98 | int(128 * alpha), 99 | int(128 * alpha), 100 | 3, 101 | padding=1, 102 | bias=False 103 | ) 104 | ) 105 | 106 | #downsample 107 | self.conv2 = nn.Sequential( 108 | DepthSeperabelConv2d( 109 | int(128 * alpha), 110 | int(256 * alpha), 111 | 3, 112 | stride=2, 113 | padding=1, 114 | bias=False 115 | ), 116 | DepthSeperabelConv2d( 117 | int(256 * alpha), 118 | int(256 * alpha), 119 | 3, 120 | padding=1, 121 | bias=False 122 | ) 123 | ) 124 | 125 | #downsample 126 | self.conv3 = nn.Sequential( 127 | DepthSeperabelConv2d( 128 | int(256 * alpha), 129 | int(512 * alpha), 130 | 3, 131 | stride=2, 132 | padding=1, 133 | bias=False 134 | ), 135 | 136 | DepthSeperabelConv2d( 137 | int(512 * alpha), 138 | int(512 * alpha), 139 | 3, 140 | padding=1, 141 | bias=False 142 | ), 143 | DepthSeperabelConv2d( 144 | int(512 * alpha), 145 | int(512 * alpha), 146 | 3, 147 | padding=1, 148 | bias=False 149 | ), 150 | DepthSeperabelConv2d( 151 | int(512 * alpha), 152 | int(512 * alpha), 153 | 3, 154 | padding=1, 155 | bias=False 156 | ), 157 | DepthSeperabelConv2d( 158 | int(512 * alpha), 159 | int(512 * alpha), 160 | 3, 161 | padding=1, 162 | bias=False 163 | ), 164 | DepthSeperabelConv2d( 165 | int(512 * alpha), 166 | int(512 * alpha), 167 | 3, 168 | padding=1, 169 | bias=False 170 | ) 171 | ) 172 | 173 | #downsample 174 | self.conv4 = nn.Sequential( 175 | DepthSeperabelConv2d( 176 | int(512 * alpha), 177 | int(1024 * alpha), 178 | 3, 179 | stride=2, 180 | padding=1, 181 | bias=False 182 | ), 183 | DepthSeperabelConv2d( 184 | int(1024 * alpha), 185 | int(1024 * alpha), 186 | 3, 187 | padding=1, 188 | bias=False 189 | ) 190 | ) 191 | 192 | self.fc = nn.Linear(int(1024 * alpha), class_num) 193 | self.avg = nn.AdaptiveAvgPool2d(1) 194 | 195 | def forward(self, x): 196 | x = self.stem(x) 197 | 198 | x = self.conv1(x) 199 | x = self.conv2(x) 200 | x = self.conv3(x) 201 | x = self.conv4(x) 202 | 203 | x = self.avg(x) 204 | x = x.view(x.size(0), -1) 205 | x = self.fc(x) 206 | return x 207 | 208 | 209 | def mobilenet(alpha=1, class_num=100): 210 | return MobileNet(alpha, class_num) 211 | 212 | -------------------------------------------------------------------------------- /recovery/nn/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """mobilenetv2 in pytorch 2 | 3 | 4 | 5 | [1] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 6 | 7 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 8 | https://arxiv.org/abs/1801.04381 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class LinearBottleNeck(nn.Module): 17 | 18 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 19 | super().__init__() 20 | 21 | self.residual = nn.Sequential( 22 | nn.Conv2d(in_channels, in_channels * t, 1), 23 | nn.BatchNorm2d(in_channels * t), 24 | nn.ReLU6(inplace=True), 25 | 26 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 27 | nn.BatchNorm2d(in_channels * t), 28 | nn.ReLU6(inplace=True), 29 | 30 | nn.Conv2d(in_channels * t, out_channels, 1), 31 | nn.BatchNorm2d(out_channels) 32 | ) 33 | 34 | self.stride = stride 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | 38 | def forward(self, x): 39 | 40 | residual = self.residual(x) 41 | 42 | if self.stride == 1 and self.in_channels == self.out_channels: 43 | residual += x 44 | 45 | return residual 46 | 47 | class MobileNetV2(nn.Module): 48 | 49 | def __init__(self, class_num=100): 50 | super().__init__() 51 | 52 | self.pre = nn.Sequential( 53 | nn.Conv2d(3, 32, 1, padding=1), 54 | nn.BatchNorm2d(32), 55 | nn.ReLU6(inplace=True) 56 | ) 57 | 58 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 59 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 60 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 61 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 62 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 63 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 64 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 65 | 66 | self.conv1 = nn.Sequential( 67 | nn.Conv2d(320, 1280, 1), 68 | nn.BatchNorm2d(1280), 69 | nn.ReLU6(inplace=True) 70 | ) 71 | 72 | self.conv2 = nn.Conv2d(1280, class_num, 1) 73 | 74 | def forward(self, x): 75 | x = self.pre(x) 76 | x = self.stage1(x) 77 | x = self.stage2(x) 78 | x = self.stage3(x) 79 | x = self.stage4(x) 80 | x = self.stage5(x) 81 | x = self.stage6(x) 82 | x = self.stage7(x) 83 | x = self.conv1(x) 84 | x = F.adaptive_avg_pool2d(x, 1) 85 | x = self.conv2(x) 86 | x = x.view(x.size(0), -1) 87 | 88 | return x 89 | 90 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 91 | 92 | layers = [] 93 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 94 | 95 | while repeat - 1: 96 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 97 | repeat -= 1 98 | 99 | return nn.Sequential(*layers) 100 | 101 | def mobilenetv2(): 102 | return MobileNetV2() -------------------------------------------------------------------------------- /recovery/nn/models/nasnet.py: -------------------------------------------------------------------------------- 1 | """nasnet in pytorch 2 | 3 | 4 | 5 | [1] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le 6 | 7 | Learning Transferable Architectures for Scalable Image Recognition 8 | https://arxiv.org/abs/1707.07012 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class SeperableConv2d(nn.Module): 15 | 16 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 17 | 18 | super().__init__() 19 | self.depthwise = nn.Conv2d( 20 | input_channels, 21 | input_channels, 22 | kernel_size, 23 | groups=input_channels, 24 | **kwargs 25 | ) 26 | 27 | self.pointwise = nn.Conv2d( 28 | input_channels, 29 | output_channels, 30 | 1 31 | ) 32 | def forward(self, x): 33 | x = self.depthwise(x) 34 | x = self.pointwise(x) 35 | 36 | return x 37 | 38 | class SeperableBranch(nn.Module): 39 | 40 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 41 | """Adds 2 blocks of [relu-separable conv-batchnorm].""" 42 | super().__init__() 43 | self.block1 = nn.Sequential( 44 | nn.ReLU(), 45 | SeperableConv2d(input_channels, output_channels, kernel_size, **kwargs), 46 | nn.BatchNorm2d(output_channels) 47 | ) 48 | 49 | self.block2 = nn.Sequential( 50 | nn.ReLU(), 51 | SeperableConv2d(output_channels, output_channels, kernel_size, stride=1, padding=int(kernel_size / 2)), 52 | nn.BatchNorm2d(output_channels) 53 | ) 54 | 55 | def forward(self, x): 56 | x = self.block1(x) 57 | x = self.block2(x) 58 | 59 | return x 60 | 61 | class Fit(nn.Module): 62 | """Make the cell outputs compatible 63 | 64 | Args: 65 | prev_filters: filter number of tensor prev, needs to be modified 66 | filters: filter number of normal cell branch output filters 67 | """ 68 | 69 | def __init__(self, prev_filters, filters): 70 | super().__init__() 71 | self.relu = nn.ReLU() 72 | 73 | self.p1 = nn.Sequential( 74 | nn.AvgPool2d(1, stride=2), 75 | nn.Conv2d(prev_filters, int(filters / 2), 1) 76 | ) 77 | 78 | #make sure there is no information loss 79 | self.p2 = nn.Sequential( 80 | nn.ConstantPad2d((0, 1, 0, 1), 0), 81 | nn.ConstantPad2d((-1, 0, -1, 0), 0), #cropping 82 | nn.AvgPool2d(1, stride=2), 83 | nn.Conv2d(prev_filters, int(filters / 2), 1) 84 | ) 85 | 86 | self.bn = nn.BatchNorm2d(filters) 87 | 88 | self.dim_reduce = nn.Sequential( 89 | nn.ReLU(), 90 | nn.Conv2d(prev_filters, filters, 1), 91 | nn.BatchNorm2d(filters) 92 | ) 93 | 94 | self.filters = filters 95 | 96 | def forward(self, inputs): 97 | x, prev = inputs 98 | if prev is None: 99 | return x 100 | 101 | #image size does not match 102 | elif x.size(2) != prev.size(2): 103 | prev = self.relu(prev) 104 | p1 = self.p1(prev) 105 | p2 = self.p2(prev) 106 | prev = torch.cat([p1, p2], 1) 107 | prev = self.bn(prev) 108 | 109 | elif prev.size(1) != self.filters: 110 | prev = self.dim_reduce(prev) 111 | 112 | return prev 113 | 114 | 115 | class NormalCell(nn.Module): 116 | 117 | def __init__(self, x_in, prev_in, output_channels): 118 | super().__init__() 119 | 120 | self.dem_reduce = nn.Sequential( 121 | nn.ReLU(), 122 | nn.Conv2d(x_in, output_channels, 1, bias=False), 123 | nn.BatchNorm2d(output_channels) 124 | ) 125 | 126 | self.block1_left = SeperableBranch( 127 | output_channels, 128 | output_channels, 129 | kernel_size=3, 130 | padding=1, 131 | bias=False 132 | ) 133 | self.block1_right = nn.Sequential() 134 | 135 | self.block2_left = SeperableBranch( 136 | output_channels, 137 | output_channels, 138 | kernel_size=3, 139 | padding=1, 140 | bias=False 141 | ) 142 | self.block2_right = SeperableBranch( 143 | output_channels, 144 | output_channels, 145 | kernel_size=5, 146 | padding=2, 147 | bias=False 148 | ) 149 | 150 | self.block3_left = nn.AvgPool2d(3, stride=1, padding=1) 151 | self.block3_right = nn.Sequential() 152 | 153 | self.block4_left = nn.AvgPool2d(3, stride=1, padding=1) 154 | self.block4_right = nn.AvgPool2d(3, stride=1, padding=1) 155 | 156 | self.block5_left = SeperableBranch( 157 | output_channels, 158 | output_channels, 159 | kernel_size=5, 160 | padding=2, 161 | bias=False 162 | ) 163 | self.block5_right = SeperableBranch( 164 | output_channels, 165 | output_channels, 166 | kernel_size=3, 167 | padding=1, 168 | bias=False 169 | ) 170 | 171 | self.fit = Fit(prev_in, output_channels) 172 | 173 | def forward(self, x): 174 | x, prev = x 175 | 176 | #return transformed x as new x, and original x as prev 177 | #only prev tensor needs to be modified 178 | prev = self.fit((x, prev)) 179 | 180 | h = self.dem_reduce(x) 181 | 182 | x1 = self.block1_left(h) + self.block1_right(h) 183 | x2 = self.block2_left(prev) + self.block2_right(h) 184 | x3 = self.block3_left(h) + self.block3_right(h) 185 | x4 = self.block4_left(prev) + self.block4_right(prev) 186 | x5 = self.block5_left(prev) + self.block5_right(prev) 187 | 188 | return torch.cat([prev, x1, x2, x3, x4, x5], 1), x 189 | 190 | class ReductionCell(nn.Module): 191 | 192 | def __init__(self, x_in, prev_in, output_channels): 193 | super().__init__() 194 | 195 | self.dim_reduce = nn.Sequential( 196 | nn.ReLU(), 197 | nn.Conv2d(x_in, output_channels, 1), 198 | nn.BatchNorm2d(output_channels) 199 | ) 200 | 201 | #block1 202 | self.layer1block1_left = SeperableBranch(output_channels, output_channels, 7, stride=2, padding=3) 203 | self.layer1block1_right = SeperableBranch(output_channels, output_channels, 5, stride=2, padding=2) 204 | 205 | #block2 206 | self.layer1block2_left = nn.MaxPool2d(3, stride=2, padding=1) 207 | self.layer1block2_right = SeperableBranch(output_channels, output_channels, 7, stride=2, padding=3) 208 | 209 | #block3 210 | self.layer1block3_left = nn.AvgPool2d(3, 2, 1) 211 | self.layer1block3_right = SeperableBranch(output_channels, output_channels, 5, stride=2, padding=2) 212 | 213 | #block5 214 | self.layer2block1_left = nn.MaxPool2d(3, 2, 1) 215 | self.layer2block1_right = SeperableBranch(output_channels, output_channels, 3, stride=1, padding=1) 216 | 217 | #block4 218 | self.layer2block2_left = nn.AvgPool2d(3, 1, 1) 219 | self.layer2block2_right = nn.Sequential() 220 | 221 | self.fit = Fit(prev_in, output_channels) 222 | 223 | def forward(self, x): 224 | x, prev = x 225 | prev = self.fit((x, prev)) 226 | 227 | h = self.dim_reduce(x) 228 | 229 | layer1block1 = self.layer1block1_left(prev) + self.layer1block1_right(h) 230 | layer1block2 = self.layer1block2_left(h) + self.layer1block2_right(prev) 231 | layer1block3 = self.layer1block3_left(h) + self.layer1block3_right(prev) 232 | layer2block1 = self.layer2block1_left(h) + self.layer2block1_right(layer1block1) 233 | layer2block2 = self.layer2block2_left(layer1block1) + self.layer2block2_right(layer1block2) 234 | 235 | return torch.cat([ 236 | layer1block2, #https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 237 | layer1block3, 238 | layer2block1, 239 | layer2block2 240 | ], 1), x 241 | 242 | 243 | class NasNetA(nn.Module): 244 | 245 | def __init__(self, repeat_cell_num, reduction_num, filters, stemfilter, class_num=100): 246 | super().__init__() 247 | 248 | self.stem = nn.Sequential( 249 | nn.Conv2d(3, stemfilter, 3, padding=1, bias=False), 250 | nn.BatchNorm2d(stemfilter) 251 | ) 252 | 253 | self.prev_filters = stemfilter 254 | self.x_filters = stemfilter 255 | self.filters = filters 256 | 257 | self.cell_layers = self._make_layers(repeat_cell_num, reduction_num) 258 | 259 | self.relu = nn.ReLU() 260 | self.avg = nn.AdaptiveAvgPool2d(1) 261 | self.fc = nn.Linear(self.filters * 6, class_num) 262 | 263 | 264 | def _make_normal(self, block, repeat, output): 265 | """make normal cell 266 | Args: 267 | block: cell type 268 | repeat: number of repeated normal cell 269 | output: output filters for each branch in normal cell 270 | Returns: 271 | stacked normal cells 272 | """ 273 | 274 | layers = [] 275 | for r in range(repeat): 276 | layers.append(block(self.x_filters, self.prev_filters, output)) 277 | self.prev_filters = self.x_filters 278 | self.x_filters = output * 6 #concatenate 6 branches 279 | 280 | return layers 281 | 282 | def _make_reduction(self, block, output): 283 | """make normal cell 284 | Args: 285 | block: cell type 286 | output: output filters for each branch in reduction cell 287 | Returns: 288 | reduction cell 289 | """ 290 | 291 | reduction = block(self.x_filters, self.prev_filters, output) 292 | self.prev_filters = self.x_filters 293 | self.x_filters = output * 4 #stack for 4 branches 294 | 295 | return reduction 296 | 297 | def _make_layers(self, repeat_cell_num, reduction_num): 298 | 299 | layers = [] 300 | for i in range(reduction_num): 301 | 302 | layers.extend(self._make_normal(NormalCell, repeat_cell_num, self.filters)) 303 | self.filters *= 2 304 | layers.append(self._make_reduction(ReductionCell, self.filters)) 305 | 306 | layers.extend(self._make_normal(NormalCell, repeat_cell_num, self.filters)) 307 | 308 | return nn.Sequential(*layers) 309 | 310 | 311 | def forward(self, x): 312 | 313 | x = self.stem(x) 314 | prev = None 315 | x, prev = self.cell_layers((x, prev)) 316 | x = self.relu(x) 317 | x = self.avg(x) 318 | x = x.view(x.size(0), -1) 319 | x = self.fc(x) 320 | 321 | return x 322 | 323 | 324 | def nasnet(): 325 | 326 | #stem filters must be 44, it's a pytorch workaround, cant change to other number 327 | return NasNetA(4, 2, 44, 44) 328 | 329 | -------------------------------------------------------------------------------- /recovery/nn/models/preactresnet.py: -------------------------------------------------------------------------------- 1 | """preactresnet in pytorch 2 | 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | 5 | Identity Mappings in Deep Residual Networks 6 | https://arxiv.org/abs/1603.05027 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | class PreActBasic(nn.Module): 14 | 15 | expansion = 1 16 | def __init__(self, in_channels, out_channels, stride): 17 | super().__init__() 18 | self.residual = nn.Sequential( 19 | nn.BatchNorm2d(in_channels), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), 22 | nn.BatchNorm2d(out_channels), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(out_channels, out_channels * PreActBasic.expansion, kernel_size=3, padding=1) 25 | ) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_channels != out_channels * PreActBasic.expansion: 29 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBasic.expansion, 1, stride=stride) 30 | 31 | def forward(self, x): 32 | 33 | res = self.residual(x) 34 | shortcut = self.shortcut(x) 35 | 36 | return res + shortcut 37 | 38 | 39 | class PreActBottleNeck(nn.Module): 40 | 41 | expansion = 4 42 | def __init__(self, in_channels, out_channels, stride): 43 | super().__init__() 44 | 45 | self.residual = nn.Sequential( 46 | nn.BatchNorm2d(in_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(in_channels, out_channels, 1, stride=stride), 49 | 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 53 | 54 | nn.BatchNorm2d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(out_channels, out_channels * PreActBottleNeck.expansion, 1) 57 | ) 58 | 59 | self.shortcut = nn.Sequential() 60 | 61 | if stride != 1 or in_channels != out_channels * PreActBottleNeck.expansion: 62 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBottleNeck.expansion, 1, stride=stride) 63 | 64 | def forward(self, x): 65 | 66 | res = self.residual(x) 67 | shortcut = self.shortcut(x) 68 | 69 | return res + shortcut 70 | 71 | class PreActResNet(nn.Module): 72 | 73 | def __init__(self, block, num_block, class_num=100): 74 | super().__init__() 75 | self.input_channels = 64 76 | 77 | self.pre = nn.Sequential( 78 | nn.Conv2d(3, 64, 3, padding=1), 79 | nn.BatchNorm2d(64), 80 | nn.ReLU(inplace=True) 81 | ) 82 | 83 | self.stage1 = self._make_layers(block, num_block[0], 64, 1) 84 | self.stage2 = self._make_layers(block, num_block[1], 128, 2) 85 | self.stage3 = self._make_layers(block, num_block[2], 256, 2) 86 | self.stage4 = self._make_layers(block, num_block[3], 512, 2) 87 | 88 | self.linear = nn.Linear(self.input_channels, class_num) 89 | 90 | def _make_layers(self, block, block_num, out_channels, stride): 91 | layers = [] 92 | 93 | layers.append(block(self.input_channels, out_channels, stride)) 94 | self.input_channels = out_channels * block.expansion 95 | 96 | while block_num - 1: 97 | layers.append(block(self.input_channels, out_channels, 1)) 98 | self.input_channels = out_channels * block.expansion 99 | block_num -= 1 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.pre(x) 105 | 106 | x = self.stage1(x) 107 | x = self.stage2(x) 108 | x = self.stage3(x) 109 | x = self.stage4(x) 110 | 111 | x = F.adaptive_avg_pool2d(x, 1) 112 | x = x.view(x.size(0), -1) 113 | x = self.linear(x) 114 | 115 | return x 116 | 117 | def preactresnet18(): 118 | return PreActResNet(PreActBasic, [2, 2, 2, 2]) 119 | 120 | def preactresnet34(): 121 | return PreActResNet(PreActBasic, [3, 4, 6, 3]) 122 | 123 | def preactresnet50(): 124 | return PreActResNet(PreActBottleNeck, [3, 4, 6, 3]) 125 | 126 | def preactresnet101(): 127 | return PreActResNet(PreActBottleNeck, [3, 4, 23, 3]) 128 | 129 | def preactresnet152(): 130 | return PreActResNet(PreActBottleNeck, [3, 8, 36, 3]) 131 | 132 | -------------------------------------------------------------------------------- /recovery/nn/models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class BasicBlock(nn.Module): 15 | """Basic Block for resnet 18 and resnet 34 16 | 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | def __init__(self, in_channels, out_channels, stride=1): 57 | super().__init__() 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 67 | ) 68 | 69 | self.shortcut = nn.Sequential() 70 | 71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 75 | ) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, num_block, num_classes=100): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential( 88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 89 | nn.BatchNorm2d(64), 90 | nn.ReLU(inplace=True)) 91 | #we use a different inputsize than the original paper 92 | #so conv2_x's stride is 1 93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.fc = nn.Linear(512 * block.expansion, num_classes) 99 | 100 | def _make_layer(self, block, out_channels, num_blocks, stride): 101 | """make resnet layers(by layer i didnt mean this 'layer' was the 102 | same as a neuron netowork layer, ex. conv layer), one layer may 103 | contain more than one residual block 104 | 105 | Args: 106 | block: block type, basic block or bottle neck block 107 | out_channels: output depth channel number of this layer 108 | num_blocks: how many blocks per layer 109 | stride: the stride of the first block of this layer 110 | 111 | Return: 112 | return a resnet layer 113 | """ 114 | 115 | # we have num_block blocks per layer, the first block 116 | # could be 1 or 2, other blocks would always be 1 117 | strides = [stride] + [1] * (num_blocks - 1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_channels, out_channels, stride)) 121 | self.in_channels = out_channels * block.expansion 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | output = self.conv1(x) 127 | output = self.conv2_x(output) 128 | output = self.conv3_x(output) 129 | output = self.conv4_x(output) 130 | output = self.conv5_x(output) 131 | output = self.avg_pool(output) 132 | output = output.view(output.size(0), -1) 133 | output = self.fc(output) 134 | 135 | return output 136 | 137 | def resnet18(num_classes=100): 138 | """ return a ResNet 18 object 139 | """ 140 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 141 | 142 | def resnet34(num_classes=100): 143 | """ return a ResNet 34 object 144 | """ 145 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 146 | 147 | def resnet50(num_classes=100): 148 | """ return a ResNet 50 object 149 | """ 150 | return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes) 151 | 152 | def resnet101(num_classes=100): 153 | """ return a ResNet 101 object 154 | """ 155 | return ResNet(BottleNeck, [3, 4, 23, 3], num_classes=100) 156 | 157 | def resnet152(): 158 | """ return a ResNet 152 object 159 | """ 160 | return ResNet(BottleNeck, [3, 8, 36, 3]) 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /recovery/nn/models/resnext.py: -------------------------------------------------------------------------------- 1 | """resnext in pytorch 2 | 3 | 4 | 5 | [1] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. 6 | 7 | Aggregated Residual Transformations for Deep Neural Networks 8 | https://arxiv.org/abs/1611.05431 9 | """ 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | #only implements ResNext bottleneck c 17 | 18 | 19 | #"""This strategy exposes a new dimension, which we call “cardinality” 20 | #(the size of the set of transformations), as an essential factor 21 | #in addition to the dimensions of depth and width.""" 22 | CARDINALITY = 32 23 | DEPTH = 4 24 | BASEWIDTH = 64 25 | 26 | #"""The grouped convolutional layer in Fig. 3(c) performs 32 groups 27 | #of convolutions whose input and output channels are 4-dimensional. 28 | #The grouped convolutional layer concatenates them as the outputs 29 | #of the layer.""" 30 | 31 | class ResNextBottleNeckC(nn.Module): 32 | 33 | def __init__(self, in_channels, out_channels, stride): 34 | super().__init__() 35 | 36 | C = CARDINALITY #How many groups a feature map was splitted into 37 | 38 | #"""We note that the input/output width of the template is fixed as 39 | #256-d (Fig. 3), We note that the input/output width of the template 40 | #is fixed as 256-d (Fig. 3), and all widths are dou- bled each time 41 | #when the feature map is subsampled (see Table 1).""" 42 | D = int(DEPTH * out_channels / BASEWIDTH) #number of channels per group 43 | self.split_transforms = nn.Sequential( 44 | nn.Conv2d(in_channels, C * D, kernel_size=1, groups=C, bias=False), 45 | nn.BatchNorm2d(C * D), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(C * D, C * D, kernel_size=3, stride=stride, groups=C, padding=1, bias=False), 48 | nn.BatchNorm2d(C * D), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(C * D, out_channels * 4, kernel_size=1, bias=False), 51 | nn.BatchNorm2d(out_channels * 4), 52 | ) 53 | 54 | self.shortcut = nn.Sequential() 55 | 56 | if stride != 1 or in_channels != out_channels * 4: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels * 4, stride=stride, kernel_size=1, bias=False), 59 | nn.BatchNorm2d(out_channels * 4) 60 | ) 61 | 62 | def forward(self, x): 63 | return F.relu(self.split_transforms(x) + self.shortcut(x)) 64 | 65 | class ResNext(nn.Module): 66 | 67 | def __init__(self, block, num_blocks, class_names=100): 68 | super().__init__() 69 | self.in_channels = 64 70 | 71 | self.conv1 = nn.Sequential( 72 | nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False), 73 | nn.BatchNorm2d(64), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | self.conv2 = self._make_layer(block, num_blocks[0], 64, 1) 78 | self.conv3 = self._make_layer(block, num_blocks[1], 128, 2) 79 | self.conv4 = self._make_layer(block, num_blocks[2], 256, 2) 80 | self.conv5 = self._make_layer(block, num_blocks[3], 512, 2) 81 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 82 | self.fc = nn.Linear(512 * 4, 100) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.conv2(x) 87 | x = self.conv3(x) 88 | x = self.conv4(x) 89 | x = self.conv5(x) 90 | x = self.avg(x) 91 | x = x.view(x.size(0), -1) 92 | x = self.fc(x) 93 | return x 94 | 95 | def _make_layer(self, block, num_block, out_channels, stride): 96 | """Building resnext block 97 | Args: 98 | block: block type(default resnext bottleneck c) 99 | num_block: number of blocks per layer 100 | out_channels: output channels per block 101 | stride: block stride 102 | 103 | Returns: 104 | a resnext layer 105 | """ 106 | strides = [stride] + [1] * (num_block - 1) 107 | layers = [] 108 | for stride in strides: 109 | layers.append(block(self.in_channels, out_channels, stride)) 110 | self.in_channels = out_channels * 4 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def resnext50(): 115 | """ return a resnext50(c32x4d) network 116 | """ 117 | return ResNext(ResNextBottleNeckC, [3, 4, 6, 3]) 118 | 119 | def resnext101(): 120 | """ return a resnext101(c32x4d) network 121 | """ 122 | return ResNext(ResNextBottleNeckC, [3, 4, 23, 3]) 123 | 124 | def resnext152(): 125 | """ return a resnext101(c32x4d) network 126 | """ 127 | return ResNext(ResNextBottleNeckC, [3, 4, 36, 3]) 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /recovery/nn/models/rir.py: -------------------------------------------------------------------------------- 1 | """resnet in resnet in pytorch 2 | 3 | 4 | 5 | [1] Sasha Targ, Diogo Almeida, Kevin Lyman. 6 | 7 | Resnet in Resnet: Generalizing Residual Architectures 8 | https://arxiv.org/abs/1603.08029v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | #geralized 15 | class ResnetInit(nn.Module): 16 | def __init__(self, in_channel, out_channel, stride): 17 | super().__init__() 18 | 19 | #"""The modular unit of the generalized residual network architecture is a 20 | #generalized residual block consisting of parallel states for a residual stream, 21 | #r, which contains identity shortcut connections and is similar to the structure 22 | #of a residual block from the original ResNet with a single convolutional layer 23 | #(parameters W l,r→r ) 24 | self.residual_stream_conv = nn.Conv2d(in_channel, out_channel, 3, padding=1, stride=stride) 25 | 26 | #"""and a transient stream, t, which is a standard convolutional layer 27 | #(W l,t→t ).""" 28 | self.transient_stream_conv = nn.Conv2d(in_channel, out_channel, 3, padding=1, stride=stride) 29 | 30 | #"""Two additional sets of convolutional filters in each block (W l,r→t , W l,t→r ) 31 | #also transfer information across streams.""" 32 | self.residual_stream_conv_across = nn.Conv2d(in_channel, out_channel, 3, padding=1, stride=stride) 33 | 34 | #"""We use equal numbers of filters for the residual and transient streams of the 35 | #generalized residual network, but optimizing this hyperparameter could lead to 36 | #further potential improvements.""" 37 | self.transient_stream_conv_across = nn.Conv2d(in_channel, out_channel, 3, padding=1, stride=stride) 38 | 39 | self.residual_bn_relu = nn.Sequential( 40 | nn.BatchNorm2d(out_channel), 41 | nn.ReLU(inplace=True) 42 | ) 43 | 44 | self.transient_bn_relu = nn.Sequential( 45 | nn.BatchNorm2d(out_channel), 46 | nn.ReLU(inplace=True) 47 | ) 48 | 49 | #"""The form of the shortcut connection can be an identity function with 50 | #the appropriate padding or a projection as in He et al. (2015b).""" 51 | self.short_cut = nn.Sequential() 52 | if in_channel != out_channel or stride != 1: 53 | self.short_cut = nn.Sequential( 54 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride) 55 | ) 56 | 57 | 58 | def forward(self, x): 59 | x_residual, x_transient = x 60 | residual_r_r = self.residual_stream_conv(x_residual) 61 | residual_r_t = self.residual_stream_conv_across(x_residual) 62 | residual_shortcut = self.short_cut(x_residual) 63 | 64 | transient_t_t = self.transient_stream_conv(x_transient) 65 | transient_t_r = self.transient_stream_conv_across(x_transient) 66 | 67 | #transient_t_t = self.transient_stream_conv(x_residual) 68 | #transient_t_r = self.transient_stream_conv_across(x_residual) 69 | #"""Same-stream and cross-stream activations are summed (along with the 70 | #shortcut connection for the residual stream) before applying batch 71 | #normalization and ReLU nonlinearities (together σ) to get the output 72 | #states of the block (Equation 1) (Ioffe & Szegedy, 2015).""" 73 | x_residual = self.residual_bn_relu(residual_r_r + transient_t_r + residual_shortcut) 74 | x_transient = self.transient_bn_relu(residual_r_t + transient_t_t) 75 | 76 | return x_residual, x_transient 77 | 78 | 79 | 80 | class RiRBlock(nn.Module): 81 | def __init__(self, in_channel, out_channel, layer_num, stride, layer=ResnetInit): 82 | super().__init__() 83 | self.resnetinit = self._make_layers(in_channel, out_channel, layer_num, stride) 84 | 85 | #self.short_cut = nn.Sequential() 86 | #if stride != 1 or in_channel != out_channel: 87 | # self.short_cut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride) 88 | 89 | def forward(self, x): 90 | x_residual, x_transient = self.resnetinit(x) 91 | #x_residual = x_residual + self.short_cut(x[0]) 92 | #x_transient = x_transient + self.short_cut(x[1]) 93 | 94 | return (x_residual, x_transient) 95 | 96 | #"""Replacing each of the convolutional layers within a residual 97 | #block from the original ResNet (Figure 1a) with a generalized residual block 98 | #(Figure 1b) leads us to a new architecture we call ResNet in ResNet (RiR) 99 | #(Figure 1d).""" 100 | def _make_layers(self, in_channel, out_channel, layer_num, stride, layer=ResnetInit): 101 | strides = [stride] + [1] * (layer_num - 1) 102 | layers = nn.Sequential() 103 | for index, s in enumerate(strides): 104 | layers.add_module("generalized layers{}".format(index), layer(in_channel, out_channel, s)) 105 | in_channel = out_channel 106 | 107 | return layers 108 | 109 | class ResnetInResneet(nn.Module): 110 | def __init__(self, num_classes=100): 111 | super().__init__() 112 | base = int(96 / 2) 113 | self.residual_pre_conv = nn.Sequential( 114 | nn.Conv2d(3, base, 3, padding=1), 115 | nn.BatchNorm2d(base), 116 | nn.ReLU(inplace=True) 117 | ) 118 | self.transient_pre_conv = nn.Sequential( 119 | nn.Conv2d(3, base, 3, padding=1), 120 | nn.BatchNorm2d(base), 121 | nn.ReLU(inplace=True) 122 | ) 123 | 124 | self.rir1 = RiRBlock(base, base, 2, 1) 125 | self.rir2 = RiRBlock(base, base, 2, 1) 126 | self.rir3 = RiRBlock(base, base * 2, 2, 2) 127 | self.rir4 = RiRBlock(base * 2, base * 2, 2, 1) 128 | self.rir5 = RiRBlock(base * 2, base * 2, 2, 1) 129 | self.rir6 = RiRBlock(base * 2, base * 4, 2, 2) 130 | self.rir7 = RiRBlock(base * 4, base * 4, 2, 1) 131 | self.rir8 = RiRBlock(base * 4, base * 4, 2, 1) 132 | 133 | self.conv1 = nn.Sequential( 134 | nn.Conv2d(384, num_classes, kernel_size=3, stride=2), #without this convolution, loss will soon be nan 135 | nn.BatchNorm2d(num_classes), 136 | nn.ReLU(inplace=True), 137 | ) 138 | 139 | self.classifier = nn.Sequential( 140 | nn.Linear(900, 450), 141 | nn.ReLU(), 142 | nn.Dropout(), 143 | nn.Linear(450, 100), 144 | ) 145 | 146 | self._weight_init() 147 | 148 | def forward(self, x): 149 | x_residual = self.residual_pre_conv(x) 150 | x_transient = self.transient_pre_conv(x) 151 | 152 | x_residual, x_transient = self.rir1((x_residual, x_transient)) 153 | x_residual, x_transient = self.rir2((x_residual, x_transient)) 154 | x_residual, x_transient = self.rir3((x_residual, x_transient)) 155 | x_residual, x_transient = self.rir4((x_residual, x_transient)) 156 | x_residual, x_transient = self.rir5((x_residual, x_transient)) 157 | x_residual, x_transient = self.rir6((x_residual, x_transient)) 158 | x_residual, x_transient = self.rir7((x_residual, x_transient)) 159 | x_residual, x_transient = self.rir8((x_residual, x_transient)) 160 | h = torch.cat([x_residual, x_transient], 1) 161 | h = self.conv1(h) 162 | h = h.view(h.size()[0], -1) 163 | h = self.classifier(h) 164 | 165 | return h 166 | 167 | def _weight_init(self): 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | torch.nn.init.kaiming_normal(m.weight) 171 | m.bias.data.fill_(0.01) 172 | 173 | 174 | def resnet_in_resnet(): 175 | return ResnetInResneet() 176 | -------------------------------------------------------------------------------- /recovery/nn/models/senet.py: -------------------------------------------------------------------------------- 1 | """senet in pytorch 2 | 3 | 4 | 5 | [1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu 6 | 7 | Squeeze-and-Excitation Networks 8 | https://arxiv.org/abs/1709.01507 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class BasicResidualSEBlock(nn.Module): 16 | 17 | expansion = 1 18 | 19 | def __init__(self, in_channels, out_channels, stride, r=16): 20 | super().__init__() 21 | 22 | self.residual = nn.Sequential( 23 | nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1), 24 | nn.BatchNorm2d(out_channels), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1), 28 | nn.BatchNorm2d(out_channels * self.expansion), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_channels != out_channels * self.expansion: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 36 | nn.BatchNorm2d(out_channels * self.expansion) 37 | ) 38 | 39 | self.squeeze = nn.AdaptiveAvgPool2d(1) 40 | self.excitation = nn.Sequential( 41 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, x): 48 | shortcut = self.shortcut(x) 49 | residual = self.residual(x) 50 | 51 | squeeze = self.squeeze(residual) 52 | squeeze = squeeze.view(squeeze.size(0), -1) 53 | excitation = self.excitation(squeeze) 54 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 55 | 56 | x = residual * excitation.expand_as(residual) + shortcut 57 | 58 | return F.relu(x) 59 | 60 | class BottleneckResidualSEBlock(nn.Module): 61 | 62 | expansion = 4 63 | 64 | def __init__(self, in_channels, out_channels, stride, r=16): 65 | super().__init__() 66 | 67 | self.residual = nn.Sequential( 68 | nn.Conv2d(in_channels, out_channels, 1), 69 | nn.BatchNorm2d(out_channels), 70 | nn.ReLU(inplace=True), 71 | 72 | nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1), 73 | nn.BatchNorm2d(out_channels), 74 | nn.ReLU(inplace=True), 75 | 76 | nn.Conv2d(out_channels, out_channels * self.expansion, 1), 77 | nn.BatchNorm2d(out_channels * self.expansion), 78 | nn.ReLU(inplace=True) 79 | ) 80 | 81 | self.squeeze = nn.AdaptiveAvgPool2d(1) 82 | self.excitation = nn.Sequential( 83 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 84 | nn.ReLU(inplace=True), 85 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 86 | nn.Sigmoid() 87 | ) 88 | 89 | self.shortcut = nn.Sequential() 90 | if stride != 1 or in_channels != out_channels * self.expansion: 91 | self.shortcut = nn.Sequential( 92 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 93 | nn.BatchNorm2d(out_channels * self.expansion) 94 | ) 95 | 96 | def forward(self, x): 97 | 98 | shortcut = self.shortcut(x) 99 | 100 | residual = self.residual(x) 101 | squeeze = self.squeeze(residual) 102 | squeeze = squeeze.view(squeeze.size(0), -1) 103 | excitation = self.excitation(squeeze) 104 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 105 | 106 | x = residual * excitation.expand_as(residual) + shortcut 107 | 108 | return F.relu(x) 109 | 110 | class SEResNet(nn.Module): 111 | 112 | def __init__(self, block, block_num, class_num=100): 113 | super().__init__() 114 | 115 | self.in_channels = 64 116 | 117 | self.pre = nn.Sequential( 118 | nn.Conv2d(3, 64, 3, padding=1), 119 | nn.BatchNorm2d(64), 120 | nn.ReLU(inplace=True) 121 | ) 122 | 123 | self.stage1 = self._make_stage(block, block_num[0], 64, 1) 124 | self.stage2 = self._make_stage(block, block_num[1], 128, 2) 125 | self.stage3 = self._make_stage(block, block_num[2], 256, 2) 126 | self.stage4 = self._make_stage(block, block_num[3], 512, 2) 127 | 128 | self.linear = nn.Linear(self.in_channels, class_num) 129 | 130 | def forward(self, x): 131 | x = self.pre(x) 132 | 133 | x = self.stage1(x) 134 | x = self.stage2(x) 135 | x = self.stage3(x) 136 | x = self.stage4(x) 137 | 138 | x = F.adaptive_avg_pool2d(x, 1) 139 | x = x.view(x.size(0), -1) 140 | 141 | x = self.linear(x) 142 | 143 | return x 144 | 145 | 146 | def _make_stage(self, block, num, out_channels, stride): 147 | 148 | layers = [] 149 | layers.append(block(self.in_channels, out_channels, stride)) 150 | self.in_channels = out_channels * block.expansion 151 | 152 | while num - 1: 153 | layers.append(block(self.in_channels, out_channels, 1)) 154 | num -= 1 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def seresnet18(): 159 | return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2]) 160 | 161 | def seresnet34(): 162 | return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3]) 163 | 164 | def seresnet50(): 165 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3]) 166 | 167 | def seresnet101(): 168 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3]) 169 | 170 | def seresnet152(): 171 | return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3]) 172 | -------------------------------------------------------------------------------- /recovery/nn/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | """shufflenet in pytorch 2 | 3 | 4 | 5 | [1] Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, Jian Sun. 6 | 7 | ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices 8 | https://arxiv.org/abs/1707.01083v2 9 | """ 10 | 11 | from functools import partial 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class BasicConv2d(nn.Module): 18 | 19 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 20 | super().__init__() 21 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs) 22 | self.bn = nn.BatchNorm2d(output_channels) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | x = self.relu(x) 29 | return x 30 | 31 | class ChannelShuffle(nn.Module): 32 | 33 | def __init__(self, groups): 34 | super().__init__() 35 | self.groups = groups 36 | 37 | def forward(self, x): 38 | batchsize, channels, height, width = x.data.size() 39 | channels_per_group = int(channels / self.groups) 40 | 41 | #"""suppose a convolutional layer with g groups whose output has 42 | #g x n channels; we first reshape the output channel dimension 43 | #into (g, n)""" 44 | x = x.view(batchsize, self.groups, channels_per_group, height, width) 45 | 46 | #"""transposing and then flattening it back as the input of next layer.""" 47 | x = x.transpose(1, 2).contiguous() 48 | x = x.view(batchsize, -1, height, width) 49 | 50 | return x 51 | 52 | class DepthwiseConv2d(nn.Module): 53 | 54 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 55 | super().__init__() 56 | self.depthwise = nn.Sequential( 57 | nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs), 58 | nn.BatchNorm2d(output_channels) 59 | ) 60 | 61 | def forward(self, x): 62 | return self.depthwise(x) 63 | 64 | class PointwiseConv2d(nn.Module): 65 | def __init__(self, input_channels, output_channels, **kwargs): 66 | super().__init__() 67 | self.pointwise = nn.Sequential( 68 | nn.Conv2d(input_channels, output_channels, 1, **kwargs), 69 | nn.BatchNorm2d(output_channels) 70 | ) 71 | 72 | def forward(self, x): 73 | return self.pointwise(x) 74 | 75 | class ShuffleNetUnit(nn.Module): 76 | 77 | def __init__(self, input_channels, output_channels, stage, stride, groups): 78 | super().__init__() 79 | 80 | #"""Similar to [9], we set the number of bottleneck channels to 1/4 81 | #of the output channels for each ShuffleNet unit.""" 82 | self.bottlneck = nn.Sequential( 83 | PointwiseConv2d( 84 | input_channels, 85 | int(output_channels / 4), 86 | groups=groups 87 | ), 88 | nn.ReLU(inplace=True) 89 | ) 90 | 91 | #"""Note that for Stage 2, we do not apply group convolution on the first pointwise 92 | #layer because the number of input channels is relatively small.""" 93 | if stage == 2: 94 | self.bottlneck = nn.Sequential( 95 | PointwiseConv2d( 96 | input_channels, 97 | int(output_channels / 4), 98 | groups=groups 99 | ), 100 | nn.ReLU(inplace=True) 101 | ) 102 | 103 | self.channel_shuffle = ChannelShuffle(groups) 104 | 105 | self.depthwise = DepthwiseConv2d( 106 | int(output_channels / 4), 107 | int(output_channels / 4), 108 | 3, 109 | groups=int(output_channels / 4), 110 | stride=stride, 111 | padding=1 112 | ) 113 | 114 | self.expand = PointwiseConv2d( 115 | int(output_channels / 4), 116 | output_channels, 117 | groups=groups 118 | ) 119 | 120 | self.relu = nn.ReLU(inplace=True) 121 | self.fusion = self._add 122 | self.shortcut = nn.Sequential() 123 | 124 | #"""As for the case where ShuffleNet is applied with stride, 125 | #we simply make two modifications (see Fig 2 (c)): 126 | #(i) add a 3 × 3 average pooling on the shortcut path; 127 | #(ii) replace the element-wise addition with channel concatenation, 128 | #which makes it easy to enlarge channel dimension with little extra 129 | #computation cost. 130 | if stride != 1 or input_channels != output_channels: 131 | self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) 132 | 133 | self.expand = PointwiseConv2d( 134 | int(output_channels / 4), 135 | output_channels - input_channels, 136 | groups=groups 137 | ) 138 | 139 | self.fusion = self._cat 140 | 141 | def _add(self, x, y): 142 | return torch.add(x, y) 143 | 144 | def _cat(self, x, y): 145 | return torch.cat([x, y], dim=1) 146 | 147 | def forward(self, x): 148 | shortcut = self.shortcut(x) 149 | 150 | shuffled = self.bottlneck(x) 151 | shuffled = self.channel_shuffle(shuffled) 152 | shuffled = self.depthwise(shuffled) 153 | shuffled = self.expand(shuffled) 154 | 155 | output = self.fusion(shortcut, shuffled) 156 | output = self.relu(output) 157 | 158 | return output 159 | 160 | class ShuffleNet(nn.Module): 161 | 162 | def __init__(self, num_blocks, num_classes=100, groups=3): 163 | super().__init__() 164 | 165 | if groups == 1: 166 | out_channels = [24, 144, 288, 567] 167 | elif groups == 2: 168 | out_channels = [24, 200, 400, 800] 169 | elif groups == 3: 170 | out_channels = [24, 240, 480, 960] 171 | elif groups == 4: 172 | out_channels = [24, 272, 544, 1088] 173 | elif groups == 8: 174 | out_channels = [24, 384, 768, 1536] 175 | 176 | self.conv1 = BasicConv2d(3, out_channels[0], 3, padding=1, stride=1) 177 | self.input_channels = out_channels[0] 178 | 179 | self.stage2 = self._make_stage( 180 | ShuffleNetUnit, 181 | num_blocks[0], 182 | out_channels[1], 183 | stride=2, 184 | stage=2, 185 | groups=groups 186 | ) 187 | 188 | self.stage3 = self._make_stage( 189 | ShuffleNetUnit, 190 | num_blocks[1], 191 | out_channels[2], 192 | stride=2, 193 | stage=3, 194 | groups=groups 195 | ) 196 | 197 | self.stage4 = self._make_stage( 198 | ShuffleNetUnit, 199 | num_blocks[2], 200 | out_channels[3], 201 | stride=2, 202 | stage=4, 203 | groups=groups 204 | ) 205 | 206 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 207 | self.fc = nn.Linear(out_channels[3], num_classes) 208 | 209 | def forward(self, x): 210 | x = self.conv1(x) 211 | x = self.stage2(x) 212 | x = self.stage3(x) 213 | x = self.stage4(x) 214 | x = self.avg(x) 215 | x = x.view(x.size(0), -1) 216 | x = self.fc(x) 217 | 218 | return x 219 | 220 | def _make_stage(self, block, num_blocks, output_channels, stride, stage, groups): 221 | """make shufflenet stage 222 | 223 | Args: 224 | block: block type, shuffle unit 225 | out_channels: output depth channel number of this stage 226 | num_blocks: how many blocks per stage 227 | stride: the stride of the first block of this stage 228 | stage: stage index 229 | groups: group number of group convolution 230 | Return: 231 | return a shuffle net stage 232 | """ 233 | strides = [stride] + [1] * (num_blocks - 1) 234 | 235 | stage = [] 236 | 237 | for stride in strides: 238 | stage.append( 239 | block( 240 | self.input_channels, 241 | output_channels, 242 | stride=stride, 243 | stage=stage, 244 | groups=groups 245 | ) 246 | ) 247 | self.input_channels = output_channels 248 | 249 | return nn.Sequential(*stage) 250 | 251 | def shufflenet(): 252 | return ShuffleNet([4, 8, 4]) 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /recovery/nn/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | """shufflenetv2 in pytorch 2 | 3 | 4 | 5 | [1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun 6 | 7 | ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design 8 | https://arxiv.org/abs/1807.11164 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def channel_split(x, split): 17 | """split a tensor into two pieces along channel dimension 18 | Args: 19 | x: input tensor 20 | split:(int) channel size for each pieces 21 | """ 22 | assert x.size(1) == split * 2 23 | return torch.split(x, split, dim=1) 24 | 25 | def channel_shuffle(x, groups): 26 | """channel shuffle operation 27 | Args: 28 | x: input tensor 29 | groups: input branch number 30 | """ 31 | 32 | batch_size, channels, height, width = x.size() 33 | channels_per_group = int(channels // groups) 34 | 35 | x = x.view(batch_size, groups, channels_per_group, height, width) 36 | x = x.transpose(1, 2).contiguous() 37 | x = x.view(batch_size, -1, height, width) 38 | 39 | return x 40 | 41 | class ShuffleUnit(nn.Module): 42 | 43 | def __init__(self, in_channels, out_channels, stride): 44 | super().__init__() 45 | 46 | self.stride = stride 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | 50 | if stride != 1 or in_channels != out_channels: 51 | self.residual = nn.Sequential( 52 | nn.Conv2d(in_channels, in_channels, 1), 53 | nn.BatchNorm2d(in_channels), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 56 | nn.BatchNorm2d(in_channels), 57 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 58 | nn.BatchNorm2d(int(out_channels / 2)), 59 | nn.ReLU(inplace=True) 60 | ) 61 | 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 64 | nn.BatchNorm2d(in_channels), 65 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 66 | nn.BatchNorm2d(int(out_channels / 2)), 67 | nn.ReLU(inplace=True) 68 | ) 69 | else: 70 | self.shortcut = nn.Sequential() 71 | 72 | in_channels = int(in_channels / 2) 73 | self.residual = nn.Sequential( 74 | nn.Conv2d(in_channels, in_channels, 1), 75 | nn.BatchNorm2d(in_channels), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 78 | nn.BatchNorm2d(in_channels), 79 | nn.Conv2d(in_channels, in_channels, 1), 80 | nn.BatchNorm2d(in_channels), 81 | nn.ReLU(inplace=True) 82 | ) 83 | 84 | 85 | def forward(self, x): 86 | 87 | if self.stride == 1 and self.out_channels == self.in_channels: 88 | shortcut, residual = channel_split(x, int(self.in_channels / 2)) 89 | else: 90 | shortcut = x 91 | residual = x 92 | 93 | shortcut = self.shortcut(shortcut) 94 | residual = self.residual(residual) 95 | x = torch.cat([shortcut, residual], dim=1) 96 | x = channel_shuffle(x, 2) 97 | 98 | return x 99 | 100 | class ShuffleNetV2(nn.Module): 101 | 102 | def __init__(self, ratio=1, class_num=100): 103 | super().__init__() 104 | if ratio == 0.5: 105 | out_channels = [48, 96, 192, 1024] 106 | elif ratio == 1: 107 | out_channels = [116, 232, 464, 1024] 108 | elif ratio == 1.5: 109 | out_channels = [176, 352, 704, 1024] 110 | elif ratio == 2: 111 | out_channels = [244, 488, 976, 2048] 112 | else: 113 | ValueError('unsupported ratio number') 114 | 115 | self.pre = nn.Sequential( 116 | nn.Conv2d(3, 24, 3, padding=1), 117 | nn.BatchNorm2d(24) 118 | ) 119 | 120 | self.stage2 = self._make_stage(24, out_channels[0], 3) 121 | self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) 122 | self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) 123 | self.conv5 = nn.Sequential( 124 | nn.Conv2d(out_channels[2], out_channels[3], 1), 125 | nn.BatchNorm2d(out_channels[3]), 126 | nn.ReLU(inplace=True) 127 | ) 128 | 129 | self.fc = nn.Linear(out_channels[3], class_num) 130 | 131 | def forward(self, x): 132 | x = self.pre(x) 133 | x = self.stage2(x) 134 | x = self.stage3(x) 135 | x = self.stage4(x) 136 | x = self.conv5(x) 137 | x = F.adaptive_avg_pool2d(x, 1) 138 | x = x.view(x.size(0), -1) 139 | x = self.fc(x) 140 | 141 | return x 142 | 143 | def _make_stage(self, in_channels, out_channels, repeat): 144 | layers = [] 145 | layers.append(ShuffleUnit(in_channels, out_channels, 2)) 146 | 147 | while repeat: 148 | layers.append(ShuffleUnit(out_channels, out_channels, 1)) 149 | repeat -= 1 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def shufflenetv2(): 154 | return ShuffleNetV2() 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /recovery/nn/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | """squeezenet in pytorch 2 | 3 | 4 | 5 | [1] Song Han, Jeff Pool, John Tran, William J. Dally 6 | 7 | squeezenet: Learning both Weights and Connections for Efficient Neural Networks 8 | https://arxiv.org/abs/1506.02626 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class Fire(nn.Module): 16 | 17 | def __init__(self, in_channel, out_channel, squzee_channel): 18 | 19 | super().__init__() 20 | self.squeeze = nn.Sequential( 21 | nn.Conv2d(in_channel, squzee_channel, 1), 22 | nn.BatchNorm2d(squzee_channel), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | self.expand_1x1 = nn.Sequential( 27 | nn.Conv2d(squzee_channel, int(out_channel / 2), 1), 28 | nn.BatchNorm2d(int(out_channel / 2)), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | self.expand_3x3 = nn.Sequential( 33 | nn.Conv2d(squzee_channel, int(out_channel / 2), 3, padding=1), 34 | nn.BatchNorm2d(int(out_channel / 2)), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | def forward(self, x): 39 | 40 | x = self.squeeze(x) 41 | x = torch.cat([ 42 | self.expand_1x1(x), 43 | self.expand_3x3(x) 44 | ], 1) 45 | 46 | return x 47 | 48 | class SqueezeNet(nn.Module): 49 | 50 | """mobile net with simple bypass""" 51 | def __init__(self, class_num=100): 52 | 53 | super().__init__() 54 | self.stem = nn.Sequential( 55 | nn.Conv2d(3, 96, 3, padding=1), 56 | nn.BatchNorm2d(96), 57 | nn.ReLU(inplace=True), 58 | nn.MaxPool2d(2, 2) 59 | ) 60 | 61 | self.fire2 = Fire(96, 128, 16) 62 | self.fire3 = Fire(128, 128, 16) 63 | self.fire4 = Fire(128, 256, 32) 64 | self.fire5 = Fire(256, 256, 32) 65 | self.fire6 = Fire(256, 384, 48) 66 | self.fire7 = Fire(384, 384, 48) 67 | self.fire8 = Fire(384, 512, 64) 68 | self.fire9 = Fire(512, 512, 64) 69 | 70 | self.conv10 = nn.Conv2d(512, class_num, 1) 71 | self.avg = nn.AdaptiveAvgPool2d(1) 72 | self.maxpool = nn.MaxPool2d(2, 2) 73 | 74 | def forward(self, x): 75 | x = self.stem(x) 76 | 77 | f2 = self.fire2(x) 78 | f3 = self.fire3(f2) + f2 79 | f4 = self.fire4(f3) 80 | f4 = self.maxpool(f4) 81 | 82 | f5 = self.fire5(f4) + f4 83 | f6 = self.fire6(f5) 84 | f7 = self.fire7(f6) + f6 85 | f8 = self.fire8(f7) 86 | f8 = self.maxpool(f8) 87 | 88 | f9 = self.fire9(f8) 89 | c10 = self.conv10(f9) 90 | 91 | x = self.avg(c10) 92 | x = x.view(x.size(0), -1) 93 | 94 | return x 95 | 96 | def squeezenet(class_num=100): 97 | return SqueezeNet(class_num=class_num) 98 | -------------------------------------------------------------------------------- /recovery/nn/models/stochasticdepth.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet with stochastic depth 3 | 4 | [1] Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, Kilian Weinberger 5 | Deep Networks with Stochastic Depth 6 | 7 | https://arxiv.org/abs/1603.09382v3 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from torch.distributions.bernoulli import Bernoulli 12 | import random 13 | 14 | 15 | class StochasticDepthBasicBlock(torch.jit.ScriptModule): 16 | 17 | expansion=1 18 | 19 | def __init__(self, p, in_channels, out_channels, stride=1): 20 | super().__init__() 21 | 22 | #self.p = torch.tensor(p).float() 23 | self.p = p 24 | self.residual = nn.Sequential( 25 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), 26 | nn.BatchNorm2d(out_channels), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(out_channels, out_channels * StochasticDepthBasicBlock.expansion, kernel_size=3, padding=1), 29 | nn.BatchNorm2d(out_channels) 30 | ) 31 | 32 | self.shortcut = nn.Sequential() 33 | 34 | if stride != 1 or in_channels != out_channels * StochasticDepthBasicBlock.expansion: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels * StochasticDepthBasicBlock.expansion, kernel_size=1, stride=stride), 37 | nn.BatchNorm2d(out_channels) 38 | ) 39 | def survival(self): 40 | var = torch.bernoulli(torch.tensor(self.p).float()) 41 | return torch.equal(var, torch.tensor(1).float().to(var.device)) 42 | 43 | @torch.jit.script_method 44 | def forward(self, x): 45 | 46 | if self.training: 47 | if self.survival(): 48 | # official torch implementation 49 | # function ResidualDrop:updateOutput(input) 50 | # local skip_forward = self.skip:forward(input) 51 | # self.output:resizeAs(skip_forward):copy(skip_forward) 52 | # if self.train then 53 | # if self.gate then -- only compute convolutional output when gate is open 54 | # self.output:add(self.net:forward(input)) 55 | # end 56 | # else 57 | # self.output:add(self.net:forward(input):mul(1-self.deathRate)) 58 | # end 59 | # return self.output 60 | # end 61 | 62 | # paper: 63 | # Hl = ReLU(bl*fl(Hl−1) + id(Hl−1)). 64 | 65 | # paper and their official implementation are different 66 | # paper use relu after output 67 | # official implementation dosen't 68 | # 69 | # other implementions which use relu: 70 | # https://github.com/jiweeo/pytorch-stochastic-depth/blob/a6f95aaffee82d273c1cd73d9ed6ef0718c6683d/models/resnet.py 71 | # https://github.com/dblN/stochastic_depth_keras/blob/master/train.py 72 | 73 | # implementations which doesn't use relu: 74 | # https://github.com/transcranial/stochastic-depth/blob/master/stochastic-depth.ipynb 75 | # https://github.com/shamangary/Pytorch-Stochastic-Depth-Resnet/blob/master/TYY_stodepth_lineardecay.py 76 | 77 | # I will just stick with the official implementation, I think 78 | # whether add relu after residual won't effect the network 79 | # performance too much 80 | x = self.residual(x) + self.shortcut(x) 81 | else: 82 | # If bl = 0, the ResBlock reduces to the identity function 83 | x = self.shortcut(x) 84 | 85 | else: 86 | x = self.residual(x) * self.p + self.shortcut(x) 87 | 88 | return x 89 | 90 | 91 | class StochasticDepthBottleNeck(torch.jit.ScriptModule): 92 | """Residual block for resnet over 50 layers 93 | 94 | """ 95 | expansion = 4 96 | def __init__(self, p, in_channels, out_channels, stride=1): 97 | super().__init__() 98 | 99 | self.p = p 100 | self.residual = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 102 | nn.BatchNorm2d(out_channels), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 105 | nn.BatchNorm2d(out_channels), 106 | nn.ReLU(inplace=True), 107 | nn.Conv2d(out_channels, out_channels * StochasticDepthBottleNeck.expansion, kernel_size=1, bias=False), 108 | nn.BatchNorm2d(out_channels * StochasticDepthBottleNeck.expansion), 109 | ) 110 | 111 | self.shortcut = nn.Sequential() 112 | 113 | if stride != 1 or in_channels != out_channels * StochasticDepthBottleNeck.expansion: 114 | self.shortcut = nn.Sequential( 115 | nn.Conv2d(in_channels, out_channels * StochasticDepthBottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 116 | nn.BatchNorm2d(out_channels * StochasticDepthBottleNeck.expansion) 117 | ) 118 | 119 | def survival(self): 120 | var = torch.bernoulli(torch.tensor(self.p).float()) 121 | return torch.equal(var, torch.tensor(1).float().to(var.device)) 122 | 123 | @torch.jit.script_method 124 | def forward(self, x): 125 | 126 | if self.training: 127 | if self.survival(): 128 | x = self.residual(x) + self.shortcut(x) 129 | else: 130 | x = self.shortcut(x) 131 | else: 132 | x = self.residual(x) * self.p + self.shortcut(x) 133 | 134 | return x 135 | 136 | class StochasticDepthResNet(nn.Module): 137 | 138 | def __init__(self, block, num_block, num_classes=100): 139 | super().__init__() 140 | 141 | self.in_channels = 64 142 | self.conv1 = nn.Sequential( 143 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 144 | nn.BatchNorm2d(64), 145 | nn.ReLU(inplace=True) 146 | ) 147 | 148 | self.step = (1 - 0.5) / (sum(num_block) - 1) 149 | self.pl = 1 150 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 151 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 152 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 153 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 154 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.fc = nn.Linear(512 * block.expansion, num_classes) 156 | 157 | def _make_layer(self, block, out_channels, num_blocks, stride): 158 | 159 | strides = [stride] + [1] * (num_blocks - 1) 160 | layers = [] 161 | for stride in strides: 162 | layers.append(block(self.pl, self.in_channels, out_channels, stride)) 163 | self.in_channels = out_channels * block.expansion 164 | self.pl -= self.step 165 | 166 | return nn.Sequential(*layers) 167 | 168 | def forward(self, x): 169 | output = self.conv1(x) 170 | output = self.conv2_x(output) 171 | output = self.conv3_x(output) 172 | output = self.conv4_x(output) 173 | output = self.conv5_x(output) 174 | output = self.avg_pool(output) 175 | output = output.view(output.size(0), -1) 176 | output = self.fc(output) 177 | 178 | return output 179 | 180 | 181 | def stochastic_depth_resnet18(): 182 | """ return a ResNet 18 object 183 | """ 184 | return StochasticDepthResNet(StochasticDepthBasicBlock, [2, 2, 2, 2]) 185 | 186 | def stochastic_depth_resnet34(): 187 | """ return a ResNet 34 object 188 | """ 189 | return StochasticDepthResNet(StochasticDepthBasicBlock, [3, 4, 6, 3]) 190 | 191 | def stochastic_depth_resnet50(): 192 | 193 | """ return a ResNet 50 object 194 | """ 195 | return StochasticDepthResNet(StochasticDepthBottleNeck, [3, 4, 6, 3]) 196 | 197 | def stochastic_depth_resnet101(): 198 | """ return a ResNet 101 object 199 | """ 200 | return StochasticDepthResNet(StochasticDepthBottleNeck, [3, 4, 23, 3]) 201 | 202 | def stochastic_depth_resnet152(): 203 | """ return a ResNet 152 object 204 | """ 205 | return StochasticDepthResNet(StochasticDepthBottleNeck, [3, 8, 36, 3]) 206 | 207 | -------------------------------------------------------------------------------- /recovery/nn/models/vgg.py: -------------------------------------------------------------------------------- 1 | """vgg in pytorch 2 | 3 | 4 | [1] Karen Simonyan, Andrew Zisserman 5 | 6 | Very Deep Convolutional Networks for Large-Scale Image Recognition. 7 | https://arxiv.org/abs/1409.1556v6 8 | """ 9 | '''VGG11/13/16/19 in Pytorch.''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | cfg = { 15 | 'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 16 | 'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 17 | 'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 18 | 'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 19 | } 20 | 21 | class VGG(nn.Module): 22 | 23 | def __init__(self, features, num_class=100): 24 | super().__init__() 25 | self.features = features 26 | 27 | self.classifier = nn.Sequential( 28 | nn.Linear(512, 4096), 29 | nn.ReLU(inplace=True), 30 | nn.Dropout(), 31 | nn.Linear(4096, 4096), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(), 34 | nn.Linear(4096, num_class) 35 | ) 36 | 37 | def forward(self, x): 38 | output = self.features(x) 39 | output = output.view(output.size()[0], -1) 40 | output = self.classifier(output) 41 | 42 | return output 43 | 44 | def make_layers(cfg, batch_norm=False): 45 | layers = [] 46 | 47 | input_channel = 3 48 | for l in cfg: 49 | if l == 'M': 50 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 51 | continue 52 | 53 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 54 | 55 | if batch_norm: 56 | layers += [nn.BatchNorm2d(l)] 57 | 58 | layers += [nn.ReLU(inplace=True)] 59 | input_channel = l 60 | 61 | return nn.Sequential(*layers) 62 | 63 | def vgg11_bn(num_classes): 64 | return VGG(make_layers(cfg['A'], batch_norm=True), num_class=num_classes) 65 | 66 | def vgg13_bn(num_classes): 67 | return VGG(make_layers(cfg['B'], batch_norm=True), num_class=num_classes) 68 | 69 | def vgg16_bn(num_classes): 70 | return VGG(make_layers(cfg['D'], batch_norm=True), num_class=num_classes) 71 | 72 | def vgg19_bn(num_classes): 73 | return VGG(make_layers(cfg['E'], batch_norm=True), num_class=num_classes) 74 | 75 | 76 | -------------------------------------------------------------------------------- /recovery/nn/models/wideresidual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class WideBasic(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels, stride=1): 8 | super().__init__() 9 | self.residual = nn.Sequential( 10 | nn.BatchNorm2d(in_channels), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d( 13 | in_channels, 14 | out_channels, 15 | kernel_size=3, 16 | stride=stride, 17 | padding=1 18 | ), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(), 22 | nn.Conv2d( 23 | out_channels, 24 | out_channels, 25 | kernel_size=3, 26 | stride=1, 27 | padding=1 28 | ) 29 | ) 30 | 31 | self.shortcut = nn.Sequential() 32 | 33 | if in_channels != out_channels or stride != 1: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels, 1, stride=stride) 36 | ) 37 | 38 | def forward(self, x): 39 | 40 | residual = self.residual(x) 41 | shortcut = self.shortcut(x) 42 | 43 | return residual + shortcut 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, num_classes, block, depth=50, widen_factor=1): 47 | super().__init__() 48 | 49 | self.depth = depth 50 | k = widen_factor 51 | l = int((depth - 4) / 6) 52 | self.in_channels = 16 53 | self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1) 54 | self.conv2 = self._make_layer(block, 16 * k, l, 1) 55 | self.conv3 = self._make_layer(block, 32 * k, l, 2) 56 | self.conv4 = self._make_layer(block, 64 * k, l, 2) 57 | self.bn = nn.BatchNorm2d(64 * k) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 60 | self.linear = nn.Linear(64 * k, num_classes) 61 | 62 | def forward(self, x): 63 | x = self.init_conv(x) 64 | x = self.conv2(x) 65 | x = self.conv3(x) 66 | x = self.conv4(x) 67 | x = self.bn(x) 68 | x = self.relu(x) 69 | x = self.avg_pool(x) 70 | x = x.view(x.size(0), -1) 71 | x = self.linear(x) 72 | 73 | return x 74 | 75 | def _make_layer(self, block, out_channels, num_blocks, stride): 76 | """make resnet layers(by layer i didnt mean this 'layer' was the 77 | same as a neuron netowork layer, ex. conv layer), one layer may 78 | contain more than one residual block 79 | 80 | Args: 81 | block: block type, basic block or bottle neck block 82 | out_channels: output depth channel number of this layer 83 | num_blocks: how many blocks per layer 84 | stride: the stride of the first block of this layer 85 | 86 | Return: 87 | return a resnet layer 88 | """ 89 | 90 | # we have num_block blocks per layer, the first block 91 | # could be 1 or 2, other blocks would always be 1 92 | strides = [stride] + [1] * (num_blocks - 1) 93 | layers = [] 94 | for stride in strides: 95 | layers.append(block(self.in_channels, out_channels, stride)) 96 | self.in_channels = out_channels 97 | 98 | return nn.Sequential(*layers) 99 | 100 | 101 | # Table 9: Best WRN performance over various datasets, single run results. 102 | def wideresnet(depth=40, widen_factor=10): 103 | net = WideResNet(100, WideBasic, depth=depth, widen_factor=widen_factor) 104 | return net -------------------------------------------------------------------------------- /recovery/nn/models/xception.py: -------------------------------------------------------------------------------- 1 | """xception in pytorch 2 | 3 | 4 | [1] François Chollet 5 | 6 | Xception: Deep Learning with Depthwise Separable Convolutions 7 | https://arxiv.org/abs/1610.02357 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | class SeperableConv2d(nn.Module): 14 | 15 | #***Figure 4. An “extreme” version of our Inception module, 16 | #with one spatial convolution per output channel of the 1x1 17 | #convolution.""" 18 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 19 | 20 | super().__init__() 21 | self.depthwise = nn.Conv2d( 22 | input_channels, 23 | input_channels, 24 | kernel_size, 25 | groups=input_channels, 26 | bias=False, 27 | **kwargs 28 | ) 29 | 30 | self.pointwise = nn.Conv2d(input_channels, output_channels, 1, bias=False) 31 | 32 | def forward(self, x): 33 | x = self.depthwise(x) 34 | x = self.pointwise(x) 35 | 36 | return x 37 | 38 | class EntryFlow(nn.Module): 39 | 40 | def __init__(self): 41 | 42 | super().__init__() 43 | self.conv1 = nn.Sequential( 44 | nn.Conv2d(3, 32, 3, padding=1, bias=False), 45 | nn.BatchNorm2d(32), 46 | nn.ReLU(inplace=True) 47 | ) 48 | 49 | self.conv2 = nn.Sequential( 50 | nn.Conv2d(32, 64, 3, padding=1, bias=False), 51 | nn.BatchNorm2d(64), 52 | nn.ReLU(inplace=True) 53 | ) 54 | 55 | self.conv3_residual = nn.Sequential( 56 | SeperableConv2d(64, 128, 3, padding=1), 57 | nn.BatchNorm2d(128), 58 | nn.ReLU(inplace=True), 59 | SeperableConv2d(128, 128, 3, padding=1), 60 | nn.BatchNorm2d(128), 61 | nn.MaxPool2d(3, stride=2, padding=1), 62 | ) 63 | 64 | self.conv3_shortcut = nn.Sequential( 65 | nn.Conv2d(64, 128, 1, stride=2), 66 | nn.BatchNorm2d(128), 67 | ) 68 | 69 | self.conv4_residual = nn.Sequential( 70 | nn.ReLU(inplace=True), 71 | SeperableConv2d(128, 256, 3, padding=1), 72 | nn.BatchNorm2d(256), 73 | nn.ReLU(inplace=True), 74 | SeperableConv2d(256, 256, 3, padding=1), 75 | nn.BatchNorm2d(256), 76 | nn.MaxPool2d(3, stride=2, padding=1) 77 | ) 78 | 79 | self.conv4_shortcut = nn.Sequential( 80 | nn.Conv2d(128, 256, 1, stride=2), 81 | nn.BatchNorm2d(256), 82 | ) 83 | 84 | #no downsampling 85 | self.conv5_residual = nn.Sequential( 86 | nn.ReLU(inplace=True), 87 | SeperableConv2d(256, 728, 3, padding=1), 88 | nn.BatchNorm2d(728), 89 | nn.ReLU(inplace=True), 90 | SeperableConv2d(728, 728, 3, padding=1), 91 | nn.BatchNorm2d(728), 92 | nn.MaxPool2d(3, 1, padding=1) 93 | ) 94 | 95 | #no downsampling 96 | self.conv5_shortcut = nn.Sequential( 97 | nn.Conv2d(256, 728, 1), 98 | nn.BatchNorm2d(728) 99 | ) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.conv2(x) 104 | residual = self.conv3_residual(x) 105 | shortcut = self.conv3_shortcut(x) 106 | x = residual + shortcut 107 | residual = self.conv4_residual(x) 108 | shortcut = self.conv4_shortcut(x) 109 | x = residual + shortcut 110 | residual = self.conv5_residual(x) 111 | shortcut = self.conv5_shortcut(x) 112 | x = residual + shortcut 113 | 114 | return x 115 | 116 | class MiddleFLowBlock(nn.Module): 117 | 118 | def __init__(self): 119 | super().__init__() 120 | 121 | self.shortcut = nn.Sequential() 122 | self.conv1 = nn.Sequential( 123 | nn.ReLU(inplace=True), 124 | SeperableConv2d(728, 728, 3, padding=1), 125 | nn.BatchNorm2d(728) 126 | ) 127 | self.conv2 = nn.Sequential( 128 | nn.ReLU(inplace=True), 129 | SeperableConv2d(728, 728, 3, padding=1), 130 | nn.BatchNorm2d(728) 131 | ) 132 | self.conv3 = nn.Sequential( 133 | nn.ReLU(inplace=True), 134 | SeperableConv2d(728, 728, 3, padding=1), 135 | nn.BatchNorm2d(728) 136 | ) 137 | 138 | def forward(self, x): 139 | residual = self.conv1(x) 140 | residual = self.conv2(residual) 141 | residual = self.conv3(residual) 142 | 143 | shortcut = self.shortcut(x) 144 | 145 | return shortcut + residual 146 | 147 | class MiddleFlow(nn.Module): 148 | def __init__(self, block): 149 | super().__init__() 150 | 151 | #"""then through the middle flow which is repeated eight times""" 152 | self.middel_block = self._make_flow(block, 8) 153 | 154 | def forward(self, x): 155 | x = self.middel_block(x) 156 | return x 157 | 158 | def _make_flow(self, block, times): 159 | flows = [] 160 | for i in range(times): 161 | flows.append(block()) 162 | 163 | return nn.Sequential(*flows) 164 | 165 | 166 | class ExitFLow(nn.Module): 167 | 168 | def __init__(self): 169 | super().__init__() 170 | self.residual = nn.Sequential( 171 | nn.ReLU(), 172 | SeperableConv2d(728, 728, 3, padding=1), 173 | nn.BatchNorm2d(728), 174 | nn.ReLU(), 175 | SeperableConv2d(728, 1024, 3, padding=1), 176 | nn.BatchNorm2d(1024), 177 | nn.MaxPool2d(3, stride=2, padding=1) 178 | ) 179 | 180 | self.shortcut = nn.Sequential( 181 | nn.Conv2d(728, 1024, 1, stride=2), 182 | nn.BatchNorm2d(1024) 183 | ) 184 | 185 | self.conv = nn.Sequential( 186 | SeperableConv2d(1024, 1536, 3, padding=1), 187 | nn.BatchNorm2d(1536), 188 | nn.ReLU(inplace=True), 189 | SeperableConv2d(1536, 2048, 3, padding=1), 190 | nn.BatchNorm2d(2048), 191 | nn.ReLU(inplace=True) 192 | ) 193 | 194 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 195 | 196 | def forward(self, x): 197 | shortcut = self.shortcut(x) 198 | residual = self.residual(x) 199 | output = shortcut + residual 200 | output = self.conv(output) 201 | output = self.avgpool(output) 202 | 203 | return output 204 | 205 | class Xception(nn.Module): 206 | 207 | def __init__(self, block, num_class=100): 208 | super().__init__() 209 | self.entry_flow = EntryFlow() 210 | self.middel_flow = MiddleFlow(block) 211 | self.exit_flow = ExitFLow() 212 | 213 | self.fc = nn.Linear(2048, num_class) 214 | 215 | def forward(self, x): 216 | x = self.entry_flow(x) 217 | x = self.middel_flow(x) 218 | x = self.exit_flow(x) 219 | x = x.view(x.size(0), -1) 220 | x = self.fc(x) 221 | 222 | return x 223 | 224 | def xception(): 225 | return Xception(MiddleFLowBlock) 226 | 227 | 228 | -------------------------------------------------------------------------------- /recovery/nn/modules.py: -------------------------------------------------------------------------------- 1 | """For monkey-patching into meta-learning frameworks.""" 2 | from typing import Iterator 3 | import torch 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | from functools import partial 7 | import warnings 8 | 9 | from torch.nn.parameter import Parameter 10 | 11 | from ..consts import BENCHMARK 12 | torch.backends.cudnn.benchmark = BENCHMARK 13 | 14 | DEBUG = False # Emit warning messages when patching. Use this to bootstrap new architectures. 15 | 16 | class MetaMonkey(torch.nn.Module): 17 | """Trace a networks and then replace its module calls with functional calls. 18 | 19 | This allows for backpropagation w.r.t to weights for "normal" PyTorch networks. 20 | """ 21 | 22 | def __init__(self, net): 23 | """Init with network.""" 24 | super().__init__() 25 | self.net = net 26 | self.nparameters = OrderedDict(net.named_parameters()) 27 | 28 | 29 | def forward(self, inputs, parameters=None): 30 | """Live Patch ... :> ...""" 31 | # If no parameter dictionary is given, everything is normal 32 | if parameters is None: 33 | return self.net(inputs) 34 | 35 | # But if not ... 36 | param_gen = iter(parameters.values()) 37 | method_pile = [] 38 | counter = 0 39 | 40 | for name, module in self.net.named_modules(): 41 | if isinstance(module, torch.nn.Conv2d): 42 | ext_weight = next(param_gen) 43 | if module.bias is not None: 44 | ext_bias = next(param_gen) 45 | else: 46 | ext_bias = None 47 | 48 | method_pile.append(module.forward) 49 | module.forward = partial(F.conv2d, weight=ext_weight, bias=ext_bias, stride=module.stride, 50 | padding=module.padding, dilation=module.dilation, groups=module.groups) 51 | elif isinstance(module, torch.nn.BatchNorm2d): 52 | if module.momentum is None: 53 | exponential_average_factor = 0.0 54 | else: 55 | exponential_average_factor = module.momentum 56 | 57 | if module.training and module.track_running_stats: 58 | if module.num_batches_tracked is not None: 59 | module.num_batches_tracked += 1 60 | if module.momentum is None: # use cumulative moving average 61 | exponential_average_factor = 1.0 / float(module.num_batches_tracked) 62 | else: # use exponential moving average 63 | exponential_average_factor = module.momentum 64 | 65 | ext_weight = next(param_gen) 66 | ext_bias = next(param_gen) 67 | method_pile.append(module.forward) 68 | module.forward = partial(F.batch_norm, running_mean=module.running_mean, running_var=module.running_var, 69 | weight=ext_weight, bias=ext_bias, 70 | training=module.training or not module.track_running_stats, 71 | momentum=exponential_average_factor, eps=module.eps) 72 | 73 | elif isinstance(module, torch.nn.Linear): 74 | lin_weights = next(param_gen) 75 | lin_bias = next(param_gen) 76 | method_pile.append(module.forward) 77 | module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias) 78 | 79 | elif next(module.parameters(), None) is None: 80 | # Pass over modules that do not contain parameters 81 | pass 82 | elif isinstance(module, torch.nn.Sequential): 83 | # Pass containers 84 | pass 85 | else: 86 | # Warn for other containers 87 | if DEBUG: 88 | warnings.warn(f'Patching for module {module.__class__} is not implemented.') 89 | 90 | output = self.net(inputs) 91 | 92 | # Undo Patch 93 | for name, module in self.net.named_modules(): 94 | if isinstance(module, torch.nn.modules.conv.Conv2d): 95 | module.forward = method_pile.pop(0) 96 | elif isinstance(module, torch.nn.BatchNorm2d): 97 | module.forward = method_pile.pop(0) 98 | elif isinstance(module, torch.nn.Linear): 99 | module.forward = method_pile.pop(0) 100 | 101 | return output 102 | -------------------------------------------------------------------------------- /recovery/nn/revnet.py: -------------------------------------------------------------------------------- 1 | """https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/iRevNet.py. 2 | 3 | Code for "i-RevNet: Deep Invertible Networks" 4 | https://openreview.net/pdf?id=HJsjkMb0Z 5 | ICLR, 2018 6 | 7 | 8 | (c) Joern-Henrik Jacobsen, 2018 9 | """ 10 | 11 | """ 12 | MIT License 13 | 14 | Copyright (c) 2018 Jörn Jacobsen 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | from .revnet_utils import split, merge, injective_pad, psi 39 | 40 | 41 | class irevnet_block(nn.Module): 42 | """This is an i-revnet block from Jacobsen et al.""" 43 | 44 | def __init__(self, in_ch, out_ch, stride=1, first=False, dropout_rate=0., 45 | affineBN=True, mult=4): 46 | """Build invertible bottleneck block.""" 47 | super(irevnet_block, self).__init__() 48 | self.first = first 49 | self.pad = 2 * out_ch - in_ch 50 | self.stride = stride 51 | self.inj_pad = injective_pad(self.pad) 52 | self.psi = psi(stride) 53 | if self.pad != 0 and stride == 1: 54 | in_ch = out_ch * 2 55 | print('') 56 | print('| Injective iRevNet |') 57 | print('') 58 | layers = [] 59 | if not first: 60 | layers.append(nn.BatchNorm2d(in_ch // 2, affine=affineBN)) 61 | layers.append(nn.ReLU(inplace=True)) 62 | layers.append(nn.Conv2d(in_ch // 2, int(out_ch // mult), kernel_size=3, 63 | stride=stride, padding=1, bias=False)) 64 | layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) 65 | layers.append(nn.ReLU(inplace=True)) 66 | layers.append(nn.Conv2d(int(out_ch // mult), int(out_ch // mult), 67 | kernel_size=3, padding=1, bias=False)) 68 | layers.append(nn.Dropout(p=dropout_rate)) 69 | layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) 70 | layers.append(nn.ReLU(inplace=True)) 71 | layers.append(nn.Conv2d(int(out_ch // mult), out_ch, kernel_size=3, 72 | padding=1, bias=False)) 73 | self.bottleneck_block = nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | """Bijective or injective block forward.""" 77 | if self.pad != 0 and self.stride == 1: 78 | x = merge(x[0], x[1]) 79 | x = self.inj_pad.forward(x) 80 | x1, x2 = split(x) 81 | x = (x1, x2) 82 | x1 = x[0] 83 | x2 = x[1] 84 | Fx2 = self.bottleneck_block(x2) 85 | if self.stride == 2: 86 | x1 = self.psi.forward(x1) 87 | x2 = self.psi.forward(x2) 88 | y1 = Fx2 + x1 89 | return (x2, y1) 90 | 91 | def inverse(self, x): 92 | """Bijective or injecitve block inverse.""" 93 | x2, y1 = x[0], x[1] 94 | if self.stride == 2: 95 | x2 = self.psi.inverse(x2) 96 | Fx2 = - self.bottleneck_block(x2) 97 | x1 = Fx2 + y1 98 | if self.stride == 2: 99 | x1 = self.psi.inverse(x1) 100 | if self.pad != 0 and self.stride == 1: 101 | x = merge(x1, x2) 102 | x = self.inj_pad.inverse(x) 103 | x1, x2 = split(x) 104 | x = (x1, x2) 105 | else: 106 | x = (x1, x2) 107 | return x 108 | 109 | 110 | class iRevNet(nn.Module): 111 | """This is an i-revnet from Jacobsen et al.""" 112 | 113 | def __init__(self, nBlocks, nStrides, nClasses, nChannels=None, init_ds=2, 114 | dropout_rate=0., affineBN=True, in_shape=None, mult=4): 115 | """Init with e.g. nBlocks=[18, 18, 18], nStrides = [1, 2, 2].""" 116 | super(iRevNet, self).__init__() 117 | self.ds = in_shape[2] // 2**(nStrides.count(2) + init_ds // 2) 118 | self.init_ds = init_ds 119 | self.in_ch = in_shape[0] * 2**self.init_ds 120 | self.nBlocks = nBlocks 121 | self.first = True 122 | 123 | print('') 124 | print(' == Building iRevNet %d == ' % (sum(nBlocks) * 3 + 1)) 125 | if not nChannels: 126 | nChannels = [self.in_ch // 2, self.in_ch // 2 * 4, 127 | self.in_ch // 2 * 4**2, self.in_ch // 2 * 4**3] 128 | 129 | self.init_psi = psi(self.init_ds) 130 | self.stack = self.irevnet_stack(irevnet_block, nChannels, nBlocks, 131 | nStrides, dropout_rate=dropout_rate, 132 | affineBN=affineBN, in_ch=self.in_ch, 133 | mult=mult) 134 | self.bn1 = nn.BatchNorm2d(nChannels[-1] * 2, momentum=0.9) 135 | self.linear = nn.Linear(nChannels[-1] * 2, nClasses) 136 | 137 | def irevnet_stack(self, _block, nChannels, nBlocks, nStrides, dropout_rate, 138 | affineBN, in_ch, mult): 139 | """Create stack of irevnet blocks.""" 140 | block_list = nn.ModuleList() 141 | strides = [] 142 | channels = [] 143 | for channel, depth, stride in zip(nChannels, nBlocks, nStrides): 144 | strides = strides + ([stride] + [1] * (depth - 1)) 145 | channels = channels + ([channel] * depth) 146 | for channel, stride in zip(channels, strides): 147 | block_list.append(_block(in_ch, channel, stride, 148 | first=self.first, 149 | dropout_rate=dropout_rate, 150 | affineBN=affineBN, mult=mult)) 151 | in_ch = 2 * channel 152 | self.first = False 153 | return block_list 154 | 155 | def forward(self, x, return_bijection=False): 156 | """Irevnet forward.""" 157 | n = self.in_ch // 2 158 | if self.init_ds != 0: 159 | x = self.init_psi.forward(x) 160 | out = (x[:, :n, :, :], x[:, n:, :, :]) 161 | for block in self.stack: 162 | out = block.forward(out) 163 | out_bij = merge(out[0], out[1]) 164 | out = F.relu(self.bn1(out_bij)) 165 | out = F.avg_pool2d(out, self.ds) 166 | out = out.view(out.size(0), -1) 167 | out = self.linear(out) 168 | if return_bijection: 169 | return out, out_bij 170 | else: 171 | return out 172 | 173 | def inverse(self, out_bij): 174 | """Irevnet inverse.""" 175 | out = split(out_bij) 176 | for i in range(len(self.stack)): 177 | out = self.stack[-1 - i].inverse(out) 178 | out = merge(out[0], out[1]) 179 | if self.init_ds != 0: 180 | x = self.init_psi.inverse(out) 181 | else: 182 | x = out 183 | return x 184 | 185 | 186 | if __name__ == '__main__': 187 | model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], 188 | nChannels=None, nClasses=1000, init_ds=2, 189 | dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], 190 | mult=4) 191 | y = model(torch.randn(1, 3, 224, 224)) 192 | print(y.size()) 193 | -------------------------------------------------------------------------------- /recovery/nn/revnet_utils.py: -------------------------------------------------------------------------------- 1 | """https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py. 2 | 3 | Code for "i-RevNet: Deep Invertible Networks" 4 | https://openreview.net/pdf?id=HJsjkMb0Z 5 | ICLR, 2018 6 | 7 | 8 | (c) Joern-Henrik Jacobsen, 2018 9 | """ 10 | 11 | """ 12 | MIT License 13 | 14 | Copyright (c) 2018 Jörn Jacobsen 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | import torch 36 | import torch.nn as nn 37 | 38 | from torch.nn import Parameter 39 | 40 | 41 | def split(x): 42 | n = int(x.size()[1] / 2) 43 | x1 = x[:, :n, :, :].contiguous() 44 | x2 = x[:, n:, :, :].contiguous() 45 | return x1, x2 46 | 47 | 48 | def merge(x1, x2): 49 | return torch.cat((x1, x2), 1) 50 | 51 | 52 | class injective_pad(nn.Module): 53 | def __init__(self, pad_size): 54 | super(injective_pad, self).__init__() 55 | self.pad_size = pad_size 56 | self.pad = nn.ZeroPad2d((0, 0, 0, pad_size)) 57 | 58 | def forward(self, x): 59 | x = x.permute(0, 2, 1, 3) 60 | x = self.pad(x) 61 | return x.permute(0, 2, 1, 3) 62 | 63 | def inverse(self, x): 64 | return x[:, :x.size(1) - self.pad_size, :, :] 65 | 66 | 67 | class psi(nn.Module): 68 | def __init__(self, block_size): 69 | super(psi, self).__init__() 70 | self.block_size = block_size 71 | self.block_size_sq = block_size * block_size 72 | 73 | def inverse(self, input): 74 | output = input.permute(0, 2, 3, 1) 75 | (batch_size, d_height, d_width, d_depth) = output.size() 76 | s_depth = int(d_depth / self.block_size_sq) 77 | s_width = int(d_width * self.block_size) 78 | s_height = int(d_height * self.block_size) 79 | t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth) 80 | spl = t_1.split(self.block_size, 3) 81 | stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl] 82 | output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth) 83 | output = output.permute(0, 3, 1, 2) 84 | return output.contiguous() 85 | 86 | def forward(self, input): 87 | output = input.permute(0, 2, 3, 1) 88 | (batch_size, s_height, s_width, s_depth) = output.size() 89 | d_depth = s_depth * self.block_size_sq 90 | d_height = int(s_height / self.block_size) 91 | t_1 = output.split(self.block_size, 2) 92 | stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] 93 | output = torch.stack(stack, 1) 94 | output = output.permute(0, 2, 1, 3) 95 | output = output.permute(0, 3, 1, 2) 96 | return output.contiguous() 97 | 98 | 99 | class ListModule(object): 100 | def __init__(self, module, prefix, *args): 101 | self.module = module 102 | self.prefix = prefix 103 | self.num_module = 0 104 | for new_module in args: 105 | self.append(new_module) 106 | 107 | def append(self, new_module): 108 | if not isinstance(new_module, nn.Module): 109 | raise ValueError('Not a Module') 110 | else: 111 | self.module.add_module(self.prefix + str(self.num_module), new_module) 112 | self.num_module += 1 113 | 114 | def __len__(self): 115 | return self.num_module 116 | 117 | def __getitem__(self, i): 118 | if i < 0 or i >= self.num_module: 119 | raise IndexError('Out of bound') 120 | return getattr(self.module, self.prefix + str(i)) 121 | 122 | 123 | def get_all_params(var, all_params): 124 | if isinstance(var, Parameter): 125 | all_params[id(var)] = var.nelement() 126 | elif hasattr(var, "creator") and var.creator is not None: 127 | if var.creator.previous_functions is not None: 128 | for j in var.creator.previous_functions: 129 | get_all_params(j[0], all_params) 130 | elif hasattr(var, "previous_functions"): 131 | for j in var.previous_functions: 132 | get_all_params(j[0], all_params) 133 | -------------------------------------------------------------------------------- /recovery/optimization_strategy.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Optimization setups.""" 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | def training_strategy(strategy, lr=None, epochs=None, dryrun=False): 8 | """Parse training strategy.""" 9 | if strategy == 'conservative': 10 | defs = ConservativeStrategy(lr, epochs, dryrun) 11 | elif strategy == 'adam': 12 | defs = AdamStrategy(lr, epochs, dryrun) 13 | else: 14 | raise ValueError('Unknown training strategy.') 15 | return defs 16 | 17 | 18 | @dataclass 19 | class Strategy: 20 | """Default usual parameters, not intended for parsing.""" 21 | 22 | epochs : int 23 | batch_size : int 24 | optimizer : str 25 | lr : float 26 | scheduler : str 27 | weight_decay : float 28 | validate : int 29 | warmup: bool 30 | dryrun : bool 31 | dropout : float 32 | augmentations : bool 33 | 34 | def __init__(self, lr=None, epochs=None, dryrun=False): 35 | """Defaulted parameters. Apply overwrites from args.""" 36 | if epochs is not None: 37 | self.epochs = epochs 38 | if lr is not None: 39 | self.lr = lr 40 | if dryrun: 41 | self.dryrun = dryrun 42 | self.validate = 10 43 | 44 | @dataclass 45 | class ConservativeStrategy(Strategy): 46 | """Default usual parameters, defines a config object.""" 47 | 48 | def __init__(self, lr=None, epochs=None, dryrun=False): 49 | """Initialize training hyperparameters.""" 50 | self.lr = 0.1 51 | self.epochs = 120 52 | self.batch_size = 128 53 | self.optimizer = 'SGD' 54 | self.scheduler = 'linear' 55 | self.warmup = False 56 | self.weight_decay : float = 5e-4 57 | self.dropout = 0.0 58 | self.augmentations = True 59 | self.dryrun = False 60 | super().__init__(lr=None, epochs=None, dryrun=False) 61 | 62 | 63 | @dataclass 64 | class AdamStrategy(Strategy): 65 | """Start slowly. Use a tame Adam.""" 66 | 67 | def __init__(self, lr=None, epochs=None, dryrun=False): 68 | """Initialize training hyperparameters.""" 69 | self.lr = 1e-3 / 10 70 | self.epochs = 120 71 | self.batch_size = 32 72 | self.optimizer = 'AdamW' 73 | self.scheduler = 'linear' 74 | self.warmup = True 75 | self.weight_decay : float = 5e-4 76 | self.dropout = 0.0 77 | self.augmentations = True 78 | self.dryrun = False 79 | super().__init__(lr=None, epochs=None, dryrun=False) 80 | -------------------------------------------------------------------------------- /recovery/training.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Implement the .train function.""" 3 | import os 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | 11 | from collections import defaultdict 12 | 13 | from .consts import BENCHMARK, NON_BLOCKING 14 | torch.backends.cudnn.benchmark = BENCHMARK 15 | 16 | class GradualWarmupScheduler(_LRScheduler): 17 | """Gradually warm-up(increasing) learning rate in optimizer. 18 | 19 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 20 | 21 | Args: 22 | optimizer (Optimizer): Wrapped optimizer. 23 | multiplier: target learning rate = base lr * multiplier 24 | total_epoch: target learning rate is reached at total_epoch, gradually 25 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 26 | 27 | """ 28 | 29 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 30 | """Initialize the warm-up start. 31 | 32 | Usage: 33 | 34 | scheduler_normal = torch.optim.lr_scheduler.MultiStepLR(optimizer) 35 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_normal) 36 | """ 37 | self.multiplier = multiplier 38 | if self.multiplier < 1.: 39 | raise ValueError('multiplier should be greater thant or equal to 1.') 40 | self.total_epoch = total_epoch 41 | self.after_scheduler = after_scheduler 42 | self.finished = False 43 | super().__init__(optimizer) 44 | 45 | def get_lr(self): 46 | if self.last_epoch > self.total_epoch: 47 | if self.after_scheduler: 48 | if not self.finished: 49 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 50 | self.finished = True 51 | return self.after_scheduler.get_lr() 52 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 53 | 54 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 55 | 56 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 57 | if epoch is None: 58 | epoch = self.last_epoch + 1 59 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 60 | if self.last_epoch <= self.total_epoch: 61 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 62 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 63 | param_group['lr'] = lr 64 | else: 65 | if epoch is None: 66 | self.after_scheduler.step(metrics, None) 67 | else: 68 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 69 | 70 | def step(self, epoch=None, metrics=None): 71 | if type(self.after_scheduler) != ReduceLROnPlateau: 72 | if self.finished and self.after_scheduler: 73 | if epoch is None: 74 | self.after_scheduler.step(None) 75 | else: 76 | self.after_scheduler.step(epoch - self.total_epoch) 77 | else: 78 | return super(GradualWarmupScheduler, self).step(epoch) 79 | else: 80 | self.step_ReduceLROnPlateau(metrics, epoch) 81 | 82 | def train(model, loss_fn, trainloader, validloader, defs, setup=dict(dtype=torch.float, device=torch.device('cpu')), ckpt_path=None, finetune=False): 83 | """Run the main interface. Train a network with specifications from the Strategy object.""" 84 | stats = defaultdict(list) 85 | optimizer, scheduler = set_optimizer(model, defs) 86 | 87 | for epoch in tqdm(range(1, defs.epochs+1)): 88 | if finetune: 89 | model.eval() 90 | else: 91 | model.train() 92 | step(model, loss_fn, trainloader, optimizer, scheduler, defs, setup, stats) 93 | 94 | if epoch % defs.validate == 0 or epoch == defs.epochs: 95 | model.eval() 96 | validate(model, loss_fn, validloader, defs, setup, stats) 97 | # Print information about loss and accuracy 98 | print_status(epoch, loss_fn, optimizer, stats) 99 | if ckpt_path is not None: 100 | torch.save(model.state_dict(), os.path.join(ckpt_path, f'model_{epoch}.pt')) 101 | 102 | if defs.dryrun: 103 | break 104 | if not (np.isfinite(stats['train_losses'][-1])): 105 | print('Loss is NaN/Inf ... terminating early ...') 106 | break 107 | 108 | return stats 109 | 110 | def step(model, loss_fn, dataloader, optimizer, scheduler, defs, setup, stats): 111 | """Step through one epoch.""" 112 | epoch_loss, epoch_metric = 0, 0 113 | for batch, (inputs, targets) in enumerate(dataloader): 114 | # Prep Mini-Batch 115 | optimizer.zero_grad() 116 | 117 | # Transfer to GPU 118 | inputs = inputs.to(**setup) 119 | targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) 120 | 121 | # Get loss 122 | outputs = model(inputs) 123 | loss, _, _ = loss_fn(outputs, targets) 124 | 125 | 126 | epoch_loss += loss.item() 127 | 128 | loss.backward() 129 | optimizer.step() 130 | 131 | metric, name, _ = loss_fn.metric(outputs, targets) 132 | epoch_metric += metric.item() 133 | 134 | if defs.scheduler == 'cyclic': 135 | scheduler.step() 136 | if defs.dryrun: 137 | break 138 | if defs.scheduler == 'linear': 139 | scheduler.step() 140 | 141 | stats['train_losses'].append(epoch_loss / (batch + 1)) 142 | stats['train_' + name].append(epoch_metric / (batch + 1)) 143 | 144 | 145 | def validate(model, loss_fn, dataloader, defs, setup, stats): 146 | """Validate model effectiveness of val dataset.""" 147 | epoch_loss, epoch_metric = 0, 0 148 | with torch.no_grad(): 149 | for batch, (inputs, targets) in enumerate(dataloader): 150 | # Transfer to GPU 151 | inputs = inputs.to(**setup) 152 | targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) 153 | 154 | # Get loss and metric 155 | outputs = model(inputs) 156 | loss, _, _ = loss_fn(outputs, targets) 157 | metric, name, _ = loss_fn.metric(outputs, targets) 158 | 159 | epoch_loss += loss.item() 160 | epoch_metric += metric.item() 161 | 162 | if defs.dryrun: 163 | break 164 | 165 | stats['valid_losses'].append(epoch_loss / (batch + 1)) 166 | stats['valid_' + name].append(epoch_metric / (batch + 1)) 167 | 168 | def set_optimizer(model, defs): 169 | """Build model optimizer and scheduler from defs. 170 | 171 | The linear scheduler drops the learning rate in intervals. 172 | # Example: epochs=160 leads to drops at 60, 100, 140. 173 | """ 174 | if defs.optimizer == 'SGD': 175 | optimizer = torch.optim.SGD(model.parameters(), lr=defs.lr, weight_decay=defs.weight_decay) 176 | elif defs.optimizer == 'AdamW': 177 | optimizer = torch.optim.AdamW(model.parameters(), lr=defs.lr, weight_decay=defs.weight_decay) 178 | 179 | if defs.scheduler == 'linear': 180 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 181 | milestones=[120 // 2.667, 120 // 1.6, 182 | 120 // 1.142], gamma=0.1) 183 | # Scheduler is fixed to 120 epochs so that calls with fewer epochs are equal in lr drops. 184 | 185 | if defs.warmup: 186 | scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=scheduler) 187 | 188 | return optimizer, scheduler 189 | 190 | 191 | def print_status(epoch, loss_fn, optimizer, stats): 192 | """Print basic console printout every defs.validation epochs.""" 193 | current_lr = optimizer.param_groups[0]['lr'] 194 | name, format = loss_fn.metric() 195 | print(f'Epoch: {epoch}| lr: {current_lr:.4f} | ' 196 | f'Train loss is {stats["train_losses"][-1]:6.4f}, Train {name}: {stats["train_" + name][-1]:{format}} | ' 197 | f'Val loss is {stats["valid_losses"][-1]:6.4f}, Val {name}: {stats["valid_" + name][-1]:{format}} |') 198 | -------------------------------------------------------------------------------- /recovery/utils.py: -------------------------------------------------------------------------------- 1 | """ build upon https://github.com/JonasGeiping/invertinggradients""" 2 | """Various utilities.""" 3 | 4 | import os 5 | import csv 6 | 7 | import torch 8 | import random 9 | import numpy as np 10 | 11 | import socket 12 | import datetime 13 | 14 | 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def system_startup(args=None, defs=None): 19 | """Print useful system information.""" 20 | # Choose GPU device and print status information: 21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 22 | setup = dict(device=device, dtype=torch.float) # non_blocking=NON_BLOCKING 23 | print('Currently evaluating -------------------------------:') 24 | print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) 25 | print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.') 26 | if args is not None: 27 | print(args) 28 | if defs is not None: 29 | print(repr(defs)) 30 | if torch.cuda.is_available(): 31 | print(f'GPU : {torch.cuda.get_device_name(device=device)}') 32 | return setup 33 | 34 | def set_random_seed(seed=233): 35 | """233 = 144 + 89 is my favorite number.""" 36 | torch.manual_seed(seed + 1) 37 | torch.cuda.manual_seed(seed + 2) 38 | torch.cuda.manual_seed_all(seed + 3) 39 | np.random.seed(seed + 4) 40 | torch.cuda.manual_seed_all(seed + 5) 41 | random.seed(seed + 6) 42 | 43 | def set_deterministic(): 44 | """Switch pytorch into a deterministic computation mode.""" 45 | torch.backends.cudnn.deterministic = True 46 | torch.backends.cudnn.benchmark = False 47 | 48 | 49 | 50 | def plot(tensor, dm=None, ds=None): 51 | tensor = tensor.clone().detach() 52 | if dm is not None or ds is not None: 53 | tensor.mul_(ds).add_(dm).clamp_(0, 1) 54 | if tensor.shape[0] == 1: 55 | return plt.imshow(tensor[0].permute(1, 2, 0).cpu()); 56 | else: 57 | fig, axes = plt.subplots(1, tensor.shape[0], figsize=(12, tensor.shape[0]*12)) 58 | for i, im in enumerate(tensor): 59 | axes[i].imshow(im.permute(1, 2, 0).cpu()); 60 | 61 | -------------------------------------------------------------------------------- /recovery_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import transforms 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import recovery as rs 7 | import argparse 8 | 9 | 10 | recons_config = dict(signed=True, 11 | boxed=True, 12 | cost_fn='sim', 13 | indices='def', 14 | weights='equal', 15 | lr=0.04, 16 | optim='adamw', 17 | restarts=5, 18 | max_iterations=10000, 19 | total_variation=1e-2, 20 | init='randn', 21 | filter='none', 22 | lr_decay=True, 23 | scoring_choice='loss') 24 | 25 | def new_plot(tensor, title="", path=None): 26 | if tensor.shape[0] == 1: 27 | return plt.imshow(tensor[0].permute(1, 2, 0).cpu()) 28 | else: 29 | fig, axes = plt.subplots(1, tensor.shape[0], figsize=(2 * tensor.shape[0], 3)) 30 | for i, im in enumerate(tensor): 31 | axes[i].imshow(im.permute(1, 2, 0).cpu()) 32 | plt.title(title) 33 | plt.savefig(path) 34 | 35 | def process_recons_results(result, ground_truth, figpath, recons_path, filename): 36 | output_list, stats, history_list, x_optimal = result 37 | x_optimal = x_optimal.detach().cpu() 38 | test_mse = (x_optimal - ground_truth.cpu()).pow(2).mean() 39 | test_psnr = rs.metrics.psnr(x_optimal, ground_truth, factor=1/ds) 40 | title = f"MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | " 41 | new_plot(torch.cat([ground_truth, x_optimal]), title, path=os.path.join(figpath, f'{filename}.png')) 42 | torch.save({'output_list': output_list.cpu(), 'stats': stats, 'history_list': history_list, 'x_optimal': x_optimal}, open(os.path.join(recons_path, f'{filename}.pth'), 'wb')) 43 | 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser(description='simple argparse.') 48 | parser.add_argument('--model', default='ConvNet', type=str, help='Vision model.') 49 | parser.add_argument('--dataset', default='cifar10', type=str) 50 | parser.add_argument('--ft_samples', default=32, type=int) 51 | parser.add_argument('--unlearn_samples', default=1, type=int) 52 | parser.add_argument('--epochs', default=1, type=int, help='updated epochs') 53 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 54 | parser.add_argument('--seed', default=0, type=int, help='random seed') 55 | parser.add_argument('--model_save_folder', default='results/models', type=str, help='folder of pretrained models') 56 | 57 | args = parser.parse_args() 58 | 59 | print(args.__dict__) 60 | 61 | img_size = 32 if 'cifar' in args.dataset else 96 62 | excluded_num = 10000 if 'cifar' in args.dataset else 1000 63 | np.random.seed(args.seed) 64 | torch.manual_seed(args.seed) 65 | torch.cuda.manual_seed(args.seed) 66 | 67 | 68 | load_folder_name = f'{args.model.lower()}_{args.dataset.lower()}_ex{excluded_num}_s0' 69 | save_folder_name = f'ex{args.ft_samples}_un{args.unlearn_samples}_ep{args.epochs}_seed{args.seed}' 70 | save_folder = os.path.join(args.model_save_folder, load_folder_name, save_folder_name) 71 | os.makedirs(save_folder, exist_ok=True) 72 | 73 | final_dict = torch.load(os.path.join(args.model_save_folder, load_folder_name, 'final.pth')) 74 | setup = rs.utils.system_startup() 75 | defs = rs.training_strategy('conservative') 76 | defs.lr = args.lr 77 | defs.epochs = args.epochs 78 | defs.batch_size = 128 79 | defs.optimizer = 'SGD' 80 | defs.scheduler = 'linear' 81 | defs.warmup = False 82 | defs.weight_decay = 0.0 83 | defs.dropout = 0.0 84 | defs.augmentations = False 85 | defs.dryrun = False 86 | 87 | 88 | loss_fn, _tl, validloader, num_classes, _exd, dmlist, dslist = rs.construct_dataloaders(args.dataset.lower(), defs, data_path=f'datasets/{args.dataset.lower()}', normalize=False, exclude_num=excluded_num) 89 | dm = torch.as_tensor(dmlist, **setup)[:, None, None] 90 | ds = torch.as_tensor(dslist, **setup)[:, None, None] 91 | normalizer = transforms.Normalize(dmlist, dslist) 92 | 93 | 94 | # *** used for batch case *** 95 | excluded_data = final_dict['excluded_data'] 96 | index = torch.tensor(np.random.choice(len(excluded_data[0]), args.ft_samples, replace=False)) 97 | print("Batch index", index.tolist()) 98 | X_all, y_all = excluded_data[0][index], excluded_data[1][index] 99 | print("FT data size", X_all.shape, y_all.shape) 100 | trainset_all = rs.data_processing.SubTrainDataset(X_all, y_all, transform=transforms.Normalize(dmlist, dslist)) 101 | trainloader_all = torch.utils.data.DataLoader(trainset_all, batch_size=min(defs.batch_size, len(trainset_all)), shuffle=True, num_workers=8, pin_memory=True) 102 | 103 | 104 | ## load state dict 105 | state_dict = final_dict['net_sd'] 106 | 107 | 108 | model_pretrain, _ = rs.construct_model(args.model, num_classes=num_classes, num_channels=3) 109 | model_pretrain.load_state_dict(state_dict) 110 | model_pretrain.eval() 111 | 112 | 113 | model_ft, _ = rs.construct_model(args.model, num_classes=num_classes, num_channels=3) 114 | model_ft.load_state_dict(state_dict) 115 | model_ft.eval() 116 | 117 | 118 | print("Train full model.") 119 | ft_folder = os.path.join(save_folder, 'full_ft') 120 | os.makedirs(ft_folder, exist_ok=True) 121 | model_ft.to(**setup) 122 | ft_stats = rs.train(model_ft, loss_fn, trainloader_all, validloader, defs, setup=setup, ckpt_path=ft_folder, finetune=True) 123 | model_ft.cpu() 124 | resdict = {'tr_args': args.__dict__, 125 | 'tr_strat': defs.__dict__, 126 | 'stats': ft_stats, 127 | 'batch_index': index, 128 | 'train_data': (X_all, y_all)} 129 | torch.save(resdict, os.path.join(ft_folder, 'finetune_params.pth')) 130 | ft_diffs = [(ft_param.detach().cpu() - org_param.detach().cpu()).detach() for (ft_param, org_param) in zip(model_ft.parameters(), model_pretrain.parameters())] 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | print("Exact unlearn each sample and test the exact and approximate unlearn") 140 | 141 | model_ft.zero_grad() 142 | model_ft.to(**setup) 143 | rec_machine_ft = rs.GradientReconstructor(model_ft, (dm, ds), recons_config, num_images=args.unlearn_samples) 144 | 145 | model_pretrain.zero_grad() 146 | model_pretrain.to(**setup) 147 | rec_machine_pretrain = rs.GradientReconstructor(model_pretrain, (dm, ds), recons_config, num_images=args.unlearn_samples) 148 | 149 | 150 | for test_id in range(args.ft_samples // args.unlearn_samples): 151 | unlearn_ids = list(range(test_id * args.unlearn_samples, (test_id + 1) * args.unlearn_samples)) 152 | print(f"Unlearn {unlearn_ids}") 153 | unlearn_folder = os.path.join(save_folder, f'unlearn_ft_batch{test_id}') 154 | os.makedirs(unlearn_folder, exist_ok=True) 155 | X_list = [xt for i, xt in enumerate(X_all) if i not in unlearn_ids] 156 | if len(X_list) > 0: 157 | X = torch.stack([xt for i, xt in enumerate(X_all) if i not in unlearn_ids]) 158 | y = torch.tensor([yt for i, yt in enumerate(y_all) if i not in unlearn_ids]) 159 | print("Exact unlearn data size", X.shape, y.shape) 160 | trainset_unlearn = rs.data_processing.SubTrainDataset(X, y, transform=transforms.Normalize(dmlist, dslist)) 161 | trainloader_unlearn = torch.utils.data.DataLoader(trainset_unlearn, batch_size=min(defs.batch_size, len(trainset_unlearn)), shuffle=True, num_workers=8, pin_memory=True) 162 | 163 | X_unlearn = torch.stack([xt for i, xt in enumerate(X_all) if i in unlearn_ids]) 164 | y_unlearn = torch.tensor([yt for i, yt in enumerate(y_all) if i in unlearn_ids]) 165 | 166 | print(f"***** Train unlearned model (withouth {unlearn_ids}) *****") 167 | model_unlearn, _ = rs.construct_model(args.model, num_classes=num_classes, num_channels=3) 168 | model_unlearn.load_state_dict(state_dict) 169 | model_unlearn.eval() 170 | model_unlearn.to(**setup) 171 | if len(X_list) > 0: 172 | unlearn_stats = rs.train(model_unlearn, loss_fn, trainloader_unlearn, validloader, defs, setup=setup, ckpt_path=unlearn_folder, finetune=True) 173 | else: 174 | unlearn_stats = None 175 | model_unlearn.cpu() 176 | resdict = {'tr_args': args.__dict__, 177 | 'tr_strat': defs.__dict__, 178 | 'stats': unlearn_stats, 179 | 'unlearn_batch_id': test_id} 180 | torch.save(resdict, os.path.join(unlearn_folder, 'finetune_params.pth')) 181 | # unlearn_params = [param.detach() for param in model_unlearn.parameters()] 182 | un_diffs = [(un_param.detach().cpu() - org_param.detach().cpu()).detach() for (un_param, org_param) in zip(model_unlearn.parameters(), model_pretrain.parameters())] 183 | 184 | print("Start reconstruction.") 185 | 186 | 187 | 188 | recons_folder = os.path.join(save_folder, 'recons') 189 | figure_folder = os.path.join(save_folder, 'figures') 190 | os.makedirs(recons_folder , exist_ok=True) 191 | os.makedirs(figure_folder, exist_ok=True) 192 | # reconstruction 193 | 194 | 195 | exact_diff = [-(ft_diff * args.ft_samples - un_diff * len(X_list)).detach().to(**setup) for (ft_diff, un_diff) in zip(ft_diffs, un_diffs)] 196 | rec_machine_pretrain.model.eval() 197 | result_exact = rec_machine_pretrain.reconstruct(exact_diff, normalizer(X_unlearn.to(**setup)), y_unlearn.to(setup['device']), img_shape=(3, img_size, img_size)) 198 | process_recons_results(result_exact, X_unlearn, figpath=figure_folder, recons_path=recons_folder, filename=f'exact{test_id}_{index[test_id].item()}') 199 | 200 | approx_diff = [p.detach().to(**setup) for p in rs.recovery_algo.loss_steps(model_ft, normalizer(X_unlearn.to(**setup)), y_unlearn.to(setup['device']), lr=1, local_steps=1)] # lr is not important in cosine 201 | rec_machine_ft.model.eval() 202 | result_approx = rec_machine_ft.reconstruct(approx_diff, normalizer(X_unlearn.to(**setup)), y_unlearn.to(setup['device']), img_shape=(3, img_size, img_size)) 203 | process_recons_results(result_approx, X_unlearn, figpath=figure_folder, recons_path=recons_folder, filename=f'approx{test_id}_{index[test_id].item()}') 204 | 205 | 206 | 207 | 208 | 209 | 210 | --------------------------------------------------------------------------------