├── models ├── __init__.py ├── get_model.py ├── perturb.py ├── vgg.py └── resnet.py ├── .gitignore ├── Dockerfile ├── checkpoints └── src │ └── tiny_imagenet_resnet20_final │ └── perturb_750.pth ├── run.sh ├── split_dataset.py ├── README.md ├── dataloader.py ├── data └── get_dataset.py ├── train_tgt.py ├── utils.py └── train_src.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.npz 3 | .vscode/ 4 | __pycache__/ 5 | .DS_store/ 6 | .env/ 7 | 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | WORKDIR /workspace 4 | RUN pip install absl-py sklearn tensorboard wandb matplotlib gpustat 5 | -------------------------------------------------------------------------------- /checkpoints/src/tiny_imagenet_resnet20_final/perturb_750.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JWoong148/MetaPerturb/HEAD/checkpoints/src/tiny_imagenet_resnet20_final/perturb_750.pth -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CODE_DIR=codes/ 2 | SAVE_DIR=checkpoints/ 3 | 4 | # Meta-training 5 | NUM_SPLIT=10 6 | SRC_LR=1e-3 7 | SRC_DATA=tiny_imagenet 8 | SRC_IMG_SIZE=32 9 | SRC_STEPS=750 10 | SRC_MODEL=resnet20 11 | SRC_DETAIL=final 12 | SRC_EXP_NAME=${SRC_DATA}_${SRC_MODEL}_${SRC_DETAIL} 13 | GPUS=0,1,2,3,4,0,1,2,3,4 14 | 15 | # Meta-testing 16 | TGT_LR=1e-3 17 | TGT_DATA=stanford_cars 18 | TGT_IMG_SIZE=84 19 | TGT_MODEL=resnet20 20 | NOISE_COEFF=1 21 | TGT_DETAIL= 22 | TGT_EXP_NAME=${SRC_DATA}_${SRC_MODEL}_${SRC_STEPS}_to_${TGT_DATA}_${TGT_MODEL}_${TGT_DETAIL} 23 | 24 | wandb on 25 | if [ "$1" = "src" ]; then 26 | python train_src.py \ 27 | --num_run 5 \ 28 | --code_dir $CODE_DIR \ 29 | --save_dir $SAVE_DIR \ 30 | --num_split $NUM_SPLIT \ 31 | --lr $SRC_LR \ 32 | --data $SRC_DATA \ 33 | --img_size $SRC_IMG_SIZE \ 34 | --model $SRC_MODEL \ 35 | --train_steps $SRC_STEPS \ 36 | --exp_name $SRC_EXP_NAME \ 37 | --gpus $GPUS \ 38 | --num_workers 2 39 | elif [ "$1" = "tgt" ]; then 40 | python train_tgt.py \ 41 | --code_dir $CODE_DIR \ 42 | --save_dir $SAVE_DIR \ 43 | --lr $TGT_LR \ 44 | --data $TGT_DATA \ 45 | --img_size $TGT_IMG_SIZE \ 46 | --model $TGT_MODEL \ 47 | --src_name $SRC_EXP_NAME \ 48 | --src_steps $SRC_STEPS \ 49 | --exp_name $TGT_EXP_NAME \ 50 | --gpus $2 51 | else 52 | echo Wrong Argument 53 | fi 54 | -------------------------------------------------------------------------------- /models/get_model.py: -------------------------------------------------------------------------------- 1 | from .vgg import ConvNet4, ConvNet6, VGG9 2 | from .resnet import resnet20, resnet32, resnet44, resnet56, resnet18, resnet34 3 | 4 | 5 | def get_model(model_name, num_classes, img_size, do_perturb): 6 | kwargs = {"num_classes": num_classes, "img_size": img_size, "do_perturb": do_perturb} 7 | conv_channels = -1 8 | if "conv" in model_name: 9 | conv_channels = int(model_name.split("_")[-1]) 10 | model_name = "_".join(model_name.split("_")[:-1]) 11 | 12 | if model_name == "lenet": 13 | raise DeprecationWarning 14 | elif model_name == "conv4": 15 | return ConvNet4(conv_channels=conv_channels, **kwargs) 16 | elif model_name == "conv6": 17 | return ConvNet6(conv_channels=conv_channels, **kwargs) 18 | elif model_name == "vgg9": 19 | return VGG9(**kwargs) 20 | elif model_name == "resnet20": 21 | return resnet20(**kwargs) 22 | elif model_name == "resnet32": 23 | return resnet32(**kwargs) 24 | elif model_name == "resnet44": 25 | return resnet44(**kwargs) 26 | elif model_name == "resnet56": 27 | return resnet56(**kwargs) 28 | elif model_name == "resnet18": 29 | return resnet18(**kwargs) 30 | elif model_name == "resnet34": 31 | return resnet34(**kwargs) 32 | else: 33 | raise NotImplementedError 34 | -------------------------------------------------------------------------------- /split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | train_labels_np = np.load("data/tiny_imagenet/train_labels.npy") 6 | test_labels_np = np.load("data/tiny_imagenet/valid_labels.npy") 7 | 8 | test_labels = [[] for _ in range(200)] 9 | for i in range(50 * 200): 10 | test_labels[test_labels_np[i]].append(i) 11 | 12 | os.makedirs("data/tiny_imagenet/10_split/") 13 | for i in range(5): 14 | label_list = range(40 * i, 40 * (i + 1)) 15 | idx_train = np.concatenate([np.arange(500 * l, 500 * (l + 1)) for l in label_list]) 16 | idx_test = np.concatenate([test_labels[l] for l in label_list]) 17 | 18 | for idx in idx_train: 19 | assert train_labels_np[idx] in label_list 20 | for idx in idx_test: 21 | assert test_labels_np[idx] in label_list 22 | 23 | np.random.shuffle(idx_train) 24 | idx_train, idx_valid = idx_train[: len(idx_train) // 2], idx_train[len(idx_train) // 2 :] 25 | 26 | np.savez_compressed( 27 | "data/tiny_imagenet/10_split/split_{}.npz".format(2 * i), 28 | label_list=label_list, 29 | idx_train=idx_train, 30 | idx_valid=idx_valid, 31 | idx_test=idx_test, 32 | ) 33 | np.savez_compressed( 34 | "data/tiny_imagenet/10_split/split_{}.npz".format(2 * i + 1), 35 | label_list=label_list, 36 | idx_train=idx_valid, 37 | idx_valid=idx_train, 38 | idx_test=idx_test, 39 | ) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaPerturb: Transferable Regularizer for Heterogeneous Tasks and Architectures 2 | 3 | This is the **Pytorch implementation** for the paper *MetaPerturb: Transferable Regularizer for Heterogeneous Tasks and Architectures* (accepted at **NeurIPS 2020 , spotlight presentation**) 4 | 5 | 6 | 7 | ## Links 8 | 9 | [Paper](https://papers.nips.cc/paper/2020/file/84ddfb34126fc3a48ee38d7044e87276-Paper.pdf) 10 | 11 | 12 | 13 | 14 | ## Abstract 15 | 16 | ![](figures/concept.png) 17 | 18 | Regularization and transfer learning are two popular techniques to enhance model generalization on unseen data, which is a fundamental problem of machine learning. Regularization techniques are versatile, as they are task- and architecture-agnostic, but they do not exploit a large amount of data available. Transfer learning methods learn to transfer knowledge from one domain to another, but may not generalize across tasks and architectures, and may introduce new training cost for adapting to the target task. To bridge the gap between the two, we propose a transferable perturbation, *MetaPerturb*, which is meta-learned to improve generalization performance on unseen data. MetaPerturb is implemented as a set-based lightweight network that is agnostic to the size and the order of the input, which is shared across the layers. Then, we propose a meta-learning framework, to jointly train the perturbation function over heterogeneous tasks in parallel. As MetaPerturb is a set-function trained over diverse distributions across layers and tasks, it can generalize to heterogeneous tasks and architectures. We validate the efficacy and generality of MetaPerturb trained on a specific source domain and architecture, by applying it to the training of diverse neural architectures on heterogeneous target datasets against various regularizers and fine-tuning. The results show that the networks trained with MetaPerturb significantly outperform the baselines on most of the tasks and architectures, with a negligible increase in the parameter size and no hyperparameters to tune. 19 | 20 | 21 | 22 | __Contribution of this work__ 23 | 24 | - We propose a lightweight and versatile perturbation function that can transfer the knowledge of a source task to **heterogeneous target tasks and architectures**. 25 | - We propose **a novel meta-learning framework in the form of joint training**, which allows to efficiently perform meta-learning on large-scale datasets in the standard learning framework. 26 | - We validate our perturbation function on a large number of datasets and architectures, on which it successfully **outperforms existing regularizers and finetuning**. 27 | 28 | 29 | 30 | 50 | 51 | ## Prerequisites 52 | We recommend to use attached Dockerfile. 53 | 54 | ## Running code 55 | 56 | To run meta-training, 57 | ``` 58 | run.sh src 59 | ``` 60 | 61 | To run meta-testing, 62 | ``` 63 | run.sh tgt 64 | ``` 65 | 66 | To change training configuration, change arguments at the top line in run.sh (ex. SRC_MODEL, TGT_DATA, ...) 67 | 68 | ## Citation 69 | If you found the provided code useful, please cite our work. 70 | ``` 71 | @inproceedings{ 72 | ryu2020metaperturb, 73 | title={MetaPerturb: Transferable Regularizer for Heterogeneous Tasks and Architectures}, 74 | author={Jeong Un Ryu, JaeWoong Shin, Hae Beom Lee, Sung Ju Hwang}, 75 | booktitle={NeurIPS}, 76 | year={2020} 77 | } 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader, SubsetRandomSampler 3 | from torchvision import transforms 4 | from data.get_dataset import get_dataset, NUM_CLASSES 5 | 6 | 7 | def get_transform( 8 | img_size, random_crop=False, random_horizontal_flip=False, normalize_mean=(0.5,), normalize_std=(0.5,) 9 | ): 10 | transform_list = [transforms.Resize((img_size, img_size))] 11 | if random_crop: 12 | transform_list.append(transforms.RandomCrop(img_size, padding=(4 if img_size == 32 else 8))) 13 | if random_horizontal_flip: 14 | transform_list.append(transforms.RandomHorizontalFlip()) 15 | transform_list.append(transforms.ToTensor()) 16 | transform_list.append(transforms.Normalize(normalize_mean, normalize_std)) 17 | return transforms.Compose(transform_list) 18 | 19 | 20 | def get_src_dataloader(name, split, total_split, img_size, batch_size, num_workers=1): 21 | if name not in ["tiny_imagenet", "mini_imagenet", "cifar_100", "meta_imagenet"]: 22 | raise NotImplementedError 23 | 24 | # Get split information 25 | split_info = np.load("data/{}/{}_split/split_{}.npz".format(name, total_split, split)) 26 | label_list = list(split_info["label_list"]) 27 | idx_train = list(split_info["idx_train"]) 28 | idx_valid = list(split_info["idx_valid"]) 29 | # idx_test = list(split_info["idx_test"]) 30 | 31 | if name == "meta_imagenet": 32 | target_transform = None 33 | else: 34 | 35 | def target_transform(y): 36 | return label_list.index(y) 37 | 38 | transform_train = get_transform(img_size, random_crop=True, random_horizontal_flip=True) 39 | # transform_test = get_transform(img_size) 40 | 41 | train_ds = get_dataset(name, train=True, transform=transform_train, target_transform=target_transform) 42 | # test_ds = get_dataset(name, train=False, transform=transform_test, target_transform=target_transform) 43 | 44 | kwargs = {"batch_size": batch_size, "num_workers": num_workers, "pin_memory": False, "drop_last": True} 45 | train_loader = DataLoader(train_ds, sampler=SubsetRandomSampler(idx_train), **kwargs) 46 | valid_loader = DataLoader(train_ds, sampler=SubsetRandomSampler(idx_valid), **kwargs) 47 | test_loader = None 48 | # test_loader = DataLoader(test_ds, sampler=SubsetRandomSampler(idx_test), **kwargs) 49 | 50 | return train_loader, valid_loader, test_loader, len(label_list) 51 | 52 | 53 | def get_tgt_dataloader(name, img_size, batch_size, num_workers=3): 54 | num_instances = None 55 | if "small_svhn" in name: 56 | if name == "small_svhn": 57 | num_instances = 500 58 | elif name == "small_svhn_100": 59 | num_instances = 100 60 | elif name == "samll_svhn_2500": 61 | num_instances = 2500 62 | else: 63 | raise NotImplementedError 64 | name = "svhn" 65 | elif "cifar_100" in name: 66 | if name == "small_cifar_100": 67 | num_instances = 50 68 | else: 69 | raise NotImplementedError 70 | name = "cifar_100" 71 | elif "fashion_mnist" in name: 72 | if name == "small_fashion_mnist": 73 | num_instances = 500 74 | else: 75 | raise NotImplementedError 76 | name = "fashion_mnist" 77 | 78 | transform_train = get_transform(img_size, random_crop=True, random_horizontal_flip=True) 79 | transform_test = get_transform(img_size) 80 | 81 | train_ds = get_dataset(name, train=True, transform=transform_train) 82 | test_ds = get_dataset(name, train=False, transform=transform_test) 83 | 84 | kwargs = {"batch_size": batch_size, "num_workers": num_workers, "pin_memory": True, "drop_last": True} 85 | if num_instances: 86 | train_idx = [] 87 | for c in range(NUM_CLASSES[name]): 88 | try: 89 | train_idx.extend(list(np.argwhere(train_ds.labels == c)[:num_instances, 0])) 90 | except AttributeError: 91 | train_idx.extend(list(np.argwhere(np.array(train_ds.targets) == c)[:50, 0])) 92 | train_loader = DataLoader(train_ds, sampler=SubsetRandomSampler(train_idx), **kwargs) 93 | test_loader = DataLoader(test_ds, **kwargs) 94 | else: 95 | train_loader = DataLoader(train_ds, shuffle=True, **kwargs) 96 | test_loader = DataLoader(test_ds, **kwargs) 97 | return train_loader, test_loader, NUM_CLASSES[name] 98 | -------------------------------------------------------------------------------- /models/perturb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as tdist 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ChannelEquivarientOp(nn.Module): 8 | def __init__(self, kernel_size=3): 9 | super().__init__() 10 | self.padding = kernel_size // 2 11 | self.w_identity = nn.Parameter(torch.ones(1, 1, kernel_size, kernel_size)) 12 | self.w_all = nn.Parameter(torch.ones(1, 1, kernel_size, kernel_size)) 13 | nn.init.normal_(self.w_identity, 0, 0.01) 14 | nn.init.normal_(self.w_all, 0, 0.01) 15 | 16 | def forward(self, x): 17 | C = x.size(1) 18 | 19 | _w_identity = self.w_identity.expand(C, -1, -1, -1) 20 | _x = x.mean(dim=1, keepdim=True) 21 | 22 | alpha_identity = F.conv2d(x, _w_identity, padding=self.padding, groups=C) 23 | alpha_all = F.conv2d(_x, self.w_all, padding=self.padding).expand(-1, C, -1, -1) 24 | return alpha_identity + alpha_all 25 | 26 | 27 | class Perturb(nn.Module): 28 | def __init__(self, channel_norm_factor, spatial_norm_factor, h_dim=4, kernel_size=3): 29 | super().__init__() 30 | # Noise generator 31 | self.phi_1 = ChannelEquivarientOp(kernel_size) 32 | self.phi_2 = ChannelEquivarientOp(kernel_size) 33 | 34 | # Inference scale (& shift) 35 | self.h_dim = h_dim 36 | self.kernel = nn.Parameter(torch.ones(self.h_dim, 1, 3, 3)) 37 | self.fc = nn.Linear(2 * self.h_dim + 2, 1, bias=False) 38 | self.channel_norm_factor = channel_norm_factor 39 | self.spatial_norm_factor = spatial_norm_factor 40 | 41 | # Parmeter initialize 42 | nn.init.normal_(self.kernel, 0, 0.01) 43 | nn.init.normal_(self.fc.weight, 0, 0.01) 44 | 45 | # Initilaizing running_mean & running_var for each perturb layer 46 | self.momentum = 0.1 47 | self.running_mean = [] 48 | self.running_var = [] 49 | 50 | def add_running_stats(self, channels): 51 | self.running_mean.append(torch.zeros(channels, self.h_dim).cuda()) 52 | self.running_var.append(torch.ones(channels, self.h_dim).cuda()) 53 | 54 | def forward(self, x, clipval=None, noise_coeff=1.0, perturb_idx=None): 55 | # Noise generation 56 | alpha = self.phi_1(x) 57 | alpha = F.relu(alpha, inplace=True) 58 | alpha = self.phi_2(alpha) 59 | dist = tdist.Normal(alpha, torch.ones_like(alpha)) 60 | noise = dist.rsample() 61 | noise = F.softplus(noise) 62 | metrics = { 63 | "max_noise": torch.max(noise), 64 | "avg_noise": torch.mean(noise), 65 | "max_input": torch.max(x), 66 | "avg_input": torch.mean(x), 67 | } 68 | 69 | # Scale inference 70 | B, C, H, W = x.size() 71 | if self.training: 72 | kernel = self.kernel.repeat(C, 1, 1, 1) 73 | _x = F.relu(F.conv2d(x, kernel, padding=1, groups=C)).mean(dim=[-1, -2]).view(B, C, self.h_dim) 74 | _x_mean = _x.mean(dim=0) 75 | _x_var = _x.var(dim=0) 76 | with torch.no_grad(): 77 | self.running_mean[perturb_idx] = ( 78 | self.momentum * _x_mean + (1 - self.momentum) * self.running_mean[perturb_idx] 79 | ) 80 | self.running_var[perturb_idx] = ( 81 | self.momentum * _x_var * B / (B - 1) + (1 - self.momentum) * self.running_var[perturb_idx] 82 | ) 83 | else: 84 | _x_mean = self.running_mean[perturb_idx] 85 | _x_var = self.running_var[perturb_idx] 86 | 87 | channel = 1.0 * C / self.channel_norm_factor * torch.ones(C, 1).cuda() 88 | size = 1.0 * H / self.spatial_norm_factor * torch.ones(C, 1).cuda() 89 | x_vec = torch.cat([_x_mean, _x_var, channel, size], dim=-1) 90 | _x = self.fc(x_vec) 91 | scale = _x.view(1, -1, 1, 1) 92 | scale = torch.sigmoid(scale) 93 | 94 | # Apply sacle to noise & Noise annealing 95 | noise = ((scale * noise) - 1) * noise_coeff + 1 96 | out = noise * x 97 | metrics["max_output"] = torch.max(out) 98 | metrics["avg_output"] = torch.mean(out) 99 | metrics["norm_output"] = torch.norm(out, p=2, dim=[-1, -2]).mean() 100 | if clipval is not None: 101 | out = torch.clamp(out, max=clipval) 102 | return out, metrics 103 | -------------------------------------------------------------------------------- /data/get_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets import CIFAR100, FashionMNIST, STL10, SVHN 5 | from meta_imagenet.lib.datasets.DownsampledImageNet import ImageNet32 6 | 7 | 8 | NUM_CLASSES = { 9 | "aircraft": 100, 10 | "cifar_100": 100, 11 | "cub": 200, 12 | "fashion_mnist": 10, 13 | "stanford_cars": 196, 14 | "stanford_dogs": 120, 15 | "stl10": 10, 16 | "svhn": 10, 17 | "tiny_imagenet": 200, 18 | "mini_imagenet": 200, 19 | "dtd": 47, 20 | } 21 | 22 | 23 | def get_dataset(data: str, train: bool, transform=None, target_transform=None) -> Dataset: 24 | if data == "aircraft": 25 | return Aircraft(train, transform, target_transform) 26 | elif data == "cifar_100": 27 | return CIFAR100("data/cifar_100", train, transform, target_transform, download=True) 28 | elif data == "cub": 29 | return CUB(train, transform, target_transform) 30 | elif data == "fashion_mnist": 31 | return FashionMNIST("data/fashion_mnist", train, transform, target_transform, download=True) 32 | elif data == "stanford_cars": 33 | return StanfordCars(train, transform, target_transform) 34 | elif data == "stanford_dogs": 35 | return StanfordDogs(train, transform, target_transform) 36 | elif data == "stl10": 37 | return STL10("data/stl10", "train" if train else "test", None, transform, target_transform, download=True) 38 | elif data == "svhn": 39 | return SVHN("data/svhn", "train" if train else "test", transform, target_transform, download=True) 40 | elif data == "tiny_imagenet": 41 | return TinyImageNet(train, transform, target_transform) 42 | elif data == "meta_imagenet": 43 | return ImageNet32("/w14/dataset/MetaGen/batch32", train, transform, target_transform) 44 | else: 45 | raise NotImplementedError() 46 | 47 | 48 | class NumpyDataset(Dataset): 49 | def __init__(self, image_path, label_path, transform=None, target_transform=None): 50 | super().__init__() 51 | self.transform = transform 52 | self.target_transform = target_transform 53 | self.images = np.load(image_path) 54 | self.labels = np.load(label_path) 55 | self.length = len(self.images) 56 | 57 | def __getitem__(self, index): 58 | img = Image.fromarray(self.images[index]) 59 | label = self.labels[index] 60 | if self.transform: 61 | img = self.transform(img) 62 | if self.target_transform: 63 | label = self.target_transform(label) 64 | return img, label 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | 70 | class TinyImageNet(NumpyDataset): 71 | def __init__(self, train=True, transform=None, target_transform=None): 72 | super().__init__( 73 | image_path="data/tiny_imagenet/{}_images.npy".format("train" if train else "valid"), 74 | label_path="data/tiny_imagenet/{}_labels.npy".format("train" if train else "valid"), 75 | transform=transform, 76 | target_transform=target_transform, 77 | ) 78 | 79 | 80 | class MiniImageNet(NumpyDataset): 81 | def __init__(self, train=True, transform=None, target_transform=None): 82 | super().__init__( 83 | image_path="data/mini_imagenet/{}_images.npy".format("train" if train else "valid"), 84 | label_path="data/mini_imagenet/{}_labels.npy".format("train" if train else "valid"), 85 | transform=transform, 86 | target_transform=target_transform, 87 | ) 88 | 89 | 90 | class CUB(NumpyDataset): 91 | def __init__(self, train=True, transform=None, target_transform=None): 92 | super().__init__( 93 | image_path="data/CUB_200_2011/84_npy/{}_images.npy".format("train" if train else "test"), 94 | label_path="data/CUB_200_2011/84_npy/{}_labels.npy".format("train" if train else "test"), 95 | transform=transform, 96 | target_transform=target_transform, 97 | ) 98 | 99 | 100 | class Aircraft(NumpyDataset): 101 | def __init__(self, train=True, transform=None, target_transform=None): 102 | super().__init__( 103 | image_path="data/aircraft/{}_images.npy".format("train" if train else "test"), 104 | label_path="data/aircraft/{}_labels.npy".format("train" if train else "test"), 105 | transform=transform, 106 | target_transform=target_transform, 107 | ) 108 | 109 | 110 | class StanfordCars(NumpyDataset): 111 | def __init__(self, train=True, transform=None, target_transform=None): 112 | super().__init__( 113 | image_path="data/stanford_cars/{}_images.npy".format("train" if train else "test"), 114 | label_path="data/stanford_cars/{}_labels.npy".format("train" if train else "test"), 115 | transform=transform, 116 | target_transform=target_transform, 117 | ) 118 | 119 | 120 | class StanfordDogs(NumpyDataset): 121 | def __init__(self, train=True, transform=None, target_transform=None): 122 | super().__init__( 123 | image_path="data/stanford_dogs/{}_images.npy".format("train" if train else "test"), 124 | label_path="data/stanford_dogs/{}_labels.npy".format("train" if train else "test"), 125 | transform=transform, 126 | target_transform=target_transform, 127 | ) 128 | 129 | 130 | class DTD(NumpyDataset): 131 | def __init__(self, train=True, transform=None, target_transform=None): 132 | super().__init__( 133 | image_path="data/dtd/{}_images.npy".format("train" if train else "test"), 134 | label_path="data/dtd/{}_labels.npy".format("train" if train else "test"), 135 | transform=transform, 136 | target_transform=target_transform, 137 | ) 138 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .perturb import Perturb 5 | 6 | 7 | class ConvBlock(nn.Module): 8 | def __init__( 9 | self, in_channels, out_channels, perturb, kernel_size=3, stride=1, padding=1, do_maxpool=False, perturb_idx=None 10 | ): 11 | super().__init__() 12 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 13 | self.bn = nn.BatchNorm2d(out_channels) 14 | self.do_maxpool = do_maxpool 15 | 16 | self.perturb = perturb 17 | if self.perturb: 18 | self.perturb.add_running_stats(out_channels) 19 | self.perturb_idx = perturb_idx 20 | 21 | def forward(self, x, clipval=None, noise_coeff=1.0): 22 | out = self.bn(self.conv(x)) 23 | metrics = None 24 | if self.perturb: 25 | out, metrics = self.perturb(out, clipval, noise_coeff, self.perturb_idx) 26 | out = F.relu(out, inplace=True) 27 | if self.do_maxpool: 28 | out = F.max_pool2d(out, 2) 29 | return out, metrics 30 | 31 | 32 | class ConvNet(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | self.conv_layers = nn.ModuleList() 36 | 37 | def forward(self, x, clipval=None, noise_coeff=1.0): 38 | metrics_all = [] 39 | out = x 40 | for layer in self.conv_layers: 41 | out, metrics = layer(out, clipval, noise_coeff) 42 | metrics_all.append(metrics) 43 | 44 | out = out.view(out.size(0), -1) 45 | out = self.fc(out) 46 | return out, metrics_all 47 | 48 | 49 | class ConvNet4(ConvNet): 50 | def __init__(self, num_classes, conv_channels=32, img_size=32, do_perturb=False): 51 | super().__init__() 52 | sz = (((img_size // 2) // 2) // 2) // 2 53 | self.perturb = None 54 | if do_perturb: 55 | self.perturb = Perturb(channel_norm_factor=conv_channels, spatial_norm_factor=img_size) 56 | self.conv_layers = nn.ModuleList( 57 | [ 58 | ConvBlock(3, conv_channels, self.perturb, do_maxpool=True, perturb_idx=0), 59 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=1), 60 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=2), 61 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=3), 62 | ] 63 | ) 64 | self.fc = nn.Linear(sz * sz * conv_channels, num_classes) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 69 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | 74 | class ConvNet6(nn.Module): 75 | def __init__(self, num_classes, conv_channels=32, img_size=32, do_perturb=False): 76 | super().__init__() 77 | sz = (((img_size // 2) // 2) // 2) // 2 78 | self.perturb = None 79 | if do_perturb: 80 | self.perturb = Perturb(channel_norm_factor=conv_channels, spatial_norm_factor=img_size) 81 | self.conv_layers = nn.ModuleList( 82 | [ 83 | ConvBlock(3, conv_channels, self.perturb, do_maxpool=True, perturb_idx=0), 84 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=1), 85 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=False, perturb_idx=2), 86 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=3), 87 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=False, perturb_idx=4), 88 | ConvBlock(conv_channels, conv_channels, self.perturb, do_maxpool=True, perturb_idx=5), 89 | ] 90 | ) 91 | self.fc = nn.Linear(sz * sz * conv_channels, num_classes) 92 | 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 96 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 97 | nn.init.constant_(m.weight, 1) 98 | nn.init.constant_(m.bias, 0) 99 | 100 | 101 | class VGG9(nn.Module): 102 | def __init__(self, num_classes, img_size=32, do_perturb=False): 103 | super().__init__() 104 | sz = ((((img_size // 2) // 2) // 2) // 2) // 2 105 | self.perturb = None 106 | if do_perturb: 107 | self.perturb = Perturb(channel_norm_factor=512, spatial_norm_factor=img_size) 108 | self.conv_layers = nn.ModuleList( 109 | [ 110 | ConvBlock(3, 64, self.perturb, do_maxpool=True, perturb_idx=0), 111 | ConvBlock(64, 128, self.perturb, do_maxpool=True, perturb_idx=1), 112 | ConvBlock(128, 256, self.perturb, do_maxpool=False, perturb_idx=2), 113 | ConvBlock(256, 256, self.perturb, do_maxpool=True, perturb_idx=3), 114 | ConvBlock(256, 512, self.perturb, do_maxpool=False, perturb_idx=4), 115 | ConvBlock(512, 512, self.perturb, do_maxpool=True, perturb_idx=5), 116 | ConvBlock(512, 512, self.perturb, do_maxpool=False, perturb_idx=6), 117 | ConvBlock(512, 512, self.perturb, do_maxpool=True, perturb_idx=7), 118 | ] 119 | ) 120 | self.fc = nn.Linear(sz * sz * 512, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 125 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | -------------------------------------------------------------------------------- /train_tgt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from absl import app, flags, logging 6 | 7 | from dataloader import get_tgt_dataloader 8 | from models.get_model import get_model 9 | from utils import check_args, InfIterator, Logger, get_optimizier, get_scheduler 10 | 11 | FLAGS = flags.FLAGS 12 | # Training 13 | flags.DEFINE_integer("batch_size", 128, "Batch size") 14 | flags.DEFINE_integer("train_steps", 10000, "Total training steps") 15 | flags.DEFINE_enum("lr_schedule", "step_lr", ["step_lr", "cosine_lr"], "lr schedule") 16 | flags.DEFINE_enum("opt", "adam", ["adam", "sgd", "rmsprop"], "optimizer") 17 | flags.DEFINE_float("lr", 1e-3, "Learning rate") 18 | flags.DEFINE_float("noise_coeff", 1.0, "Noise coefficient for noise annealing (Not used)") 19 | 20 | # Model 21 | flags.DEFINE_string("model", "resnet20", "Model") 22 | flags.DEFINE_integer("clipval", 100, "clip value for perturb module") 23 | 24 | # Data 25 | flags.DEFINE_integer("img_size", 32, "Image size") 26 | flags.DEFINE_string("data", "stl10", "Data") 27 | 28 | # Misc 29 | flags.DEFINE_string("tblog_dir", None, "Directory for tensorboard logs") 30 | flags.DEFINE_string("code_dir", "./codes", "Directory for backup code") 31 | flags.DEFINE_string("save_dir", "./checkpoints", "Directory for checkpoints") 32 | flags.DEFINE_bool("fine_tune", False, "Fine tune or not") 33 | flags.DEFINE_string("src_name", "", "Source name to use") 34 | flags.DEFINE_string("src_steps", "10000", "Source training steps") 35 | flags.DEFINE_string("exp_name", "", "Experiment name") 36 | flags.DEFINE_integer("print_every", 200, "Print period") 37 | flags.DEFINE_string("gpus", "", "GPUs to use") 38 | flags.DEFINE_integer("num_workers", 3, "The number of workers for dataloading") 39 | 40 | 41 | def accuracy(y, y_pred): 42 | with torch.no_grad(): 43 | pred = torch.max(y_pred, dim=1) 44 | return 1.0 * pred[1].eq(y).sum() / y.size(0) 45 | 46 | 47 | def train_step(model, noise_coeff, train_iter, theta_optimizer, device, criterion, logger): 48 | model.train() 49 | 50 | x, y = next(train_iter) 51 | x, y = x.to(device), y.to(device) 52 | x = x.expand(-1, 3, -1, -1) 53 | 54 | y_pred, metrics = model(x, clipval=FLAGS.clipval, noise_coeff=noise_coeff) 55 | y_pred = nn.LogSoftmax(dim=1)(y_pred) 56 | loss = criterion(y_pred, y) 57 | 58 | theta_optimizer.zero_grad() 59 | loss.backward() 60 | theta_optimizer.step() 61 | 62 | # Meter logs 63 | logger.meter("train", "ce_loss", loss) 64 | logger.meter("train", "accuracy", accuracy(y, y_pred)) 65 | for i, metric in enumerate(metrics): 66 | if metric is None: 67 | continue 68 | for name, value in metric.items(): 69 | logger.meter(f"train_{name}", f"layer_{i}", value) 70 | 71 | 72 | def test(model, noise_coeff, test_loader, num_samples, device, criterion, logger): 73 | model.eval() 74 | correct, total = 0, 0 75 | with torch.no_grad(): 76 | for x, y in test_loader: 77 | x, y = x.to(device), y.to(device) 78 | x = x.expand(-1, 3, -1, -1) 79 | 80 | y_preds = torch.stack([nn.Softmax(dim=1)(model(x, noise_coeff=noise_coeff)[0]) for _ in range(num_samples)]) 81 | y_pred = torch.log(torch.mean(y_preds, dim=0)) 82 | 83 | loss = criterion(y_pred, y) 84 | logger.meter("test", "ce_loss", loss) 85 | 86 | pred = torch.max(y_pred, dim=1) 87 | correct += pred[1].eq(y).sum() 88 | total += y.size(0) 89 | 90 | logger.meter("test", "accuracy", 1.0 * correct / total) 91 | return (1.0 * correct / total).item() 92 | 93 | 94 | def main(argv): 95 | del argv 96 | check_args(FLAGS) 97 | 98 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpus 99 | os.environ["WANDB_SILENT"] = "true" 100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | 102 | # Dataloader 103 | train_loader, test_loader, num_classes = get_tgt_dataloader( 104 | name=FLAGS.data, img_size=FLAGS.img_size, batch_size=FLAGS.batch_size, num_workers=FLAGS.num_workers 105 | ) 106 | train_iter = InfIterator(train_loader) 107 | logging.info(f"Train dataset: {len(train_loader)} batches") 108 | logging.info(f"Total {FLAGS.train_steps//len(train_loader)} epochs") 109 | logging.info(f"Test dataset: {len(test_loader)} batches") 110 | 111 | # Model 112 | model = get_model(model_name=FLAGS.model, num_classes=num_classes, img_size=FLAGS.img_size, do_perturb=True) 113 | model = model.to(device) 114 | if FLAGS.fine_tune: 115 | src_model_state_dict = torch.load(f"{FLAGS.save_dir}/tgt/scratch_TIN_{FLAGS.model}/model_100000.pth") 116 | state_dict_wo_bn = { 117 | name: value for name, value in src_model_state_dict.items() if "bn" not in name and "fc" not in name 118 | } 119 | model.load_state_dict(state_dict_wo_bn, strict=False) 120 | logging.info(f"Model is loaded from {FLAGS.save_dir}/tgt/scratch_TIN_{FLAGS.model}/model_100000.pth") 121 | 122 | src_perturb_state_dict = torch.load(f"{FLAGS.save_dir}/{FLAGS.src_name}_src/split_1/perturb_{FLAGS.src_steps}.pth") 123 | model.perturb.load_state_dict(src_perturb_state_dict) 124 | logging.info( 125 | f"Perturb module is loaded from {FLAGS.save_dir}/{FLAGS.src_name}_src/split_1/perturb_{FLAGS.src_steps}.pth" 126 | ) 127 | 128 | theta = [p for name, p in model.named_parameters() if "perturb" not in name] 129 | 130 | # Optimizer 131 | theta_opt = get_optimizier(FLAGS.opt, FLAGS.lr, theta) 132 | 133 | # Scheduler 134 | scheduler = get_scheduler(FLAGS.lr_schedule, theta_opt, FLAGS.train_steps) 135 | 136 | # Criterion 137 | criterion = nn.NLLLoss().to(device) 138 | 139 | # Logger 140 | logger = Logger( 141 | exp_name=FLAGS.exp_name, 142 | log_dir=FLAGS.log_dir, 143 | save_dir=FLAGS.save_dir, 144 | exp_suffix="tgt", 145 | print_every=FLAGS.print_every, 146 | save_every=FLAGS.train_steps, 147 | total_step=FLAGS.train_steps, 148 | use_wandb=True, 149 | wnadb_project_name="l2p", 150 | wandb_config=FLAGS, 151 | ) 152 | logger.register_model_to_save(model, "model") 153 | logger.register_model_to_save(model.perturb, "perturb") 154 | 155 | # Training Loop 156 | logger.start() 157 | for step in range(1, FLAGS.train_steps + 1): 158 | # Noise annealing 159 | noise_coeff = FLAGS.noise_coeff * min(1.0, step / int(FLAGS.train_steps * 0.4)) 160 | train_step(model, noise_coeff, train_iter, theta_opt, device, criterion, logger) 161 | scheduler.step() 162 | if step % FLAGS.print_every == 0: 163 | test(model, noise_coeff, test_loader, 5, device, criterion, logger) 164 | 165 | logger.step() 166 | 167 | # Test final model 168 | final_results = test(model, FLAGS.noise_coeff, test_loader, 100, device, criterion, logger) 169 | logger.write_log_individually("final_accuracy", final_results, FLAGS.train_steps) 170 | print(f"Final accuracy: {final_results}") 171 | logger.finish() 172 | 173 | 174 | if __name__ == "__main__": 175 | app.run(main) 176 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from datetime import datetime 4 | 5 | import torch 6 | import torch.optim as optim 7 | import wandb 8 | from absl import logging 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | def get_optimizier(opt_name, lr, params): 13 | if opt_name == "adam": 14 | opt = optim.Adam(params, lr=lr, weight_decay=5e-4) 15 | elif opt_name == "sgd": 16 | opt = optim.SGD(params, lr=lr, weight_decay=5e-4, momentum=0.9) 17 | elif opt_name == "rmsprop": 18 | opt = optim.RMSprop(params, lr=lr, weight_decay=5e-4, momentum=0.9) 19 | else: 20 | raise NotImplementedError 21 | return opt 22 | 23 | 24 | def get_scheduler(scheduler_name, opt, train_steps, milestones=[0.4, 0.7, 0.9], gamma=0.3): 25 | if scheduler_name == "step_lr": 26 | milestones = [int(train_steps * v) for v in milestones] 27 | scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=gamma) 28 | elif scheduler_name == "cosine_lr": 29 | scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, train_steps) 30 | else: 31 | raise NotImplementedError 32 | return scheduler 33 | 34 | 35 | class InfIterator: 36 | def __init__(self, iterable): 37 | self.iterable = iterable 38 | self.iterator = iter(self.iterable) 39 | 40 | def __next__(self): 41 | try: 42 | return next(self.iterator) 43 | except StopIteration: 44 | self.iterator = iter(self.iterable) 45 | return next(self.iterator) 46 | 47 | 48 | def check_args(FLAGS): 49 | ignore = [ 50 | "logtostderr", 51 | "alsologtostderr", 52 | "log_dir", 53 | "v", 54 | "verbosity", 55 | "stderrthreshold", 56 | "showprefixforinfo", 57 | "run_with_pdb", 58 | "pdb_post_mortem", 59 | "run_with_profiling", 60 | "profile_file", 61 | "use_cprofile_for_profiling", 62 | "only_check_args", 63 | "?", 64 | "help", 65 | "helpshort", 66 | "helpfull", 67 | "helpxml", 68 | ] 69 | for name, value in FLAGS.flag_values_dict().items(): 70 | if name not in ignore: 71 | print(f"{name:>15} : {value}") 72 | print("Is this correct? (y/n)") 73 | ret = input() 74 | if ret.lower() != "y": 75 | exit(0) 76 | 77 | 78 | def backup_code( 79 | backup_dir, 80 | ignore_list={".gitignore", ".ipynb_checkpoints", ".vscode", "__pycache__", "checkpoint", "data", "runs", "wandb"}, 81 | ): 82 | shutil.copytree( 83 | os.path.abspath(os.path.curdir), backup_dir, ignore=lambda src, names: ignore_list, 84 | ) 85 | 86 | 87 | class Logger: 88 | def __init__( 89 | self, 90 | exp_name, 91 | exp_suffix="", 92 | log_dir=None, 93 | save_dir=None, 94 | print_every=100, 95 | save_every=100, 96 | initial_step=0, 97 | total_step=0, 98 | print_to_stdout=True, 99 | use_wandb=False, 100 | wnadb_project_name=None, 101 | wandb_tags=[], 102 | wandb_config=None, 103 | ): 104 | if log_dir is not None: 105 | self.log_dir = os.path.join(log_dir, exp_name, exp_suffix) 106 | os.makedirs(self.log_dir, exist_ok=True) 107 | else: 108 | self.log_dir = None 109 | assert use_wandb, "'log_dir' argument must be given or 'use_wandb' argument must be True." 110 | 111 | if save_dir is not None: 112 | self.save_dir = os.path.join(save_dir, exp_name, exp_suffix) 113 | os.makedirs(self.save_dir, exist_ok=True) 114 | else: 115 | self.save_dir = None 116 | 117 | self.print_every = print_every 118 | self.save_every = save_every 119 | self.step_count = initial_step 120 | self.total_step = total_step 121 | self.print_to_stdout = print_to_stdout 122 | self.use_wandb = use_wandb 123 | 124 | self.writer = None 125 | self.start_time = None 126 | self.groups = dict() 127 | self.models_to_save = dict() 128 | if self.use_wandb: 129 | exp_suffix = "_".join(exp_suffix.split("/")[:-1]) 130 | wandb.init(project=wnadb_project_name, name=exp_name + "_" + exp_suffix, tags=wandb_tags, reinit=True) 131 | wandb.config.update(wandb_config) 132 | 133 | def register_model_to_save(self, model, name): 134 | assert name not in self.models_to_save.keys(), "Name is already registered." 135 | 136 | self.models_to_save[name] = model 137 | 138 | def step(self): 139 | self.step_count += 1 140 | if self.step_count % self.print_every == 0: 141 | if self.print_to_stdout: 142 | self.print_log(self.step_count, self.total_step, elapsed_time=datetime.now() - self.start_time) 143 | self.write_log(self.step_count) 144 | 145 | if self.step_count % self.save_every == 0: 146 | self.save_models() 147 | 148 | def meter(self, group_name, log_name, value): 149 | if group_name not in self.groups.keys(): 150 | self.groups[group_name] = dict() 151 | 152 | if log_name not in self.groups[group_name].keys(): 153 | self.groups[group_name][log_name] = Accumulator() 154 | 155 | self.groups[group_name][log_name].update_state(value) 156 | 157 | def reset_state(self): 158 | for _, group in self.groups.items(): 159 | for _, log in group.items(): 160 | log.reset_state() 161 | 162 | def print_log(self, step, total_step, elapsed_time=None): 163 | print(f"[Step {step:5d}/{total_step}]", end=" ") 164 | 165 | for name, group in self.groups.items(): 166 | print(f"({name})", end=" ") 167 | for log_name, log in group.items(): 168 | if "acc" in log_name.lower(): 169 | print(f"{log_name} {log.result() * 100:.2f}", end=" | ") 170 | else: 171 | print(f"{log_name} {log.result():.4f}", end=" | ") 172 | 173 | if elapsed_time is not None: 174 | print(f"(Elapsed time) {elapsed_time}") 175 | else: 176 | print() 177 | 178 | def write_log(self, step): 179 | if self.use_wandb: 180 | log_dict = {} 181 | for group_name, group in self.groups.items(): 182 | for log_name, log in group.items(): 183 | log_dict["{}/{}".format(log_name, group_name)] = log.result() 184 | wandb.log(log_dict, step=step) 185 | else: 186 | if self.writer is None: 187 | self.writer = SummaryWriter(self.log_dir) 188 | 189 | for group_name, group in self.groups.items(): 190 | for log_name, log in group.items(): 191 | self.writer.add_scalar("{}/{}".format(log_name, group_name), log.result(), step) 192 | self.writer.flush() 193 | 194 | self.reset_state() 195 | 196 | def write_log_individually(self, name, value, step): 197 | if self.use_wandb: 198 | wandb.log({name: value}, step=step) 199 | else: 200 | self.writer.add_scalar(name, value, step=step) 201 | 202 | def save_models(self, suffix=None): 203 | if self.save_dir is None: 204 | return 205 | 206 | for name, model in self.models_to_save.items(): 207 | if suffix: 208 | name += f"_{suffix}" 209 | torch.save(model.state_dict(), os.path.join(self.save_dir, f"{name}.pth")) 210 | 211 | if self.print_to_stdout: 212 | logging.info(f"Model is saved to {self.save_dir}") 213 | 214 | def start(self): 215 | if self.print_to_stdout: 216 | logging.info("Training starts!") 217 | self.save_models("init") 218 | self.start_time = datetime.now() 219 | 220 | def finish(self): 221 | if self.step_count % self.save_every != 0: 222 | self.save_models(self.step_count) 223 | 224 | if self.print_to_stdout: 225 | logging.info("Training is finished!") 226 | 227 | if self.use_wandb: 228 | wandb.join() 229 | 230 | 231 | class Accumulator: 232 | def __init__(self): 233 | self.data = 0 234 | self.num_data = 0 235 | 236 | def reset_state(self): 237 | self.data = 0 238 | self.num_data = 0 239 | 240 | def update_state(self, tensor): 241 | with torch.no_grad(): 242 | self.data += tensor 243 | self.num_data += 1 244 | 245 | def result(self): 246 | if self.num_data == 0: 247 | return 0 248 | return (1.0 * self.data / self.num_data).item() 249 | -------------------------------------------------------------------------------- /train_src.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | from absl import app, flags, logging 9 | from torch.multiprocessing import Process 10 | 11 | from dataloader import get_src_dataloader 12 | from models.get_model import get_model 13 | from utils import InfIterator, Logger, backup_code, check_args, get_optimizier 14 | 15 | FLAGS = flags.FLAGS 16 | # Training 17 | flags.DEFINE_integer("num_run", 1, "The number of meta-training runs") 18 | flags.DEFINE_integer("batch_size", 128, "Batch size") 19 | flags.DEFINE_integer("train_steps", 2000, "Total training steps for a single run") 20 | flags.DEFINE_enum("lr_schedule", "step_lr", ["step_lr", "cosine_lr"], "lr schedule") 21 | flags.DEFINE_enum("opt", "adam", ["adam", "sgd", "rmsprop"], "optimizer") 22 | flags.DEFINE_float("lr", 1e-3, "Learning rate") 23 | 24 | # Model 25 | flags.DEFINE_string("model", "resnet20", "Model") 26 | flags.DEFINE_integer("clipval", 100, "clip value for perturb module") 27 | 28 | # Data 29 | flags.DEFINE_integer("img_size", 32, "Image size") 30 | flags.DEFINE_string("data", "tiny_imagenet", "Data") 31 | flags.DEFINE_integer("num_split", 10, "The number of splits") 32 | 33 | # Misc 34 | flags.DEFINE_string("tblog_dir", None, "Directory for tensorboard logs") 35 | flags.DEFINE_string("code_dir", "./codes", "Directory for backup code") 36 | flags.DEFINE_string("save_dir", "./checkpoints", "Directory for checkpoints") 37 | flags.DEFINE_string("exp_name", "", "Experiment name") 38 | flags.DEFINE_integer("print_every", 200, "Print period") 39 | flags.DEFINE_integer("save_every", 1000, "Save period") 40 | flags.DEFINE_list("gpus", "", "GPUs to use") 41 | flags.DEFINE_string("port", "123456", "Port number for multiprocessing") 42 | flags.DEFINE_integer("num_workers", 1, "The number of workers for dataloading") 43 | 44 | 45 | def share_grads(params): 46 | tensors = torch.cat([p.grad.view(-1) for p in params]) 47 | dist.all_reduce(tensors) 48 | tensors /= dist.get_world_size() 49 | 50 | idx = 0 51 | for p in params: 52 | p.grad.data.copy_(tensors[idx : idx + np.prod(p.grad.shape)].view(p.grad.size())) 53 | idx += np.prod(p.grad.shape) 54 | 55 | 56 | def accuracy(y, y_pred): 57 | with torch.no_grad(): 58 | pred = torch.max(y_pred, dim=1) 59 | return 1.0 * pred[1].eq(y).sum() / y.size(0) 60 | 61 | 62 | def train_step(model, phi, train_iter, valid_iter, theta_opt, phi_opt, clipval, device, criterion, logger): 63 | model.train() 64 | 65 | # Sample data from training set 66 | x, y = next(train_iter) 67 | x, y = x.to(device), y.to(device) 68 | x = x.expand(-1, 3, -1, -1) 69 | 70 | # Update theta 71 | y_pred, metrics = model(x, clipval=clipval) 72 | y_pred = nn.LogSoftmax(dim=1)(y_pred) 73 | loss = criterion(y_pred, y) 74 | 75 | theta_opt.zero_grad() 76 | loss.backward() 77 | gradient_norm = [torch.norm(param.grad) for name, param in model.named_parameters() if "perturb" not in name] 78 | theta_opt.step() 79 | 80 | # Meter logs 81 | logger.meter("train", "ce_loss", loss) 82 | logger.meter("train", "accuracy", accuracy(y, y_pred)) 83 | for i, metric in enumerate(metrics): 84 | if metric is None: 85 | continue 86 | for name, value in metric.items(): 87 | logger.meter(f"train_{name}", f"layer_{i}", value) 88 | logger.meter("theta", "gradient_norm", torch.sum(torch.stack(gradient_norm))) 89 | 90 | # Stop accumulating running statistics 91 | for m in model.modules(): 92 | if isinstance(m, nn.BatchNorm2d): 93 | m.eval() 94 | 95 | # Sample data from validation set 96 | x, y = next(valid_iter) 97 | x, y = x.to(device), y.to(device) 98 | x = x.expand(-1, 3, -1, -1) 99 | 100 | # Update phi 101 | y_pred, metrics = model(x, clipval=clipval) 102 | y_pred = nn.LogSoftmax(dim=1)(y_pred) 103 | loss = criterion(y_pred, y) 104 | 105 | phi_opt.zero_grad() 106 | loss.backward() 107 | gradient_norm = [torch.norm(param.grad) for param in phi] 108 | share_grads(phi) 109 | phi_opt.step() 110 | 111 | # Meter logs 112 | logger.meter("valid", "ce_loss", loss) 113 | logger.meter("valid", "accuracy", accuracy(y, y_pred)) 114 | for i, metric in enumerate(metrics): 115 | if metric is None: 116 | continue 117 | for name, value in metric.items(): 118 | logger.meter(f"valid_{name}", f"layer_{i}", value) 119 | 120 | idx = 0 121 | for name, _ in model.named_parameters(): 122 | if "perturb" in name: 123 | logger.meter(name[8:], "gradient_norm", gradient_norm[idx]) 124 | idx += 1 125 | logger.meter("all", "gradient_norm", torch.sum(torch.stack(gradient_norm))) 126 | 127 | 128 | def test(model, test_loader, device, criterion, logger): 129 | model.eval() 130 | correct, total = 0, 0 131 | with torch.no_grad(): 132 | for x, y in test_loader: 133 | x, y = x.to(device), y.to(device) 134 | x = x.expand(-1, 3, -1, -1) 135 | 136 | y_preds = torch.stack([nn.Softmax(dim=1)(model(x)) for _ in range(5)]) 137 | y_pred = torch.log(torch.mean(y_preds, dim=0)) 138 | 139 | loss = criterion(y_pred, y) 140 | logger.meter("test", "ce_loss", loss) 141 | 142 | pred = torch.max(y_pred, dim=1) 143 | correct += pred[1].eq(y).sum() 144 | total += y.size(0) 145 | logger.meter("test", "accuracy", 1.0 * correct / total) 146 | 147 | 148 | def run_single_process(rank, model_name, backend="nccl"): 149 | dist.init_process_group(backend, rank=rank, world_size=FLAGS.num_split) 150 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpus[rank] 151 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 152 | 153 | # Dataloader 154 | # Don't test in meta-training stage 155 | train_loader, valid_loader, test_loader, num_classes = get_src_dataloader( 156 | name=FLAGS.data, 157 | split=rank, 158 | total_split=FLAGS.num_split, 159 | img_size=FLAGS.img_size, 160 | batch_size=FLAGS.batch_size, 161 | num_workers=FLAGS.num_workers, 162 | ) 163 | train_iter = InfIterator(train_loader) 164 | valid_iter = InfIterator(valid_loader) 165 | 166 | if rank == 0: 167 | logging.info(f"Train dataset: {len(train_loader)} batches") 168 | logging.info(f"Total {FLAGS.train_steps//len(train_loader)} epochs") 169 | logging.info(f"Valid dataset: {len(valid_loader)} batches") 170 | logging.info(f"Total {FLAGS.train_steps//len(valid_loader)} epochs") 171 | # logging.info(f"Test dataset: {len(test_loader)} batches") 172 | 173 | for run in range(FLAGS.num_run): 174 | if rank == 0: 175 | logging.info(f"Run #{run+1}") 176 | # Model 177 | model = get_model(model_name=model_name, num_classes=num_classes, img_size=FLAGS.img_size, do_perturb=True) 178 | model = model.to(device) 179 | 180 | theta = [p for name, p in model.named_parameters() if "perturb" not in name] 181 | phi = [p for name, p in model.named_parameters() if "perturb" in name] 182 | 183 | # Synchronize phi at the beginning 184 | for p in phi: 185 | dist.all_reduce(p.data) 186 | p.data /= FLAGS.num_split 187 | 188 | # Criterion 189 | criterion = nn.NLLLoss().to(device) 190 | 191 | # Optimizers 192 | theta_opt = get_optimizier(FLAGS.opt, FLAGS.lr, theta) 193 | phi_opt = get_optimizier(FLAGS.opt, FLAGS.lr, phi) 194 | 195 | # Logger 196 | logger = Logger( 197 | exp_name=FLAGS.exp_name, 198 | log_dir=FLAGS.log_dir, 199 | save_dir=FLAGS.save_dir, 200 | exp_suffix=f"run_{run}_src/split_{rank+1}", 201 | print_every=FLAGS.print_every, 202 | save_every=FLAGS.save_every, 203 | total_step=FLAGS.train_steps, 204 | print_to_stdout=(rank == 0), 205 | use_wandb=True, 206 | wnadb_project_name="l2p", 207 | wandb_tags=[f"split_{rank+1}"], 208 | wandb_config=FLAGS, 209 | ) 210 | logger.register_model_to_save(model, "model") 211 | logger.register_model_to_save(model.perturb, "perturb") 212 | 213 | # Training Loop 214 | logger.start() 215 | for i in range(FLAGS.train_steps): 216 | train_step(model, phi, train_iter, valid_iter, theta_opt, phi_opt, FLAGS.clipval, device, criterion, logger) 217 | if (i + 1) % FLAGS.save_every == 0: 218 | pass 219 | logger.step() 220 | logger.finish() 221 | 222 | 223 | def run_multi_process(argv): 224 | del argv 225 | check_args(FLAGS) 226 | backup_code(os.path.join(FLAGS.code_dir, FLAGS.exp_name, datetime.now().strftime("%m-%d-%H-%M-%S"))) 227 | 228 | os.environ["MASTER_ADDR"] = "localhost" 229 | os.environ["MASTER_PORT"] = FLAGS.port 230 | os.environ["WANDB_SILENT"] = "true" 231 | processes = [] 232 | 233 | for rank in range(FLAGS.num_split): 234 | p = Process(target=run_single_process, args=(rank, FLAGS.model)) 235 | p.start() 236 | processes.append(p) 237 | 238 | for p in processes: 239 | p.join() 240 | 241 | 242 | if __name__ == "__main__": 243 | app.run(run_multi_process) 244 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .perturb import Perturb 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=dilation, 16 | groups=groups, 17 | bias=False, 18 | dilation=dilation, 19 | ) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__( 31 | self, 32 | inplanes, 33 | planes, 34 | stride=1, 35 | downsample=None, 36 | groups=1, 37 | base_width=64, 38 | dilation=1, 39 | norm_layer=None, 40 | perturb=None, 41 | perturb_idx=None, 42 | ): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | # MetaPerturb 60 | self.perturb = perturb 61 | if perturb: 62 | self.perturb.add_running_stats(planes) 63 | self.perturb_idx = perturb_idx 64 | 65 | def forward(self, x, clipval, noise_coeff): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | # MetaPerturb 80 | metrics = None 81 | if self.perturb: 82 | out, metrics = self.perturb(out, clipval, noise_coeff, self.perturb_idx) 83 | out = self.relu(out) 84 | return out, metrics 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 | # This variant is also known as ResNet V1.5 and improves accuracy according to 92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 | 94 | expansion = 4 95 | 96 | def __init__( 97 | self, 98 | inplanes, 99 | planes, 100 | stride=1, 101 | downsample=None, 102 | groups=1, 103 | base_width=64, 104 | dilation=1, 105 | norm_layer=None, 106 | perturb=None, 107 | perturb_idx=None, 108 | ): 109 | super(Bottleneck, self).__init__() 110 | if norm_layer is None: 111 | norm_layer = nn.BatchNorm2d 112 | width = int(planes * (base_width / 64.0)) * groups 113 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 114 | self.conv1 = conv1x1(inplanes, width) 115 | self.bn1 = norm_layer(width) 116 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 117 | self.bn2 = norm_layer(width) 118 | self.conv3 = conv1x1(width, planes * self.expansion) 119 | self.bn3 = norm_layer(planes * self.expansion) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.downsample = downsample 122 | self.stride = stride 123 | 124 | # MetaPerturb 125 | self.perturb = perturb 126 | if perturb: 127 | self.perturb.add_running_stats(planes) 128 | self.perturb_idx = perturb_idx 129 | 130 | def forward(self, x, clipval, noise_coeff): 131 | identity = x 132 | 133 | out = self.conv1(x) 134 | out = self.bn1(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv2(out) 138 | out = self.bn2(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv3(out) 142 | out = self.bn3(out) 143 | 144 | if self.downsample is not None: 145 | identity = self.downsample(x) 146 | 147 | out += identity 148 | # MetaPerturb 149 | metrics = None 150 | if self.perturb: 151 | out, metrics = self.perturb(out, clipval, noise_coeff, self.perturb_idx) 152 | out = self.relu(out) 153 | return out, metrics 154 | 155 | 156 | class IdentityShortCut(nn.Module): 157 | def __init__(self, pad): 158 | super().__init__() 159 | self.pad = pad 160 | 161 | def forward(self, x): 162 | return F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, self.pad, self.pad), "constant", 0) 163 | 164 | 165 | class ResNet_small(nn.Module): 166 | def __init__( 167 | self, 168 | block, 169 | layers, 170 | img_size, 171 | num_classes=1000, 172 | zero_init_residual=False, 173 | groups=1, 174 | width_per_group=64, 175 | replace_stride_with_dilation=None, 176 | norm_layer=None, 177 | do_perturb=False, 178 | ): 179 | super(ResNet_small, self).__init__() 180 | if norm_layer is None: 181 | norm_layer = nn.BatchNorm2d 182 | self._norm_layer = norm_layer 183 | 184 | self.inplanes = 16 185 | self.dilation = 1 186 | if replace_stride_with_dilation is None: 187 | # each element in the tuple indicates if we should replace 188 | # the 2x2 stride with a dilated convolution instead 189 | replace_stride_with_dilation = [False, False, False] 190 | if len(replace_stride_with_dilation) != 3: 191 | raise ValueError( 192 | "replace_stride_with_dilation should be None " 193 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 194 | ) 195 | self.groups = groups 196 | self.base_width = width_per_group 197 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 198 | self.bn1 = norm_layer(self.inplanes) 199 | self.relu = nn.ReLU(inplace=True) 200 | # MetaPerturb 201 | self.perturb = None 202 | if do_perturb: 203 | self.perturb = Perturb(channel_norm_factor=64, spatial_norm_factor=img_size) 204 | 205 | self.layer1 = self._make_layer(block, 16, layers[0], stride=1, perturb_idx=0) 206 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, perturb_idx=layers[0]) 207 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, perturb_idx=layers[0] + layers[1]) 208 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 209 | self.fc = nn.Linear(64 * block.expansion, num_classes) 210 | 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv2d): 213 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 214 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 215 | nn.init.constant_(m.weight, 1) 216 | nn.init.constant_(m.bias, 0) 217 | 218 | # Zero-initialize the last BN in each residual branch, 219 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 220 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 221 | if zero_init_residual: 222 | for m in self.modules(): 223 | if isinstance(m, Bottleneck): 224 | nn.init.constant_(m.bn3.weight, 0) 225 | elif isinstance(m, BasicBlock): 226 | nn.init.constant_(m.bn2.weight, 0) 227 | 228 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, perturb_idx=0): 229 | norm_layer = self._norm_layer 230 | downsample = None 231 | previous_dilation = self.dilation 232 | if dilate: 233 | self.dilation *= stride 234 | stride = 1 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | downsample = IdentityShortCut((planes * block.expansion - self.inplanes) // 2) 237 | 238 | layers = [] 239 | layers.append( 240 | block( 241 | self.inplanes, 242 | planes, 243 | stride, 244 | downsample, 245 | self.groups, 246 | self.base_width, 247 | previous_dilation, 248 | norm_layer, 249 | self.perturb, 250 | perturb_idx=perturb_idx, 251 | ) 252 | ) 253 | perturb_idx += 1 254 | 255 | self.inplanes = planes * block.expansion 256 | for i in range(1, blocks): 257 | layers.append( 258 | block( 259 | self.inplanes, 260 | planes, 261 | groups=self.groups, 262 | base_width=self.base_width, 263 | dilation=self.dilation, 264 | norm_layer=norm_layer, 265 | perturb=self.perturb, 266 | perturb_idx=perturb_idx, 267 | ) 268 | ) 269 | perturb_idx += 1 270 | return nn.Sequential(*layers) 271 | 272 | def _forward_impl(self, x, clipval, noise_coeff): 273 | # See note [TorchScript super()] 274 | x = self.conv1(x) 275 | x = self.bn1(x) 276 | x = self.relu(x) 277 | 278 | # MetaPerturb 279 | metrics_all = [] 280 | for layer in self.layer1: 281 | x, metrics = layer(x, clipval, noise_coeff) 282 | metrics_all.append(metrics) 283 | for layer in self.layer2: 284 | x, metrics = layer(x, clipval, noise_coeff) 285 | metrics_all.append(metrics) 286 | for layer in self.layer3: 287 | x, metrics = layer(x, clipval, noise_coeff) 288 | metrics_all.append(metrics) 289 | x = self.avgpool(x) 290 | x = torch.flatten(x, 1) 291 | x = self.fc(x) 292 | return x, metrics_all 293 | 294 | def forward(self, x, clipval=None, noise_coeff=1.0): 295 | return self._forward_impl(x, clipval, noise_coeff) 296 | 297 | 298 | def resnet20(num_classes, img_size, do_perturb): 299 | return ResNet_small(BasicBlock, [3, 3, 3], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 300 | 301 | 302 | def resnet32(num_classes, img_size, do_perturb): 303 | return ResNet_small(BasicBlock, [5, 5, 5], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 304 | 305 | 306 | def resnet44(num_classes, img_size, do_perturb): 307 | return ResNet_small(BasicBlock, [7, 7, 7], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 308 | 309 | 310 | def resnet56(num_classes, img_size, do_perturb): 311 | return ResNet_small(BasicBlock, [9, 9, 9], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 312 | 313 | 314 | class ResNet(nn.Module): 315 | def __init__( 316 | self, 317 | block, 318 | layers, 319 | img_size, 320 | num_classes=1000, 321 | zero_init_residual=False, 322 | groups=1, 323 | width_per_group=64, 324 | replace_stride_with_dilation=None, 325 | norm_layer=None, 326 | do_perturb=False, 327 | ): 328 | super(ResNet, self).__init__() 329 | if norm_layer is None: 330 | norm_layer = nn.BatchNorm2d 331 | self._norm_layer = norm_layer 332 | 333 | self.inplanes = 64 334 | self.dilation = 1 335 | if replace_stride_with_dilation is None: 336 | # each element in the tuple indicates if we should replace 337 | # the 2x2 stride with a dilated convolution instead 338 | replace_stride_with_dilation = [False, False, False] 339 | if len(replace_stride_with_dilation) != 3: 340 | raise ValueError( 341 | "replace_stride_with_dilation should be None " 342 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 343 | ) 344 | self.groups = groups 345 | self.base_width = width_per_group 346 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 347 | self.bn1 = norm_layer(self.inplanes) 348 | self.relu = nn.ReLU(inplace=True) 349 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 350 | self.perturb = None 351 | if do_perturb: 352 | self.perturb = Perturb(channel_norm_factor=512, spatial_norm_factor=img_size) 353 | self.layer1 = self._make_layer(block, 64, layers[0], perturb_idx=0) 354 | self.layer2 = self._make_layer( 355 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], perturb_idx=layers[0] 356 | ) 357 | self.layer3 = self._make_layer( 358 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], perturb_idx=layers[0] + layers[1] 359 | ) 360 | self.layer4 = self._make_layer( 361 | block, 362 | 512, 363 | layers[3], 364 | stride=2, 365 | dilate=replace_stride_with_dilation[2], 366 | perturb_idx=layers[0] + layers[1] + layers[2], 367 | ) 368 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 369 | self.fc = nn.Linear(512 * block.expansion, num_classes) 370 | 371 | for m in self.modules(): 372 | if isinstance(m, nn.Conv2d): 373 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 374 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 375 | nn.init.constant_(m.weight, 1) 376 | nn.init.constant_(m.bias, 0) 377 | 378 | # Zero-initialize the last BN in each residual branch, 379 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 380 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 381 | if zero_init_residual: 382 | for m in self.modules(): 383 | if isinstance(m, Bottleneck): 384 | nn.init.constant_(m.bn3.weight, 0) 385 | elif isinstance(m, BasicBlock): 386 | nn.init.constant_(m.bn2.weight, 0) 387 | 388 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, perturb_idx=0): 389 | norm_layer = self._norm_layer 390 | downsample = None 391 | previous_dilation = self.dilation 392 | if dilate: 393 | self.dilation *= stride 394 | stride = 1 395 | if stride != 1 or self.inplanes != planes * block.expansion: 396 | downsample = nn.Sequential( 397 | conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), 398 | ) 399 | 400 | layers = [] 401 | layers.append( 402 | block( 403 | self.inplanes, 404 | planes, 405 | stride, 406 | downsample, 407 | self.groups, 408 | self.base_width, 409 | previous_dilation, 410 | norm_layer, 411 | self.perturb, 412 | perturb_idx=perturb_idx, 413 | ) 414 | ) 415 | perturb_idx += 1 416 | 417 | self.inplanes = planes * block.expansion 418 | for i in range(1, blocks): 419 | layers.append( 420 | block( 421 | self.inplanes, 422 | planes, 423 | groups=self.groups, 424 | base_width=self.base_width, 425 | dilation=self.dilation, 426 | norm_layer=norm_layer, 427 | perturb=self.perturb, 428 | perturb_idx=perturb_idx, 429 | ) 430 | ) 431 | perturb_idx += 1 432 | return nn.Sequential(*layers) 433 | 434 | def _forward_impl(self, x, clipval, noise_coeff): 435 | # See note [TorchScript super()] 436 | x = self.conv1(x) 437 | x = self.bn1(x) 438 | x = self.relu(x) 439 | x = self.maxpool(x) 440 | 441 | metrics_all = [] 442 | for layer in self.layer1: 443 | x, metrics = layer(x, clipval, noise_coeff) 444 | metrics_all.append(metrics) 445 | for layer in self.layer2: 446 | x, metrics = layer(x, clipval, noise_coeff) 447 | metrics_all.append(metrics) 448 | for layer in self.layer3: 449 | x, metrics = layer(x, clipval, noise_coeff) 450 | metrics_all.append(metrics) 451 | for layer in self.layer4: 452 | x, metrics = layer(x, clipval, noise_coeff) 453 | metrics_all.append(metrics) 454 | x = self.avgpool(x) 455 | x = torch.flatten(x, 1) 456 | x = self.fc(x) 457 | 458 | return x, metrics_all 459 | 460 | def forward(self, x, clipval=None, noise_coeff=1.0): 461 | return self._forward_impl(x, clipval, noise_coeff) 462 | 463 | 464 | def resnet18(num_classes, img_size, do_perturb): 465 | return ResNet(BasicBlock, [2, 2, 2, 2], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 466 | 467 | 468 | def resnet34(num_classes, img_size, do_perturb): 469 | return ResNet(BasicBlock, [3, 4, 6, 3], img_size=img_size, num_classes=num_classes, do_perturb=do_perturb) 470 | 471 | 472 | # def _resnet(arch, block, layers, pretrained, progress, **kwargs): 473 | # model = ResNet(block, layers, **kwargs) 474 | # if pretrained: 475 | # state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 476 | # model.load_state_dict(state_dict, strict=False) 477 | # return model 478 | 479 | 480 | # def resnet18(pretrained=False, progress=True, **kwargs): 481 | # r"""ResNet-18 model from 482 | # `"Deep Residual Learning for Image Recognition" `_ 483 | 484 | # Args: 485 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 486 | # progress (bool): If True, displays a progress bar of the download to stderr 487 | # """ 488 | # return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 489 | 490 | 491 | # def resnet34(pretrained=False, progress=True, **kwargs): 492 | # r"""ResNet-34 model from 493 | # `"Deep Residual Learning for Image Recognition" `_ 494 | 495 | # Args: 496 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 497 | # progress (bool): If True, displays a progress bar of the download to stderr 498 | # """ 499 | # return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 500 | 501 | 502 | # def resnet50(pretrained=False, progress=True, **kwargs): 503 | # r"""ResNet-50 model from 504 | # `"Deep Residual Learning for Image Recognition" `_ 505 | 506 | # Args: 507 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 508 | # progress (bool): If True, displays a progress bar of the download to stderr 509 | # """ 510 | # return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 511 | 512 | 513 | # def resnet101(pretrained=False, progress=True, **kwargs): 514 | # r"""ResNet-101 model from 515 | # `"Deep Residual Learning for Image Recognition" `_ 516 | 517 | # Args: 518 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 519 | # progress (bool): If True, displays a progress bar of the download to stderr 520 | # """ 521 | # return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 522 | 523 | 524 | # def resnet152(pretrained=False, progress=True, **kwargs): 525 | # r"""ResNet-152 model from 526 | # `"Deep Residual Learning for Image Recognition" `_ 527 | 528 | # Args: 529 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 530 | # progress (bool): If True, displays a progress bar of the download to stderr 531 | # """ 532 | # return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) 533 | --------------------------------------------------------------------------------