├── main.py ├── tester ├── __init__.py └── TestAcc.py ├── Normalizations ├── __init__.py ├── freeze_weight.py ├── AdaBN.py ├── BIN.py ├── change.py ├── ASRNormBN1d.py ├── ASRNormBN2d.py ├── ASRNormLN.py └── ASRNormIN.py ├── evaluate.py ├── optimizer ├── __init__.py ├── FGSM.py ├── CosineLRS.py ├── default.py └── ALRS.py ├── backbones ├── __init__.py ├── ConvNet.py ├── vggv2.py ├── utils.py ├── wrnv2.py ├── resnetv3.py ├── ShuffleNetv1.py ├── resnetv2.py ├── ShuffleNetv2.py ├── PyramidNet.py ├── resnet_imagenet.py ├── wrn.py ├── mobilenetv2.py ├── vgg.py └── RSC.py ├── README.md ├── data ├── __init__.py ├── mnist.py ├── svhn.py ├── usps.py ├── someset.py ├── ImageNet.py ├── PACS.py ├── CIFAR10C.py ├── mnistm.py └── cifar.py ├── .gitignore └── Solver.py /main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tester/__init__.py: -------------------------------------------------------------------------------- 1 | from .TestAcc import test_acc 2 | -------------------------------------------------------------------------------- /Normalizations/__init__.py: -------------------------------------------------------------------------------- 1 | from .ASRNormBN2d import ASRNormBN2d, build_ASRNormBN2d 2 | from .ASRNormBN1d import ASRNormBN1d, build_ASRNormBN1d 3 | from .ASRNormIN import ASRNormIN, build_ASRNormIN 4 | from .ASRNormLN import ASRNormLN, build_ASRNormLN 5 | 6 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from tester import test_acc 2 | import torch 3 | 4 | from data import get_CIFAR10_test 5 | 6 | loader = get_CIFAR10_test() 7 | from backbones import resnet32, ShuffleV2 8 | from Normalizations import ASRNormBN2d, ASRNormIN 9 | 10 | model = resnet32(num_classes=10) 11 | model.load_state_dict(torch.load('./student.pth')) 12 | test_acc(model, loader) 13 | -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .FGSM import FGSM 2 | from .ALRS import ALRS 3 | from .CosineLRS import CosineLRS 4 | from torch.optim import Adam, AdamW, SGD 5 | from .default import default_optimizer, default_lr_scheduler, PACS_optimizer 6 | 7 | __all__ = ['FGSM', 'AdamW', 'SGD', 'Adam', 'default_lr_scheduler', 'default_optimizer', 'CosineLRS', 'ALRS', 'PACS_optimizer'] 8 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv2 import * 2 | from .resnet import * 3 | # from .resnet_imagenet import * 4 | # from .resnetv2 import * 5 | # from .resnetv3 import * 6 | # from .RSC import * 7 | from .ShuffleNetv1 import * 8 | from .ShuffleNetv2 import * 9 | # from .vgg import * 10 | from .vggv2 import * 11 | from .wrn import * 12 | from .wrnv2 import * 13 | from .PyramidNet import * 14 | from .ConvNet import * 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASRNorm 2 | a pytorch implement of Adversarially Adaptive Normalization for Single Domain Generalization 3 | 4 | AdaBN improve about 0.1%, so change the BN can make a difference!! 5 | 6 | As for only 0.1%, I think that's because there are 6 domain both in training set and test set. If there are only 1 domain in test set, it can improve more. 7 | 8 | So, this proves that, ASRNorm can improve at least 0.1% because it can discriminate difference domain by dynamic parameters. 9 | -------------------------------------------------------------------------------- /optimizer/FGSM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class FGSM(Optimizer): 6 | def __init__(self, params, lr, ): 7 | dampening = 0 8 | weight_decay = 0 9 | nesterov = False 10 | maximize = False 11 | momentum = 0 12 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 13 | weight_decay=weight_decay, nesterov=nesterov, maximize=maximize) 14 | super(FGSM, self).__init__(params, defaults) 15 | self.lr = lr 16 | 17 | @torch.no_grad() 18 | def step(self, closure=None): 19 | for group in self.param_groups: 20 | for p in group['params']: 21 | if p.grad is not None: 22 | p.add_(-self.lr * p.grad.sign()) 23 | -------------------------------------------------------------------------------- /tester/TestAcc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | 5 | 6 | @torch.no_grad() 7 | def test_acc(model: nn.Module, loader: DataLoader, 8 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 9 | total_loss = 0 10 | total_acc = 0 11 | criterion = nn.CrossEntropyLoss().to(device) 12 | model.to(device).eval() 13 | denominator = 0 14 | for x, y in loader: 15 | x, y = x.to(device), y.to(device) 16 | pre = model(x) 17 | total_loss += criterion(pre, y).item() * y.shape[0] 18 | _, pre = torch.max(pre, dim=1) 19 | total_acc += torch.sum((pre == y)).item() 20 | denominator += y.shape[0] 21 | 22 | test_loss = total_loss / denominator 23 | test_accuracy = total_acc / denominator 24 | print(f'loss = {test_loss}, acc = {test_accuracy}') 25 | return test_loss, test_accuracy 26 | -------------------------------------------------------------------------------- /optimizer/CosineLRS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import cos, pi 4 | 5 | 6 | class CosineLRS(): 7 | def __init__(self, optimizer, max_epoch=300, lr_min=0, lr_max=0.05, warmup_epoch=30): 8 | self.optimizer = optimizer 9 | self.max_epoch = max_epoch 10 | self.min_lr = lr_min 11 | self.max_lr = lr_max 12 | self.warmup_epoch = warmup_epoch 13 | 14 | def step(self, current_epoch): 15 | if current_epoch < self.warmup_epoch: 16 | lr = self.max_lr * current_epoch / self.warmup_epoch 17 | else: 18 | lr = self.min_lr + (self.max_lr - self.min_lr) * (1 + cos(pi * (current_epoch - self.warmup_epoch) / 19 | (self.max_epoch - self.warmup_epoch))) / 2 20 | 21 | for param_group in self.optimizer.param_groups: 22 | param_group['lr'] = lr 23 | print(f'now lr = {lr}') 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import get_CIFAR100_test, get_CIFAR100_train, get_CIFAR10_train, get_CIFAR10_test 2 | from .someset import SomeDataSet, get_someset_loader 3 | from .CIFAR10C import get_cifar_10_c_loader 4 | from .PACS import get_PACS_train, get_PACS_test 5 | from .ImageNet import get_imagenet_loader, get_imagenet10_loader 6 | from .mnist import get_mnist_train, get_mnist_test 7 | from .usps import get_usps_train, get_usps_test 8 | from .svhn import get_svhn_train, get_svhn_test 9 | from .mnistm import get_mnist_m_train, get_mnist_m_test 10 | 11 | __all__ = ['get_CIFAR100_test', 'get_CIFAR100_train', 'get_CIFAR10_test', 'get_CIFAR10_train', 12 | 'SomeDataSet', 'get_someset_loader', 'get_cifar_10_c_loader', 13 | 'get_PACS_train', 'get_PACS_test', 'get_imagenet_loader', 'get_svhn_test', 'get_svhn_train', 14 | 'get_mnist_test', 'get_mnist_train', 'get_usps_test', 'get_usps_train', 15 | 'get_mnist_m_train', 'get_mnist_m_test' 16 | ] 17 | -------------------------------------------------------------------------------- /optimizer/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .ALRS import ALRS 4 | 5 | 6 | # cifar10-c WRN with 40 layers and widen factor 4 7 | # SGD with Nestrov momentum 0.9 and batch_size 128 8 | # learning 0.1 with cosine annealing 200 epoch 9 | # def default_optimizer(model: nn.Module, lr=0.1) -> torch.optim.Optimizer: 10 | # return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-5) 11 | 12 | def default_optimizer(model: nn.Module, lr=1e-4) -> torch.optim.Optimizer: 13 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6) 14 | 15 | 16 | # ResNet-18 pretrained on ImageNet 17 | # SGD with initial learning rate as 0.004 decays 10% after 24 epochs 18 | # batch_size 128 30 peochs 19 | def PACS_optimizer(model: nn.Module, lr=0.1, decay=0) -> torch.optim.Optimizer: 20 | return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=decay, nesterov=True) 21 | 22 | 23 | def default_lr_scheduler(optimizer): 24 | return ALRS(optimizer) 25 | -------------------------------------------------------------------------------- /Normalizations/freeze_weight.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_norm_layers(model:nn.Module, norm_name): 5 | norm_layers = [] 6 | for module in model.modules(): 7 | if isinstance(module, norm_name): 8 | norm_layers.append(module) 9 | elif isinstance(module, nn.ModuleList): 10 | for sub_module in module: 11 | if isinstance(sub_module, norm_name): 12 | norm_layers.append(sub_module) 13 | elif isinstance(module, nn.Sequential): 14 | for sub_module in module.children(): 15 | for layer in list(sub_module.modules()): 16 | if isinstance(layer, norm_name): 17 | layer.requires_grad_(True) 18 | # norm_layers.append(sub_module) 19 | 20 | return norm_layers 21 | 22 | 23 | def freeze_weights(model:nn.Module, norm_name): 24 | for param in model.parameters(): 25 | param.requires_grad = False 26 | for layer in get_norm_layers(model, norm_name): 27 | for param in layer.parameters(): 28 | param.requires_grad = True 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /data/mnist.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | __all__ = ['get_mnist_test', 'get_mnist_train'] 6 | 7 | 8 | def get_mnist_train(batch_size=256, 9 | num_workers=40, 10 | pin_memory=True, 11 | ): 12 | transform = transforms.Compose([ 13 | transforms.Resize((32, 32)), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5,), (0.5,)) 16 | ]) 17 | set = MNIST('../resources/mnist/', train=True, download=True, transform=transform) 18 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 19 | shuffle=True) 20 | return loader 21 | 22 | 23 | def get_mnist_test(batch_size=256, 24 | num_workers=40, 25 | pin_memory=True, ): 26 | transform = transforms.Compose([ 27 | transforms.Resize((32, 32)), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.5,), (0.5,)) 30 | ]) 31 | set = MNIST('../resources/mnist/', train=False, download=True, transform=transform) 32 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 33 | return loader 34 | 35 | 36 | -------------------------------------------------------------------------------- /data/svhn.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import SVHN 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | __all__ = ['get_svhn_test', 'get_svhn_train'] 6 | 7 | 8 | def get_svhn_train(batch_size=256, 9 | num_workers=40, 10 | pin_memory=True, 11 | ): 12 | transform = transforms.Compose([ 13 | transforms.Resize((32, 32)), 14 | transforms.Grayscale(num_output_channels=1), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.5,), (0.5,)) 17 | ]) 18 | set = SVHN('../resources/svhn/', split='train', download=True, transform=transform) 19 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 20 | shuffle=True) 21 | return loader 22 | 23 | 24 | def get_svhn_test(batch_size=256, 25 | num_workers=40, 26 | pin_memory=True, 27 | ): 28 | transform = transforms.Compose([ 29 | transforms.Resize((32, 32)), 30 | transforms.Grayscale(num_output_channels=1), 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.5,), (0.5,)) 33 | ]) 34 | set = SVHN('../resources/svhn/', split='test', download=True, transform=transform) 35 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 36 | return loader 37 | -------------------------------------------------------------------------------- /backbones/ConvNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['convnet'] 5 | 6 | 7 | class ConvNet(nn.Module): 8 | def __init__(self, dim=1, norm_layer=nn.BatchNorm2d, num_classes=10): 9 | super(ConvNet, self).__init__() 10 | self.conv1 = nn.Conv2d(dim, 16, kernel_size=3, stride=1, padding=1) 11 | self.norm1 = norm_layer(16) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 14 | 15 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) 16 | self.norm2 = norm_layer(32) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 19 | 20 | self.fc1 = nn.Linear(2048, 128) 21 | self.relu3 = nn.ReLU(inplace=True) 22 | self.fc2 = nn.Linear(128, num_classes) 23 | 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = self.norm1(x) 27 | x = self.relu1(x) 28 | x = self.maxpool1(x) 29 | 30 | x = self.conv2(x) 31 | x = self.norm2(x) 32 | x = self.relu2(x) 33 | x = self.maxpool2(x) 34 | 35 | x = x.view(x.size(0), -1) 36 | x = self.fc1(x) 37 | x = self.relu3(x) 38 | x = self.fc2(x) 39 | 40 | return x 41 | 42 | 43 | def convnet(dim, norm_layer, num_classes): 44 | return ConvNet(dim, norm_layer, num_classes) 45 | 46 | 47 | -------------------------------------------------------------------------------- /data/usps.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import USPS 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | __all__ = ['get_usps_test', 'get_usps_train'] 6 | 7 | 8 | def get_usps_train(batch_size=256, 9 | num_workers=40, 10 | pin_memory=True, 11 | ): 12 | transform = transforms.Compose([ 13 | transforms.Resize((32, 32)), 14 | transforms.Grayscale(num_output_channels=1), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.5,), (0.5,)) 17 | ]) 18 | set = USPS('../resources/usps/', train=True, download=True, transform=transform) 19 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 20 | shuffle=True) 21 | return loader 22 | 23 | 24 | def get_usps_test(batch_size=256, 25 | num_workers=40, 26 | pin_memory=True, 27 | ): 28 | transform = transforms.Compose([ 29 | transforms.Resize((32, 32)), 30 | transforms.Grayscale(num_output_channels=1), 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.5,), (0.5,)) 33 | ]) 34 | set = USPS('../resources/usps/', train=False, download=True, transform=transform) 35 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 36 | return loader 37 | 38 | 39 | -------------------------------------------------------------------------------- /optimizer/ALRS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ALRS(): 5 | ''' 6 | proposer: Huanran Chen 7 | theory: landscape 8 | Bootstrap Generalization Ability from Loss Landscape Perspective 9 | ''' 10 | 11 | def __init__(self, optimizer, loss_threshold=0.01, loss_ratio_threshold=0.01, decay_rate=0.97, 12 | max_epoch=200, lr_min=0, lr_max=0.1, warmup_epoch=0): 13 | self.optimizer = optimizer 14 | self.loss_threshold = loss_threshold 15 | self.decay_rate = decay_rate 16 | self.loss_ratio_threshold = loss_ratio_threshold 17 | 18 | self.warmup_epoch = warmup_epoch 19 | self.lr_min = lr_min 20 | self.lr_max = lr_max 21 | self.max_epoch = max_epoch 22 | 23 | self.last_loss = 999 24 | 25 | def step(self, loss, current_epoch): 26 | if current_epoch < self.warmup_epoch: 27 | now_lr = self.lr_max * current_epoch / self.warmup_epoch 28 | for group in self.optimizer.param_groups: 29 | group['lr'] = now_lr 30 | print(f'now lr = {now_lr}') 31 | 32 | else: 33 | delta = self.last_loss - loss 34 | if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold: 35 | for group in self.optimizer.param_groups: 36 | group['lr'] *= self.decay_rate 37 | now_lr = group['lr'] 38 | print(f'now lr = {now_lr}') 39 | else: 40 | now_lr = self.optimizer.param_groups[0]['lr'] 41 | print(f'now lr = {now_lr}') 42 | self.last_loss = loss 43 | -------------------------------------------------------------------------------- /data/someset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this file aims to read any dataset satisfied that: 3 | 1.all the images are in one folder 4 | 2.only a dict to store ground truth. Keys are image names, values are ground truth labels. 5 | ''' 6 | 7 | import os 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | from PIL import Image 12 | import numpy as np 13 | 14 | 15 | class SomeDataSet(Dataset): 16 | def __init__(self, img_path, gt_path): 17 | self.transform = transforms.Compose([ 18 | # transforms.RandomResizedCrop(size=(224, 224), scale=(0.7, 1)), 19 | # transforms.AutoAugment(), 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 22 | ]) 23 | self.images = [img for img in os.listdir(img_path) if img.endswith('.jpg')] 24 | self.gt = np.load(gt_path, allow_pickle=True).item() 25 | self.img_path = img_path 26 | 27 | def __len__(self): 28 | return len(self.images) 29 | 30 | def __getitem__(self, item): 31 | now = self.images[item] 32 | now_img = Image.open(os.path.join(self.img_path, now)) # numpy 33 | return self.transform(now_img), self.gt[now] 34 | 35 | 36 | def get_someset_loader(img_path, 37 | gt_path, 38 | batch_size=128, 39 | num_workers=8, 40 | pin_memory=False, ): 41 | set = SomeDataSet(img_path=img_path, gt_path=gt_path) 42 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) 43 | return loader 44 | -------------------------------------------------------------------------------- /Normalizations/AdaBN.py: -------------------------------------------------------------------------------- 1 | from train import Model 2 | import torch 3 | from data.data import get_test_loader 4 | from tqdm import tqdm 5 | from data.dataUtils import write_result 6 | 7 | 8 | from train import Model 9 | import torch 10 | from data.data import get_test_loader 11 | from tqdm import tqdm 12 | from data.dataUtils import write_result 13 | 14 | 15 | def train(batch_size=64, total_epoch=3): 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | model = Model().to(device) 18 | state = torch.load('adabn.pth', map_location=device) 19 | model.model.load_state_dict(state['model_state']) 20 | train_image_path='../input/nico2022/track_1/public_dg_0416/train/' 21 | valid_image_path='../input/nico2022/track_1/public_dg_0416/train/' 22 | label2id_path='../input/nico2022/dg_label_id_mapping.json' 23 | test_image_path='../input/nico2022/track_1/public_dg_0416/public_test_flat/' 24 | loader,_ = get_test_loader(batch_size=batch_size, 25 | transforms=None, 26 | label2id_path=label2id_path, 27 | test_image_path=test_image_path) 28 | model.train() 29 | with torch.no_grad(): 30 | for epoch in range(1, total_epoch + 1): 31 | # train 32 | pbar = tqdm(loader) 33 | for x, _ in pbar: 34 | x=x.to(device) 35 | x = model(x) 36 | 37 | model.eval() 38 | result = {} 39 | with torch.no_grad(): 40 | for x, name in tqdm(loader): 41 | x = x.to(device) 42 | y = model(x) # N, D 43 | _, y = torch.max(y, dim=1) # (N,) 44 | 45 | for i, name in enumerate(list(name)): 46 | result[name] = y[i].item() 47 | 48 | write_result(result) 49 | torch.save(model.model.state_dict(), 'model.pth') 50 | 51 | 52 | if __name__ == '__main__': 53 | train() -------------------------------------------------------------------------------- /backbones/vggv2.py: -------------------------------------------------------------------------------- 1 | """vgg in pytorch 2 | 3 | 4 | [1] Karen Simonyan, Andrew Zisserman 5 | 6 | Very Deep Convolutional Networks for Large-Scale Image Recognition. 7 | https://arxiv.org/abs/1409.1556v6 8 | """ 9 | '''VGG11/13/16/19 in Pytorch.''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | cfg = { 15 | 'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 16 | 'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 17 | 'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 18 | 'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 19 | } 20 | 21 | class VGG(nn.Module): 22 | 23 | def __init__(self, features, num_class=100): 24 | super().__init__() 25 | self.features = features 26 | 27 | self.classifier = nn.Sequential( 28 | nn.Linear(512, 4096), 29 | nn.ReLU(inplace=True), 30 | nn.Dropout(), 31 | nn.Linear(4096, 4096), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(), 34 | nn.Linear(4096, num_class) 35 | ) 36 | 37 | def forward(self, x): 38 | output = self.features(x) 39 | output = output.view(output.size()[0], -1) 40 | output = self.classifier(output) 41 | 42 | return output 43 | 44 | def make_layers(cfg, batch_norm=False): 45 | layers = [] 46 | 47 | input_channel = 3 48 | for l in cfg: 49 | if l == 'M': 50 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 51 | continue 52 | 53 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 54 | 55 | if batch_norm: 56 | layers += [nn.BatchNorm2d(l)] 57 | 58 | layers += [nn.ReLU(inplace=True)] 59 | input_channel = l 60 | 61 | return nn.Sequential(*layers) 62 | 63 | def vgg11_bn(): 64 | return VGG(make_layers(cfg['A'], batch_norm=True)) 65 | 66 | def vgg13_bn(): 67 | return VGG(make_layers(cfg['B'], batch_norm=True)) 68 | 69 | def vgg16_bn(): 70 | return VGG(make_layers(cfg['D'], batch_norm=True)) 71 | 72 | def vgg19_bn(): 73 | return VGG(make_layers(cfg['E'], batch_norm=True)) 74 | 75 | -------------------------------------------------------------------------------- /Normalizations/BIN.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | from torch.nn.parameter import Parameter 4 | from torch.nn import functional as F 5 | import torch 6 | 7 | 8 | class _BacthInstanceNorm(_BatchNorm): 9 | def __init__(self, 10 | num_features, 11 | eps=1e-5, 12 | momentum=0.1, 13 | affine=True): 14 | super(_BacthInstanceNorm, self).__init__(num_features, eps, momentum, affine) 15 | self.gate = Parameter(torch.Tensor(num_features)) 16 | self.gate.data.fill_(1) 17 | setattr(self.gate, 'bin_gate', True) 18 | 19 | def forward(self, input): 20 | self._check_input_dim(input) 21 | 22 | # Batch norm 23 | if self.affine: 24 | bn_w = self.weight * self.gate 25 | else: 26 | bn_w = self.gate 27 | 28 | out_bn = F.batch_norm( 29 | input, self.running_mean, self.runing_var, bn_w, self.bias, 30 | self.training, self.momentum, self.eps 31 | ) 32 | 33 | # Instance norm 34 | b, c = input.size(0), input.size(1) 35 | if self.affine: 36 | in_w = self.weight * (1 - self.gate) 37 | else: 38 | in_w = 1 - self.gate 39 | 40 | input = input.view(1, b*c, *input.size()[2:]) 41 | out_in = F.batch_norm( 42 | input, None, None, None, None, 43 | True, self.momentum, self.eps 44 | ) 45 | out_in = out_in.view(b, c, *input.size()[2:]) 46 | out_in.mul_(in_w[None, :, None, None]) 47 | 48 | return out_bn + out_in 49 | 50 | 51 | class BatchInstanceNorm1d(_BacthInstanceNorm): 52 | def _check_input_dim(self, input): 53 | if input.dim() != 2 and input.dim() != 3: 54 | raise ValueError('expected 2D or 3D input (got {}D input)'.format(input.dim())) 55 | 56 | 57 | class BatchInstanceNorm2d(_BacthInstanceNorm): 58 | def _check_input_dim(self, input): 59 | if input.dim() != 4: 60 | raise ValueError('expected 4D input (got {}D input)'.format(input.dim())) 61 | 62 | 63 | class BatchInstanceNorm3d(_BacthInstanceNorm): 64 | def _check_input_dim(self, input): 65 | if input.dim() != 5: 66 | raise ValueError('expected 5D input (got {}D input)'.format(input.dim())) 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /Normalizations/change.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet18,resnet50,resnet101 2 | import torch.fx as fx 3 | from torch.fx.node import Argument, Target 4 | from torch.nn.utils.fusion import fuse_conv_bn_eval 5 | from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast 6 | from ASRBnrom import ASRNorm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.fx import replace_pattern 11 | from torch.fx.passes.shape_prop import ShapeProp 12 | import copy 13 | from collections import defaultdict 14 | import torch.utils.mkldnn as th_mkldnn 15 | import operator 16 | import time 17 | import logging 18 | from enum import Enum 19 | 20 | 21 | def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): 22 | if len(node.args) == 0: 23 | return False 24 | nodes: Tuple[Any, fx.Node] = (node.args[0], node) 25 | for expected_type, current_node in zip(pattern, nodes): 26 | if not isinstance(current_node, fx.Node): 27 | return False 28 | if current_node.op != 'call_module': 29 | return False 30 | if not isinstance(current_node.target, str): 31 | return False 32 | if current_node.target not in modules: 33 | return False 34 | if type(modules[current_node.target]) is not expected_type: 35 | return False 36 | return True 37 | 38 | def _parent_name(target : str) -> Tuple[str, str]: 39 | """ 40 | Splits a qualname into parent path and last atom. 41 | For example, `foo.bar.baz` -> (`foo.bar`, `baz`) 42 | """ 43 | *parent, name = target.rsplit('.', 1) 44 | return parent[0] if parent else '', name 45 | 46 | def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): 47 | assert(isinstance(node.target, str)) 48 | parent_name, name = _parent_name(node.target) 49 | modules[node.target] = new_module 50 | setattr(modules[parent_name], name, new_module) 51 | 52 | 53 | def relace_bn_to_asrn(model): 54 | from torch.fx import symbolic_trace 55 | traced: torch.fx.GraphModule = symbolic_trace(model) 56 | traced.graph.print_tabular() 57 | modules = dict(traced.named_modules()) 58 | new_graph = copy.deepcopy(traced.graph) 59 | for n in traced.graph.nodes: 60 | if n.op == "call_module": 61 | if type(modules[n.target]) == nn.BatchNorm2d: 62 | replace_node_module(n,modules,ASRNorm(modules[n.target].num_features)) 63 | return torch.fx.GraphModule(traced,new_graph) -------------------------------------------------------------------------------- /backbones/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | 7 | 8 | def cal_param_size(model): 9 | return sum([i.numel() for i in model.parameters()]) 10 | 11 | 12 | count_ops = 0 13 | 14 | 15 | def measure_layer(layer, x, multi_add=1): 16 | delta_ops = 0 17 | type_name = str(layer)[: str(layer).find("(")].strip() 18 | 19 | if type_name in ["Conv2d"]: 20 | out_h = int( 21 | (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 22 | ) 23 | out_w = int( 24 | (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 25 | ) 26 | delta_ops = ( 27 | layer.in_channels 28 | * layer.out_channels 29 | * layer.kernel_size[0] 30 | * layer.kernel_size[1] 31 | * out_h 32 | * out_w 33 | // layer.groups 34 | * multi_add 35 | ) 36 | 37 | ### ops_linear 38 | elif type_name in ["Linear"]: 39 | weight_ops = layer.weight.numel() * multi_add 40 | bias_ops = 0 41 | delta_ops = weight_ops + bias_ops 42 | 43 | global count_ops 44 | count_ops += delta_ops 45 | return 46 | 47 | 48 | def is_leaf(module): 49 | return sum(1 for x in module.children()) == 0 50 | 51 | 52 | def should_measure(module): 53 | if str(module).startswith("Sequential"): 54 | return False 55 | if is_leaf(module): 56 | return True 57 | return False 58 | 59 | 60 | def cal_multi_adds(model, shape=(2, 3, 32, 32)): 61 | global count_ops 62 | count_ops = 0 63 | data = torch.zeros(shape) 64 | 65 | def new_forward(m): 66 | def lambda_forward(x): 67 | measure_layer(m, x) 68 | return m.old_forward(x) 69 | 70 | return lambda_forward 71 | 72 | def modify_forward(model): 73 | for child in model.children(): 74 | if should_measure(child): 75 | child.old_forward = child.forward 76 | child.forward = new_forward(child) 77 | else: 78 | modify_forward(child) 79 | 80 | def restore_forward(model): 81 | for child in model.children(): 82 | if is_leaf(child) and hasattr(child, "old_forward"): 83 | child.forward = child.old_forward 84 | child.old_forward = None 85 | else: 86 | restore_forward(child) 87 | 88 | modify_forward(model) 89 | model.forward(data) 90 | restore_forward(model) 91 | 92 | return count_ops 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .idea 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | -------------------------------------------------------------------------------- /data/ImageNet.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageNet 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from typing import Tuple 5 | import random 6 | 7 | __all__ = ['get_imagenet10_loader', 'get_imagenet_loader'] 8 | 9 | 10 | class ImageNet10(ImageNet): 11 | def __init__(self, 12 | *args, 13 | target_class: Tuple[int], 14 | maximal_images: int or None = None, 15 | **kwargs): 16 | super(ImageNet10, self).__init__(*args, **kwargs) 17 | self.target_class = list(target_class) 18 | result = [] 19 | for x, y in self.samples: 20 | if y in self.target_class: 21 | result.append((x, y)) 22 | random.shuffle(result) 23 | self.maximal_images = maximal_images 24 | self.samples = result 25 | 26 | def __len__(self): 27 | if self.maximal_images is not None: 28 | return self.maximal_images 29 | return len(self.samples) 30 | 31 | 32 | def get_transform(augment=False): 33 | if not augment: 34 | transform = transforms.Compose([ 35 | transforms.Resize((256, 256)), 36 | transforms.ToTensor(), 37 | ]) 38 | else: 39 | transform = transforms.Compose([ 40 | transforms.Resize((256, 256)), 41 | # transforms.AutoAugment(transforms.AutoAugmentPolicy), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | ]) 45 | return transform 46 | 47 | 48 | def get_imagenet_loader( 49 | root='resources/ImageNet/', 50 | split='val', 51 | augment=False, 52 | batch_size=1, 53 | num_workers=8, 54 | pin_memory=False, 55 | shuffle=False, 56 | ): 57 | assert split in ['val', 'train'] 58 | transform = get_transform(augment) 59 | set = ImageNet(root, split, transform=transform) 60 | loader = DataLoader(set, batch_size=batch_size, 61 | num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle) 62 | return loader 63 | 64 | 65 | def get_imagenet10_loader( 66 | target_class=(0, 100, 200, 300, 400, 500, 600, 700, 800, 900), 67 | maximum_images=None, 68 | root='resources/ImageNet/', 69 | split='val', 70 | augment=False, 71 | batch_size=1, 72 | num_workers=8, 73 | pin_memory=False, 74 | shuffle=False, 75 | ): 76 | assert split in ['val', 'train'] 77 | transform = get_transform(augment) 78 | set = ImageNet10(root, split, target_class=target_class, transform=transform, 79 | maximal_images=maximum_images) 80 | loader = DataLoader(set, batch_size=batch_size, 81 | num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle) 82 | return loader -------------------------------------------------------------------------------- /backbones/wrnv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class WideBasic(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels, stride=1): 8 | super().__init__() 9 | self.residual = nn.Sequential( 10 | nn.BatchNorm2d(in_channels), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d( 13 | in_channels, 14 | out_channels, 15 | kernel_size=3, 16 | stride=stride, 17 | padding=1 18 | ), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(), 22 | nn.Conv2d( 23 | out_channels, 24 | out_channels, 25 | kernel_size=3, 26 | stride=1, 27 | padding=1 28 | ) 29 | ) 30 | 31 | self.shortcut = nn.Sequential() 32 | 33 | if in_channels != out_channels or stride != 1: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels, 1, stride=stride) 36 | ) 37 | 38 | def forward(self, x): 39 | 40 | residual = self.residual(x) 41 | shortcut = self.shortcut(x) 42 | 43 | return residual + shortcut 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, num_classes, block, depth=50, widen_factor=1): 47 | super().__init__() 48 | 49 | self.depth = depth 50 | k = widen_factor 51 | l = int((depth - 4) / 6) 52 | self.in_channels = 16 53 | self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1) 54 | self.conv2 = self._make_layer(block, 16 * k, l, 1) 55 | self.conv3 = self._make_layer(block, 32 * k, l, 2) 56 | self.conv4 = self._make_layer(block, 64 * k, l, 2) 57 | self.bn = nn.BatchNorm2d(64 * k) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 60 | self.linear = nn.Linear(64 * k, num_classes) 61 | 62 | def forward(self, x): 63 | x = self.init_conv(x) 64 | x = self.conv2(x) 65 | x = self.conv3(x) 66 | x = self.conv4(x) 67 | x = self.bn(x) 68 | x = self.relu(x) 69 | x = self.avg_pool(x) 70 | x = x.view(x.size(0), -1) 71 | x = self.linear(x) 72 | 73 | return x 74 | 75 | def _make_layer(self, block, out_channels, num_blocks, stride): 76 | """make resnet layers(by layer i didnt mean this 'layer' was the 77 | same as a neuron netowork layer, ex. conv layer), one layer may 78 | contain more than one residual block 79 | 80 | Args: 81 | block: block type, basic block or bottle neck block 82 | out_channels: output depth channel number of this layer 83 | num_blocks: how many blocks per layer 84 | stride: the stride of the first block of this layer 85 | 86 | Return: 87 | return a resnet layer 88 | """ 89 | 90 | # we have num_block blocks per layer, the first block 91 | # could be 1 or 2, other blocks would always be 1 92 | strides = [stride] + [1] * (num_blocks - 1) 93 | layers = [] 94 | for stride in strides: 95 | layers.append(block(self.in_channels, out_channels, stride)) 96 | self.in_channels = out_channels 97 | 98 | return nn.Sequential(*layers) 99 | 100 | 101 | # Table 9: Best WRN performance over various datasets, single run results. 102 | def wideresnet(depth=40, widen_factor=10): 103 | net = WideResNet(100, WideBasic, depth=depth, widen_factor=widen_factor) 104 | return net -------------------------------------------------------------------------------- /data/PACS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 3 | import torchvision 4 | from torchvision import transforms 5 | from tllib.vision.datasets import PACS 6 | from tllib.vision.transforms import ResizeImage 7 | from tllib.vision.datasets.imagelist import MultipleDomainsDataset 8 | 9 | """ 10 | install tllib: 11 | git clone git@github.com:thuml/Transfer-Learning-Library.git 12 | python setup.py install 13 | pip install -r requirements.txt 14 | """ 15 | 16 | 17 | class NPACS(PACS): 18 | def __init__(self, root: str, task: str, split='all', download=True, **kwargs): 19 | super(NPACS, self).__init__(root, task, split, download, **kwargs) 20 | 21 | def __getitem__(self, index): 22 | img, target = super(NPACS, self).__getitem__(index) 23 | return img, target 24 | 25 | 26 | def get_pacs_dataset(target_domain, root="./data/pacs", download=True, augment=True): 27 | assert target_domain in ["P", "A", "C", "S"] 28 | 29 | if augment: 30 | test_transform = transforms.Compose( 31 | [ 32 | transforms.RandomResizedCrop((227, 227), scale=(0.7, 1.0)), 33 | transforms.RandomHorizontalFlip(), 34 | # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 35 | # transforms.RandomGrayscale(), 36 | transforms.AutoAugment(), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 39 | 40 | ] 41 | ) 42 | else: 43 | test_transform = transforms.Compose( 44 | [ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 47 | 48 | ] 49 | ) 50 | 51 | train_transform = transforms.Compose( 52 | [ 53 | transforms.RandomResizedCrop((227, 227), scale=(0.7, 1.0)), 54 | transforms.RandomHorizontalFlip(), 55 | # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 56 | # transforms.RandomGrayscale(), 57 | transforms.AutoAugment(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 60 | 61 | ] 62 | ) 63 | test_dataset = NPACS(root=root, 64 | task=target_domain, 65 | transform=test_transform, 66 | download=download) 67 | 68 | source_domain = [i for i in ["P", "A", "C", "S"] if target_domain != i] 69 | 70 | train_dataset = [] 71 | for domain in source_domain: 72 | train_dataset.append(NPACS(root=root, 73 | task=domain, 74 | transform=train_transform if augment else test_transform, 75 | download=download)) 76 | train_dataset = ConcatDataset(train_dataset) 77 | return train_dataset, test_dataset 78 | 79 | 80 | def get_PACS_train(batch_size=128, 81 | num_workers=0, 82 | pin_memory=True, 83 | augment=True, 84 | target_domain="A" 85 | ): 86 | set, _ = get_pacs_dataset(root='./resources/PACS', target_domain=target_domain, augment=augment) 87 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 88 | shuffle=True) 89 | return loader 90 | 91 | 92 | def get_PACS_test(batch_size=128, 93 | num_workers=0, 94 | pin_memory=True, 95 | augment=False, 96 | target_domain="A" 97 | ): 98 | _, set = get_pacs_dataset(root='./resources/PACS', target_domain=target_domain, augment=augment) 99 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 100 | shuffle=True) 101 | return loader 102 | -------------------------------------------------------------------------------- /data/CIFAR10C.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Subset, DataLoader 6 | from torchvision import datasets, transforms 7 | 8 | """ 9 | Please download data on 10 | https://zenodo.org/record/2535967 11 | """ 12 | 13 | __all__ = [ 14 | 'get_cifar_10_c_loader', 15 | ] 16 | 17 | """ 18 | level 1: brightness, contrast, defocus blur, elastic transform, fog 19 | level 2: JPEG Compression(?), Pixelate(?), Gaussian Noise, Snow(?) 20 | level 3: Frost, Gaussian Blur, Motion Blur 21 | level 4: Zoom Blur, Speckle Noise, Impulse Noise, Glass Blur 22 | level 5: Shot Noise, Spatter, Saturate 23 | """ 24 | 25 | # level1 10000 level2 10000~20000 level3 20000~30000 level4 30000~40000 level5 40000~50000 26 | 27 | corruptions = [ 28 | 'glass_blur', 29 | 'gaussian_noise', 30 | 'shot_noise', 31 | 'speckle_noise', 32 | 'impulse_noise', 33 | 'defocus_blur', 34 | 'gaussian_blur', 35 | 'motion_blur', 36 | 'zoom_blur', 37 | 'snow', 38 | 'fog', 39 | 'brightness', 40 | 'contrast', 41 | 'elastic_transform', 42 | 'pixelate', 43 | 'jpeg_compression', 44 | 'spatter', 45 | 'saturate', 46 | 'frost', 47 | ] 48 | 49 | 50 | class CIFAR10C(datasets.VisionDataset): 51 | def __init__(self, 52 | name: str, 53 | root: str = './resources/CIFAR10-C/', 54 | transform=None, target_transform=None): 55 | assert name in corruptions 56 | super(CIFAR10C, self).__init__( 57 | root, transform=transform, 58 | target_transform=target_transform 59 | ) 60 | data_path = os.path.join(root, name + '.npy') 61 | target_path = os.path.join(root, 'labels.npy') 62 | 63 | self.data = np.load(data_path) 64 | self.targets = np.load(target_path) 65 | # if you want to only test a small mount of data, uncomment the following codes 66 | # self.data = np.concatenate([self.data[0:1000], self.data[10000:11000], 67 | # self.data[20000:21000], self.data[30000:31000], 68 | # self.data[40000:41000]]) 69 | # self.targets = np.concatenate([self.targets[0:1000], self.targets[10000:11000], 70 | # self.targets[20000:21000], self.targets[30000:31000], 71 | # self.targets[40000:41000]]) 72 | 73 | def __getitem__(self, index): 74 | img, targets = self.data[index], self.targets[index] 75 | img = Image.fromarray(img) 76 | 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | if self.target_transform is not None: 80 | targets = self.target_transform(targets) 81 | 82 | return img, targets 83 | 84 | def __len__(self): 85 | return len(self.data) 86 | 87 | 88 | def extract_subset(dataset, num_subset: int, random_subset: bool): 89 | if random_subset: 90 | random.seed(0) 91 | indices = random.sample(list(range(len(dataset))), num_subset) 92 | else: 93 | indices = [i for i in range(num_subset)] 94 | return Subset(dataset, indices) 95 | 96 | 97 | def get_cifar_10_c_loader(name: str = 'gaussian_blur', 98 | augment=False, 99 | batch_size=128, 100 | shuffle=False, 101 | num_workers=8, 102 | pin_memory=False): 103 | if not augment: 104 | transform = transforms.Compose([ 105 | transforms.ToTensor(), 106 | ]) 107 | else: 108 | transform = transforms.Compose([ 109 | # transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), 110 | transforms.RandomHorizontalFlip(), 111 | transforms.ToTensor(), 112 | ]) 113 | set = CIFAR10C(name, transform=transform) 114 | loader = DataLoader(set, batch_size=batch_size, shuffle=shuffle, 115 | num_workers=num_workers, pin_memory=pin_memory) 116 | return loader 117 | -------------------------------------------------------------------------------- /Normalizations/ASRNormBN1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | class ASRNormBN1d(nn.Module): 8 | def __init__(self, dim, eps=1e-6, init_beta=None, init_gamma=None): 9 | ''' 10 | 11 | :param dim: C of N,C 12 | ''' 13 | super(ASRNormBN1d, self).__init__() 14 | self.eps = eps 15 | self.num_channels = dim 16 | self.stan_mid_channel = self.num_channels // 2 17 | self.rsc_mid_channel = self.num_channels // 16 18 | 19 | self.relu = nn.ReLU(True) 20 | self.tanh = nn.Tanh() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | self.standard_encoder = nn.Linear(dim, self.stan_mid_channel) # 16 24 | self.rescale_encoder = nn.Linear(dim, self.rsc_mid_channel) 25 | 26 | # standardization 27 | self.standard_mean_decoder = nn.Sequential( 28 | self.relu, 29 | nn.Linear(self.stan_mid_channel, dim) 30 | ) 31 | 32 | self.standard_std_decoder = nn.Sequential( 33 | self.relu, 34 | nn.Linear(self.stan_mid_channel, dim), 35 | self.relu 36 | ) 37 | 38 | # Rescaling 39 | self.rescale_beta_decoder = nn.Sequential( 40 | self.relu, 41 | nn.Linear(self.rsc_mid_channel, dim), 42 | self.tanh 43 | ) 44 | 45 | self.rescale_gamma_decoder = nn.Sequential( 46 | self.relu, 47 | nn.Linear(self.rsc_mid_channel, dim), 48 | self.sigmoid 49 | ) 50 | 51 | self.lambda_mu = nn.Parameter(torch.empty(1)) 52 | self.lambda_sigma = nn.Parameter(torch.empty(1)) 53 | 54 | self.lambda_beta = nn.Parameter(torch.empty(1)) 55 | self.lambda_gamma = nn.Parameter(torch.empty(1)) 56 | 57 | self.bias_beta = nn.Parameter(torch.empty(dim)) 58 | self.bias_gamma = nn.Parameter(torch.empty(dim)) 59 | 60 | self.drop_out = nn.Dropout(p=0.3) 61 | 62 | # init lambda and bias 63 | with torch.no_grad(): 64 | init.constant_(self.lambda_mu, self.sigmoid(torch.tensor(-3))) 65 | init.constant_(self.lambda_sigma, self.sigmoid(torch.tensor(-3))) 66 | init.constant_(self.lambda_beta, self.sigmoid(torch.tensor(-5))) 67 | init.constant_(self.lambda_gamma, self.sigmoid(torch.tensor(-5))) 68 | 69 | if init_beta is None: 70 | init.constant_(self.bias_beta, 0.) 71 | else: 72 | self.bias_beta.copy_(init_beta) 73 | if init_gamma is None: 74 | init.constant_(self.bias_gamma, 1.) 75 | else: 76 | self.bias_gamma.copy_(init_gamma) 77 | 78 | 79 | 80 | def forward(self, x): 81 | ''' 82 | 83 | :param x: N,C 84 | :return: 85 | ''' 86 | N, C = x.size() 87 | x_mean = torch.mean(x, dim=0) 88 | x_std = torch.sqrt(torch.var(x, dim=0)) + self.eps 89 | 90 | # standardization 91 | x_standard_mean = self.standard_mean_decoder(self.standard_encoder(self.drop_out(x_mean.view(1, -1)))).squeeze() 92 | x_standard_std = self.standard_std_decoder(self.standard_encoder(self.drop_out(x_std.view(1, -1)))).squeeze() 93 | 94 | lambda_sigma = self.sigmoid(self.lambda_sigma) 95 | lambda_mu = self.sigmoid(self.lambda_mu) 96 | 97 | mean = lambda_mu * x_standard_mean + (1 - lambda_mu) * x_mean 98 | std = lambda_sigma * x_standard_std + (1 - lambda_sigma) * x_std 99 | 100 | mean = mean.reshape((1, C)) 101 | std = std.reshape((1, C)) 102 | 103 | x = (x - mean) / std 104 | 105 | # rescaling 106 | x_rescaling_beta = self.rescale_beta_decoder(self.rescale_encoder(x_mean.view(1, -1))).squeeze() 107 | x_rescaling_gamma = self.rescale_gamma_decoder(self.rescale_encoder(x_std.view(1, -1))).squeeze() 108 | 109 | beta = self.lambda_beta * x_rescaling_beta + self.bias_beta 110 | gamma = self.lambda_gamma * x_rescaling_gamma + self.bias_gamma 111 | 112 | beta = beta.reshape((1, C)) 113 | gamma = gamma.reshape((1, C)) 114 | 115 | x = x * gamma + beta 116 | 117 | return x 118 | 119 | 120 | def set_module(model, name, norm_layer): 121 | """Replace module of model by name with multi-level path""" 122 | path = name.split('.') 123 | cur = model 124 | for p in path[:-1]: 125 | cur = getattr(cur, p) 126 | setattr(cur, path[-1], norm_layer) 127 | 128 | 129 | # 先试试行不行,如果行,再进行修改 130 | def build_ASRNormBN1d(model:nn.Module, init=False): 131 | for name, module in model.named_modules(): 132 | if isinstance(module, nn.BatchNorm2d): 133 | set_module(model, name, ASRNormBN1d(module.num_features, init_beta=(module.bias if init else None), init_gamma=(module.weight if init else None))) 134 | -------------------------------------------------------------------------------- /Normalizations/ASRNormBN2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | class ASRNormBN2d(nn.Module): 8 | def __init__(self, dim, eps=1e-6, init_beta=None, init_gamma=None): 9 | ''' 10 | 11 | :param dim: C of N,C,H,D 12 | ''' 13 | super(ASRNormBN2d, self).__init__() 14 | self.eps = eps 15 | self.num_channels = dim 16 | self.stan_mid_channel = self.num_channels // 2 17 | self.rsc_mid_channel = self.num_channels // 16 18 | 19 | self.relu = nn.ReLU(True) 20 | self.tanh = nn.Tanh() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | self.standard_encoder = nn.Linear(dim, self.stan_mid_channel) # 16 24 | self.rescale_encoder = nn.Linear(dim, self.rsc_mid_channel) 25 | 26 | # standardization 27 | self.standard_mean_decoder = nn.Sequential( 28 | self.relu, 29 | nn.Linear(self.stan_mid_channel, dim) 30 | ) 31 | 32 | self.standard_std_decoder = nn.Sequential( 33 | self.relu, 34 | nn.Linear(self.stan_mid_channel, dim), 35 | self.relu 36 | ) 37 | 38 | # Rescaling 39 | self.rescale_beta_decoder = nn.Sequential( 40 | self.relu, 41 | nn.Linear(self.rsc_mid_channel, dim), 42 | self.tanh 43 | ) 44 | 45 | self.rescale_gamma_decoder = nn.Sequential( 46 | self.relu, 47 | nn.Linear(self.rsc_mid_channel, dim), 48 | self.sigmoid 49 | ) 50 | 51 | self.lambda_mu = nn.Parameter(torch.empty(1)) 52 | self.lambda_sigma = nn.Parameter(torch.empty(1)) 53 | 54 | self.lambda_beta = nn.Parameter(torch.empty(1)) 55 | self.lambda_gamma = nn.Parameter(torch.empty(1)) 56 | 57 | self.bias_beta = nn.Parameter(torch.empty(dim)) 58 | self.bias_gamma = nn.Parameter(torch.empty(dim)) 59 | 60 | self.drop_out = nn.Dropout(p=0.3) 61 | 62 | # init lambda and bias 63 | with torch.no_grad(): 64 | init.constant_(self.lambda_mu, self.sigmoid(torch.tensor(-3))) 65 | init.constant_(self.lambda_sigma, self.sigmoid(torch.tensor(-3))) 66 | init.constant_(self.lambda_beta, self.sigmoid(torch.tensor(-5))) 67 | init.constant_(self.lambda_gamma, self.sigmoid(torch.tensor(-5))) 68 | 69 | if init_beta is None: 70 | init.constant_(self.bias_beta, 0.) 71 | else: 72 | self.bias_beta.copy_(init_beta) 73 | if init_gamma is None: 74 | init.constant_(self.bias_gamma, 1.) 75 | else: 76 | self.bias_gamma.copy_(init_gamma) 77 | 78 | 79 | 80 | def forward(self, x): 81 | ''' 82 | 83 | :param x: N,C,H,D 84 | :return: 85 | ''' 86 | N, C, H, W = x.size() 87 | x_mean = torch.mean(x, dim=(0, 2, 3)) 88 | x_std = torch.sqrt(torch.var(x, dim=(0, 2, 3))) + self.eps 89 | 90 | # standardization 91 | x_standard_mean = self.standard_mean_decoder(self.standard_encoder(self.drop_out(x_mean.view(1, -1)))).squeeze() 92 | x_standard_std = self.standard_std_decoder(self.standard_encoder(self.drop_out(x_std.view(1, -1)))).squeeze() 93 | 94 | lambda_sigma = self.sigmoid(self.lambda_sigma) 95 | lambda_mu = self.sigmoid(self.lambda_mu) 96 | 97 | mean = lambda_mu * x_standard_mean + (1 - lambda_mu) * x_mean 98 | std = lambda_sigma * x_standard_std + (1 - lambda_sigma) * x_std 99 | 100 | mean = mean.reshape((1, C, 1, 1)) 101 | std = std.reshape((1, C, 1, 1)) 102 | 103 | x = (x - mean) / std 104 | 105 | # rescaling 106 | x_rescaling_beta = self.rescale_beta_decoder(self.rescale_encoder(x_mean.view(1, -1))).squeeze() 107 | x_rescaling_gamma = self.rescale_gamma_decoder(self.rescale_encoder(x_std.view(1, -1))).squeeze() 108 | 109 | beta = self.lambda_beta * x_rescaling_beta + self.bias_beta 110 | gamma = self.lambda_gamma * x_rescaling_gamma + self.bias_gamma 111 | 112 | beta = beta.reshape((1, C, 1, 1)) 113 | gamma = gamma.reshape((1, C, 1, 1)) 114 | 115 | x = x * gamma + beta 116 | 117 | return x 118 | 119 | 120 | def set_module(model, name, norm_layer): 121 | """Replace module of model by name with multi-level path""" 122 | path = name.split('.') 123 | cur = model 124 | for p in path[:-1]: 125 | cur = getattr(cur, p) 126 | setattr(cur, path[-1], norm_layer) 127 | 128 | 129 | # 先试试行不行,如果行,再进行修改 130 | def build_ASRNormBN2d(model:nn.Module, init=False): 131 | for name, module in model.named_modules(): 132 | if isinstance(module, nn.BatchNorm2d): 133 | set_module(model, name, ASRNormBN2d(module.num_features, init_beta=(module.bias if init else None), init_gamma=(module.weight if init else None))) 134 | -------------------------------------------------------------------------------- /Normalizations/ASRNormLN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | """ 7 | ASRNormLN exists bugs, batchsize cannot match the input_channel of every block 8 | """ 9 | 10 | 11 | class ASRNormLN(nn.Module): 12 | def __init__(self, dim, eps=1e-6, init_beta=None, init_gamma=None): 13 | ''' 14 | 15 | :param dim: C of N,C,H,D 16 | ''' 17 | super(ASRNormLN, self).__init__() 18 | self.eps = eps 19 | self.num_channels = dim 20 | self.stan_mid_channel = self.num_channels // 2 21 | self.rsc_mid_channel = self.num_channels // 16 22 | 23 | self.relu = nn.ReLU(True) 24 | self.tanh = nn.Tanh() 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | self.standard_encoder = nn.Linear(dim, self.stan_mid_channel) # 16 28 | self.rescale_encoder = nn.Linear(dim, self.rsc_mid_channel) 29 | 30 | # standardization 31 | self.standard_mean_decoder = nn.Sequential( 32 | self.relu, 33 | nn.Linear(self.stan_mid_channel, dim) 34 | ) 35 | 36 | self.standard_std_decoder = nn.Sequential( 37 | self.relu, 38 | nn.Linear(self.stan_mid_channel, dim), 39 | self.relu 40 | ) 41 | 42 | # Rescaling 43 | self.rescale_beta_decoder = nn.Sequential( 44 | self.relu, 45 | nn.Linear(self.rsc_mid_channel, dim), 46 | self.tanh 47 | ) 48 | 49 | self.rescale_gamma_decoder = nn.Sequential( 50 | self.relu, 51 | nn.Linear(self.rsc_mid_channel, dim), 52 | self.sigmoid 53 | ) 54 | 55 | self.lambda_mu = nn.Parameter(torch.empty(1)) 56 | self.lambda_sigma = nn.Parameter(torch.empty(1)) 57 | 58 | self.lambda_beta = nn.Parameter(torch.empty(1)) 59 | self.lambda_gamma = nn.Parameter(torch.empty(1)) 60 | 61 | self.bias_beta = nn.Parameter(torch.empty(dim)) 62 | self.bias_gamma = nn.Parameter(torch.empty(dim)) 63 | 64 | self.drop_out = nn.Dropout(p=0.3) 65 | 66 | # init lambda and bias 67 | with torch.no_grad(): 68 | init.constant_(self.lambda_mu, self.sigmoid(torch.tensor(-3))) 69 | init.constant_(self.lambda_sigma, self.sigmoid(torch.tensor(-3))) 70 | init.constant_(self.lambda_beta, self.sigmoid(torch.tensor(-5))) 71 | init.constant_(self.lambda_gamma, self.sigmoid(torch.tensor(-5))) 72 | 73 | if init_beta is None: 74 | init.constant_(self.bias_beta, 0.) 75 | else: 76 | self.bias_beta.copy_(init_beta) 77 | if init_gamma is None: 78 | init.constant_(self.bias_gamma, 1.) 79 | else: 80 | self.bias_gamma.copy_(init_gamma) 81 | 82 | 83 | 84 | 85 | def forward(self, x): 86 | ''' 87 | 88 | :param x: N,C,H,D 89 | :return: 90 | ''' 91 | N, C, H, W = x.size() 92 | x_mean = torch.mean(x, dim=(1, 2, 3)) 93 | x_std = torch.sqrt(torch.var(x, dim=(1, 2, 3))) + self.eps 94 | 95 | # standardization 96 | x_standard_mean = self.standard_mean_decoder(self.standard_encoder(self.drop_out(x_mean.view(1, -1)))).squeeze() 97 | x_standard_std = self.standard_std_decoder(self.standard_encoder(self.drop_out(x_std.view(1, -1)))).squeeze() 98 | 99 | lambda_sigma = self.sigmoid(self.lambda_sigma) 100 | lambda_mu = self.sigmoid(self.lambda_mu) 101 | 102 | mean = lambda_mu * x_standard_mean + (1 - lambda_mu) * x_mean 103 | std = lambda_sigma * x_standard_std + (1 - lambda_sigma) * x_std 104 | 105 | mean = mean.reshape((N, 1, 1, 1)) 106 | std = std.reshape((N, 1, 1, 1)) 107 | 108 | x = (x - mean) / std 109 | 110 | # rescaling 111 | x_rescaling_beta = self.rescale_beta_decoder(self.rescale_encoder(x_mean.view(1, -1))).squeeze() 112 | x_rescaling_gamma = self.rescale_gamma_decoder(self.rescale_encoder(x_std.view(1, -1))).squeeze() 113 | 114 | beta = self.lambda_beta * x_rescaling_beta + self.bias_beta 115 | gamma = self.lambda_gamma * x_rescaling_gamma + self.bias_gamma 116 | 117 | beta = beta.reshape((N, 1, 1, 1)) 118 | gamma = gamma.reshape((N, 1, 1, 1)) 119 | 120 | x = x * gamma + beta 121 | 122 | return x 123 | 124 | def set_module(model, name, norm_layer): 125 | """Replace module of model by name with multi-level path""" 126 | path = name.split('.') 127 | cur = model 128 | for p in path[:-1]: 129 | cur = getattr(cur, p) 130 | setattr(cur, path[-1], norm_layer) 131 | 132 | 133 | # 先试试行不行,如果行,再进行修改 134 | def build_ASRNormLN(model:nn.Module, init=False): 135 | for name, module in model.named_modules(): 136 | if isinstance(module, nn.BatchNorm2d): 137 | set_module(model, name, ASRNormLN(module.num_features, init_beta=(module.bias if init else None), init_gamma=(module.weight if init else None))) 138 | -------------------------------------------------------------------------------- /Normalizations/ASRNormIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | class ASRNormIN(nn.Module): 8 | def __init__(self, dim, eps=1e-6, init_beta=None, init_gamma=None): 9 | ''' 10 | 11 | :param dim: C of N,C,H,D 12 | ''' 13 | super(ASRNormIN, self).__init__() 14 | 15 | self.eps = eps 16 | self.num_channels = dim 17 | self.stan_mid_channel = self.num_channels // 2 18 | self.rsc_mid_channel = self.num_channels // 16 19 | 20 | self.relu = nn.ReLU(True) 21 | self.tanh = nn.Tanh() 22 | self.sigmoid = nn.Sigmoid() 23 | 24 | self.standard_encoder = nn.Linear(dim, self.stan_mid_channel) # 16 25 | self.rescale_encoder = nn.Linear(dim, self.rsc_mid_channel) 26 | 27 | # standardization 28 | self.standard_mean_decoder = nn.Sequential( 29 | self.relu, 30 | nn.Linear(self.stan_mid_channel, dim) 31 | ) 32 | 33 | self.standard_std_decoder = nn.Sequential( 34 | self.relu, 35 | nn.Linear(self.stan_mid_channel, dim), 36 | self.relu 37 | ) 38 | 39 | # Rescaling 40 | self.rescale_beta_decoder = nn.Sequential( 41 | self.relu, 42 | nn.Linear(self.rsc_mid_channel, dim), 43 | self.tanh 44 | ) 45 | 46 | self.rescale_gamma_decoder = nn.Sequential( 47 | self.relu, 48 | nn.Linear(self.rsc_mid_channel, dim), 49 | self.sigmoid 50 | ) 51 | 52 | self.lambda_mu = nn.Parameter(torch.empty(1)) 53 | self.lambda_sigma = nn.Parameter(torch.empty(1)) 54 | 55 | # 如果是PACS多源泛化,那么我们会含有这俩个参数反之不含有 56 | # 与此同时,这个时候上面的 bias 都是 pretrained model 里面的固定值 57 | # if pacs: 58 | self.lambda_beta = nn.Parameter(torch.empty(1)) 59 | self.lambda_gamma = nn.Parameter(torch.empty(1)) 60 | 61 | # else: 62 | self.bias_beta = nn.Parameter(torch.empty(dim)) 63 | self.bias_gamma = nn.Parameter(torch.empty(dim)) 64 | 65 | 66 | # init lambda and bias 67 | with torch.no_grad(): 68 | init.constant_(self.lambda_mu, self.sigmoid(torch.tensor(-3))) 69 | init.constant_(self.lambda_sigma, self.sigmoid(torch.tensor(-3))) 70 | init.constant_(self.lambda_beta, self.sigmoid(torch.tensor(-5))) 71 | init.constant_(self.lambda_gamma, self.sigmoid(torch.tensor(-5))) 72 | 73 | if init_beta is None: 74 | init.constant_(self.bias_beta, 0.) 75 | else: 76 | self.bias_beta.copy_(init_beta) 77 | if init_gamma is None: 78 | init.constant_(self.bias_gamma, 1.) 79 | else: 80 | self.bias_gamma.copy_(init_gamma) 81 | 82 | def forward(self, x): 83 | ''' 84 | 85 | :param x: N,C,H,W 86 | :return: 87 | ''' 88 | 89 | N, C, H, W = x.size() 90 | x_mean = torch.mean(x, dim=(2, 3)) 91 | x_std = torch.sqrt(torch.var(x, dim=(2, 3)) + self.eps) 92 | 93 | # standardization 94 | x_standard_mean = self.standard_mean_decoder(self.standard_encoder(x_mean)) 95 | x_standard_std = self.standard_std_decoder(self.standard_encoder(x_std)) 96 | 97 | lambda_sigma = self.sigmoid(self.lambda_sigma) 98 | lambda_mu = self.sigmoid(self.lambda_mu) 99 | 100 | mean = lambda_mu * x_standard_mean + (1 - lambda_mu) * x_mean 101 | std = lambda_sigma * x_standard_std + (1 - lambda_sigma) * x_std 102 | 103 | mean = mean.reshape((N, C, 1, 1)) 104 | std = std.reshape((N, C, 1, 1)) 105 | 106 | x = (x - mean) / std 107 | 108 | # rescaling 109 | x_rescaling_beta = self.rescale_beta_decoder(self.rescale_encoder(x_mean)) 110 | x_rescaling_gamma = self.rescale_gamma_decoder(self.rescale_encoder(x_std)) 111 | 112 | # if self.pacs: 113 | beta = self.lambda_beta * x_rescaling_beta + self.bias_beta 114 | gamma = self.lambda_gamma * x_rescaling_gamma + self.bias_gamma 115 | # else: 116 | # beta = x_rescaling_beta + self.bias_beta 117 | # gamma = x_rescaling_gamma + self.bias_gamma 118 | 119 | beta = beta.reshape((N, C, 1, 1)) 120 | gamma = gamma.reshape((N, C, 1, 1)) 121 | 122 | x = x * gamma + beta 123 | 124 | return x 125 | 126 | 127 | def set_module(model, name, norm_layer): 128 | """Replace module of model by name with multi-level path""" 129 | path = name.split('.') 130 | cur = model 131 | for p in path[:-1]: 132 | cur = getattr(cur, p) 133 | setattr(cur, path[-1], norm_layer) 134 | 135 | 136 | # 先试试行不行,如果行,再进行修改 137 | def build_ASRNormIN(model:nn.Module, init=False): 138 | for name, module in model.named_modules(): 139 | if isinstance(module, nn.BatchNorm2d): 140 | set_module(model, name, ASRNormIN(module.num_features, init_beta=(module.bias if init else None), init_gamma=(module.weight if init else None))) 141 | -------------------------------------------------------------------------------- /data/mnistm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import VisionDataset 11 | from torchvision.datasets.utils import download_and_extract_archive 12 | 13 | 14 | class MNISTM(VisionDataset): 15 | """MNIST-M Dataset. 16 | """ 17 | 18 | resources = [ 19 | ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz', 20 | '191ed53db9933bd85cc9700558847391'), 21 | ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz', 22 | 'e11cb4d7fff76d7ec588b1134907db59') 23 | ] 24 | 25 | training_file = "mnist_m_train.pt" 26 | test_file = "mnist_m_test.pt" 27 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', 28 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] 29 | 30 | @property 31 | def train_labels(self): 32 | warnings.warn("train_labels has been renamed targets") 33 | return self.targets 34 | 35 | @property 36 | def test_labels(self): 37 | warnings.warn("test_labels has been renamed targets") 38 | return self.targets 39 | 40 | @property 41 | def train_data(self): 42 | warnings.warn("train_data has been renamed data") 43 | return self.data 44 | 45 | @property 46 | def test_data(self): 47 | warnings.warn("test_data has been renamed data") 48 | return self.data 49 | 50 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 51 | """Init MNIST-M dataset.""" 52 | super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform) 53 | 54 | self.train = train 55 | 56 | if download: 57 | self.download() 58 | 59 | if not self._check_exists(): 60 | raise RuntimeError("Dataset not found." + 61 | " You can use download=True to download it") 62 | 63 | if self.train: 64 | data_file = self.training_file 65 | else: 66 | data_file = self.test_file 67 | 68 | print(os.path.join(self.processed_folder, data_file)) 69 | 70 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 71 | 72 | def __getitem__(self, index): 73 | """Get images and target for data loader. 74 | Args: 75 | index (int): Index 76 | Returns: 77 | tuple: (image, target) where target is index of the target class. 78 | """ 79 | img, target = self.data[index], int(self.targets[index]) 80 | 81 | # doing this so that it is consistent with all other datasets 82 | # to return a PIL Image 83 | img = Image.fromarray(img.squeeze().numpy(), mode="RGB") 84 | 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target 92 | 93 | def __len__(self): 94 | """Return size of dataset.""" 95 | return len(self.data) 96 | 97 | @property 98 | def raw_folder(self): 99 | return os.path.join(self.root, self.__class__.__name__, 'raw') 100 | 101 | @property 102 | def processed_folder(self): 103 | return os.path.join(self.root, self.__class__.__name__, 'processed') 104 | 105 | @property 106 | def class_to_idx(self): 107 | return {_class: i for i, _class in enumerate(self.classes)} 108 | 109 | def _check_exists(self): 110 | return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and 111 | os.path.exists(os.path.join(self.processed_folder, self.test_file))) 112 | 113 | def download(self): 114 | """Download the MNIST-M data.""" 115 | 116 | if self._check_exists(): 117 | return 118 | 119 | os.makedirs(self.raw_folder, exist_ok=True) 120 | os.makedirs(self.processed_folder, exist_ok=True) 121 | 122 | # download files 123 | for url, md5 in self.resources: 124 | filename = url.rpartition('/')[2] 125 | download_and_extract_archive(url, download_root=self.raw_folder, 126 | extract_root=self.processed_folder, 127 | filename=filename, md5=md5) 128 | 129 | print('Done!') 130 | 131 | def extra_repr(self): 132 | return "Split: {}".format("Train" if self.train is True else "Test") 133 | 134 | 135 | def get_mnist_m_train(batch_size=256, 136 | num_workers=40, 137 | pin_memory=True, 138 | ): 139 | transform = transforms.Compose([ 140 | transforms.Resize((32, 32)), 141 | transforms.Grayscale(num_output_channels=1), 142 | transforms.ToTensor(), 143 | transforms.Normalize((0.5,), (0.5,)) 144 | ]) 145 | set = MNISTM('../resources/mnistm', train=True, download=False, transform=transform) 146 | 147 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 148 | shuffle=True) 149 | 150 | return loader 151 | 152 | 153 | def get_mnist_m_test(batch_size=256, 154 | num_workers=40, 155 | pin_memory=True, 156 | ): 157 | transform = transforms.Compose([ 158 | transforms.Resize((32, 32)), 159 | transforms.Grayscale(num_output_channels=1), 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.5,), (0.5,)) 162 | ]) 163 | set = MNISTM('../resources/mnistm/', train=False, download=False, transform=transform) 164 | 165 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 166 | shuffle=True) 167 | 168 | return loader 169 | 170 | -------------------------------------------------------------------------------- /backbones/resnetv3.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 4 | 5 | Deep Residual Learning for Image Recognition 6 | https://arxiv.org/abs/1512.03385v1 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | class BasicBlock(nn.Module): 13 | """Basic Block for resnet 18 and resnet 34 14 | 15 | """ 16 | 17 | #BasicBlock and BottleNeck block 18 | #have different output size 19 | #we use class attribute expansion 20 | #to distinct 21 | expansion = 1 22 | 23 | def __init__(self, in_channels, out_channels, stride=1): 24 | super().__init__() 25 | 26 | #residual function 27 | self.residual_function = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 32 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 33 | ) 34 | 35 | #shortcut 36 | self.shortcut = nn.Sequential() 37 | 38 | #the shortcut output dimension is not the same with residual function 39 | #use 1*1 convolution to match the dimension 40 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 41 | self.shortcut = nn.Sequential( 42 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 43 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 44 | ) 45 | 46 | def forward(self, x): 47 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 48 | 49 | class BottleNeck(nn.Module): 50 | """Residual block for resnet over 50 layers 51 | 52 | """ 53 | expansion = 4 54 | def __init__(self, in_channels, out_channels, stride=1): 55 | super().__init__() 56 | self.residual_function = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(out_channels), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 61 | nn.BatchNorm2d(out_channels), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 64 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 65 | ) 66 | 67 | self.shortcut = nn.Sequential() 68 | 69 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 70 | self.shortcut = nn.Sequential( 71 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 72 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 73 | ) 74 | 75 | def forward(self, x): 76 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 77 | 78 | class ResNet(nn.Module): 79 | 80 | def __init__(self, block, num_block, num_classes=100): 81 | super().__init__() 82 | 83 | self.in_channels = 64 84 | 85 | self.conv1 = nn.Sequential( 86 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 87 | nn.BatchNorm2d(64), 88 | nn.ReLU(inplace=True)) 89 | #we use a different inputsize than the original paper 90 | #so conv2_x's stride is 1 91 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 92 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 93 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 94 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 95 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 96 | self.fc = nn.Linear(512 * block.expansion, num_classes) 97 | 98 | def _make_layer(self, block, out_channels, num_blocks, stride): 99 | """make resnet layers(by layer i didnt mean this 'layer' was the 100 | same as a neuron netowork layer, ex. conv layer), one layer may 101 | contain more than one residual block 102 | 103 | Args: 104 | block: block type, basic block or bottle neck block 105 | out_channels: output depth channel number of this layer 106 | num_blocks: how many blocks per layer 107 | stride: the stride of the first block of this layer 108 | 109 | Return: 110 | return a resnet layer 111 | """ 112 | 113 | # we have num_block blocks per layer, the first block 114 | # could be 1 or 2, other blocks would always be 1 115 | strides = [stride] + [1] * (num_blocks - 1) 116 | layers = [] 117 | for stride in strides: 118 | layers.append(block(self.in_channels, out_channels, stride)) 119 | self.in_channels = out_channels * block.expansion 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | output = self.conv1(x) 125 | output = self.conv2_x(output) 126 | output = self.conv3_x(output) 127 | output = self.conv4_x(output) 128 | output = self.conv5_x(output) 129 | output = self.avg_pool(output) 130 | output = output.view(output.size(0), -1) 131 | output = self.fc(output) 132 | 133 | return output 134 | 135 | def resnet18(): 136 | """ return a ResNet 18 object 137 | """ 138 | return ResNet(BasicBlock, [2, 2, 2, 2]) 139 | 140 | def resnet34(): 141 | """ return a ResNet 34 object 142 | """ 143 | return ResNet(BasicBlock, [3, 4, 6, 3]) 144 | 145 | def resnet50(): 146 | """ return a ResNet 50 object 147 | """ 148 | return ResNet(BottleNeck, [3, 4, 6, 3]) 149 | 150 | def resnet101(): 151 | """ return a ResNet 101 object 152 | """ 153 | return ResNet(BottleNeck, [3, 4, 23, 3]) 154 | 155 | def resnet152(): 156 | """ return a ResNet 152 object 157 | """ 158 | return ResNet(BottleNeck, [3, 8, 36, 3]) 159 | -------------------------------------------------------------------------------- /Solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from typing import Callable 5 | from torch.nn import functional as F 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | from optimizer import default_optimizer, default_lr_scheduler, CosineLRS, ALRS 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | def default_loss(x, y): 13 | cross_entropy = F.cross_entropy(x, y) 14 | return cross_entropy 15 | 16 | 17 | class Solver(): 18 | def __init__(self, student: nn.Module, 19 | loss_function: Callable or None = None, 20 | optimizer: torch.optim.Optimizer or None = None, 21 | scheduler=None, 22 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 23 | ): 24 | self.student = student 25 | self.criterion = loss_function if loss_function is not None else default_loss 26 | self.optimizer = optimizer if optimizer is not None else default_optimizer(self.student) 27 | self.scheduler = scheduler if scheduler is not None else ALRS(self.optimizer) 28 | # self.optimizer = optim.SGD(self.student.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001) 29 | # self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[150, 250], gamma=0.1) 30 | self.device = device 31 | 32 | # initialization 33 | self.init() 34 | 35 | def init(self): 36 | # change device 37 | self.student.to(self.device) 38 | 39 | # # tensorboard 40 | # self.writer = SummaryWriter(log_dir="runs/Solver", flush_secs=120) 41 | 42 | def train(self, 43 | train_loader: DataLoader, 44 | validation_loader: DataLoader, 45 | total_epoch=500, 46 | fp16=False, 47 | ): 48 | from torch.cuda.amp import autocast, GradScaler 49 | scaler = GradScaler() 50 | for epoch in range(1, total_epoch + 1): 51 | train_loss, train_acc, validation_loss, validation_acc = 0, 0, 0, 0 52 | self.student.train() 53 | # train 54 | pbar = tqdm(train_loader) 55 | for step, (x, y) in enumerate(pbar, 1): 56 | x, y = x.to(self.device), y.to(self.device) 57 | if fp16: 58 | with autocast(): 59 | student_out = self.student(x) # N, 60 60 | # student_out = self.student(x, y, epoch) 61 | _, pre = torch.max(student_out, dim=1) 62 | loss = self.criterion(student_out, y) 63 | else: 64 | student_out = self.student(x) # N, 60 65 | # student_out = self.student(x, y, epoch) 66 | _, pre = torch.max(student_out, dim=1) 67 | loss = self.criterion(student_out, y) 68 | if pre.shape != y.shape: 69 | _, y = torch.max(y, dim=1) 70 | train_acc += (torch.sum(pre == y).item()) / y.shape[0] 71 | train_loss += loss.item() 72 | self.optimizer.zero_grad() 73 | 74 | if fp16: 75 | scaler.scale(loss).backward() 76 | scaler.unscale_(self.optimizer) 77 | # nn.utils.clip_grad_value_(self.student.parameters(), 0.1) 78 | # nn.utils.clip_grad_norm(self.student.parameters(), max_norm=10) 79 | scaler.step(self.optimizer) 80 | scaler.update() 81 | else: 82 | loss.backward() 83 | # nn.utils.clip_grad_value_(self.student.parameters(), 0.1) 84 | # nn.utils.clip_grad_norm(self.student.parameters(), max_norm=10) 85 | self.optimizer.step() 86 | 87 | if step % 10 == 0: 88 | pbar.set_postfix_str(f'loss={train_loss / step}, acc={train_acc / step}') 89 | 90 | train_loss /= len(train_loader) 91 | train_acc /= len(train_loader) 92 | 93 | # validation 94 | vbar = tqdm(validation_loader, colour='yellow') 95 | self.student.eval() 96 | with torch.no_grad(): 97 | for step, (x, y) in enumerate(vbar, 1): 98 | x, y = x.to(self.device), y.to(self.device) 99 | student_out = self.student(x) # N, 60 100 | # student_out = self.student(x, y, epoch) 101 | _, pre = torch.max(student_out, dim=1) 102 | loss = self.criterion(student_out, y) 103 | if pre.shape != y.shape: 104 | _, y = torch.max(y, dim=1) 105 | validation_acc += (torch.sum(pre == y).item()) / y.shape[0] 106 | validation_loss += loss.item() 107 | 108 | if step % 10 == 0: 109 | vbar.set_postfix_str(f'loss={validation_loss / step}, acc={validation_acc / step}') 110 | 111 | validation_loss /= len(validation_loader) 112 | validation_acc /= len(validation_loader) 113 | 114 | self.scheduler.step(train_loss, epoch) 115 | # self.optimizer.step() 116 | 117 | print(f'epoch {epoch}, train_loss = {train_loss}, train_acc = {train_acc}') 118 | print(f'epoch {epoch}, validation_loss = {validation_loss}, validation_acc = {validation_acc}') 119 | print('-' * 100) 120 | 121 | torch.save(self.student.state_dict(), 'student.pth') 122 | 123 | 124 | if __name__ == '__main__': 125 | import torchvision 126 | from Normalizations import ASRNormBN2d, ASRNormIN, build_ASRNormIN, build_ASRNormBN2d 127 | from torchvision.models import resnet18 128 | from data import get_PACS_train, get_PACS_test, get_CIFAR100_train, get_cifar_10_c_loader, get_CIFAR100_test 129 | from backbones import resnet56, resnet110, wrn_40_2, wrn_16_2, vgg16_bn, vgg19_bn, pyramidnet164, pyramidnet272, resnet32 130 | 131 | a = resnet32(num_classes=100) 132 | 133 | 134 | # freeze_weights(a, nn.BatchNorm2d) 135 | # build_ASRNormIN(a, True) 136 | # build_ASRNormBN2d(a, True) 137 | 138 | train_loader = get_CIFAR100_train(batch_size=128, augment=True) 139 | test_loader = get_CIFAR100_test(batch_size=256) 140 | 141 | 142 | w = Solver(a) 143 | w.train(train_loader, test_loader) 144 | -------------------------------------------------------------------------------- /backbones/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | """ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | """ 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | __all__ = ["ShuffleV1_aux", "ShuffleV1"] 11 | 12 | 13 | class ShuffleBlock(nn.Module): 14 | def __init__(self, groups): 15 | super(ShuffleBlock, self).__init__() 16 | self.groups = groups 17 | 18 | def forward(self, x): 19 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 20 | N, C, H, W = x.size() 21 | g = self.groups 22 | return x.reshape(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 23 | 24 | 25 | class Bottleneck(nn.Module): 26 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 27 | super(Bottleneck, self).__init__() 28 | self.is_last = is_last 29 | self.stride = stride 30 | 31 | mid_planes = int(out_planes / 4) 32 | g = 1 if in_planes == 24 else groups 33 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 34 | self.bn1 = nn.BatchNorm2d(mid_planes) 35 | self.shuffle1 = ShuffleBlock(groups=g) 36 | self.conv2 = nn.Conv2d( 37 | mid_planes, 38 | mid_planes, 39 | kernel_size=3, 40 | stride=stride, 41 | padding=1, 42 | groups=mid_planes, 43 | bias=False, 44 | ) 45 | self.bn2 = nn.BatchNorm2d(mid_planes) 46 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 47 | self.bn3 = nn.BatchNorm2d(out_planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride == 2: 51 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 52 | 53 | def forward(self, x): 54 | out = F.relu(self.bn1(self.conv1(x))) 55 | out = self.shuffle1(out) 56 | out = F.relu(self.bn2(self.conv2(out))) 57 | out = self.bn3(self.conv3(out)) 58 | res = self.shortcut(x) 59 | preact = torch.cat([out, res], 1) if self.stride == 2 else out + res 60 | out = F.relu(preact) 61 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 62 | if self.is_last: 63 | return out, preact 64 | else: 65 | return out 66 | 67 | 68 | class ShuffleNet(nn.Module): 69 | def __init__(self, cfg, num_classes=10): 70 | super(ShuffleNet, self).__init__() 71 | out_planes = cfg["out_planes"] 72 | num_blocks = cfg["num_blocks"] 73 | groups = cfg["groups"] 74 | 75 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(24) 77 | self.in_planes = 24 78 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 79 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 80 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 81 | self.linear = nn.Linear(out_planes[2], num_classes) 82 | self.last_channel = out_planes[2] 83 | 84 | def _make_layer(self, out_planes, num_blocks, groups): 85 | layers = [] 86 | for i in range(num_blocks): 87 | stride = 2 if i == 0 else 1 88 | cat_planes = self.in_planes if i == 0 else 0 89 | layers.append( 90 | Bottleneck( 91 | self.in_planes, 92 | out_planes - cat_planes, 93 | stride=stride, 94 | groups=groups, 95 | is_last=(i == num_blocks - 1), 96 | ) 97 | ) 98 | self.in_planes = out_planes 99 | return nn.Sequential(*layers) 100 | 101 | def get_feat_modules(self): 102 | feat_m = nn.ModuleList([]) 103 | feat_m.append(self.conv1) 104 | feat_m.append(self.bn1) 105 | feat_m.append(self.layer1) 106 | feat_m.append(self.layer2) 107 | feat_m.append(self.layer3) 108 | return feat_m 109 | 110 | def get_bn_before_relu(self): 111 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 112 | 113 | def forward(self, x, is_feat=False): 114 | out = F.relu(self.bn1(self.conv1(x))) 115 | f0 = out 116 | out, f1_pre = self.layer1(out) 117 | f1 = out 118 | out, f2_pre = self.layer2(out) 119 | f2 = out 120 | out, f3_pre = self.layer3(out) 121 | f3 = out 122 | out = F.avg_pool2d(out, 4) 123 | out = out.reshape(out.size(0), -1) 124 | out = self.linear(out) 125 | if is_feat: 126 | return [f0, f1, f2, f3], out 127 | else: 128 | return out 129 | 130 | 131 | class Auxiliary_Classifier(nn.Module): 132 | def __init__(self, cfg, num_classes=10): 133 | super(Auxiliary_Classifier, self).__init__() 134 | out_planes = cfg["out_planes"] 135 | num_blocks = cfg["num_blocks"] 136 | groups = cfg["groups"] 137 | 138 | self.in_planes = out_planes[0] 139 | self.block_extractor1 = nn.Sequential( 140 | *[ 141 | self._make_layer(out_planes[1], num_blocks[1], groups), 142 | self._make_layer(out_planes[2], num_blocks[2], groups), 143 | ] 144 | ) 145 | self.in_planes = out_planes[1] 146 | self.block_extractor2 = nn.Sequential( 147 | *[self._make_layer(out_planes[2], num_blocks[2], groups)] 148 | ) 149 | 150 | self.inplanes = out_planes[2] 151 | self.block_extractor3 = nn.Sequential( 152 | *[self._make_layer(out_planes[2], num_blocks[2], groups, downsample=False)] 153 | ) 154 | 155 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc1 = nn.Linear(out_planes[2], num_classes) 157 | self.fc2 = nn.Linear(out_planes[2], num_classes) 158 | self.fc3 = nn.Linear(out_planes[2], num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.Linear): 168 | m.bias.data.zero_() 169 | 170 | def _make_layer(self, out_planes, num_blocks, groups, downsample=True): 171 | layers = [] 172 | for i in range(num_blocks): 173 | stride = 2 if i == 0 and downsample is True else 1 174 | cat_planes = self.in_planes if i == 0 and downsample is True else 0 175 | layers.append( 176 | Bottleneck( 177 | self.in_planes, 178 | out_planes - cat_planes, 179 | stride=stride, 180 | groups=groups, 181 | is_last=(i == num_blocks - 1), 182 | ) 183 | ) 184 | self.in_planes = out_planes 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x): 188 | ss_logits = [] 189 | for i in range(len(x)): 190 | idx = i + 1 191 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 192 | out = self.avg_pool(out) 193 | out = out.view(out.size(0), -1) 194 | out = getattr(self, "fc" + str(idx))(out) 195 | ss_logits.append(out) 196 | return ss_logits 197 | 198 | 199 | class ShuffleNet_Auxiliary(nn.Module): 200 | def __init__(self, cfg, num_classes=100): 201 | super(ShuffleNet_Auxiliary, self).__init__() 202 | self.backbone = ShuffleNet(cfg, num_classes=num_classes) 203 | self.auxiliary_classifier = Auxiliary_Classifier(cfg, num_classes=num_classes * 4) 204 | 205 | def forward(self, x, grad=False): 206 | feats, logit = self.backbone(x, is_feat=True) 207 | if grad is False: 208 | for i in range(len(feats)): 209 | feats[i] = feats[i].detach() 210 | ss_logits = self.auxiliary_classifier(feats) 211 | return logit, ss_logits 212 | 213 | 214 | def ShuffleV1(**kwargs): 215 | cfg = {"out_planes": [240, 480, 960], "num_blocks": [4, 8, 4], "groups": 3} 216 | return ShuffleNet(cfg, **kwargs) 217 | 218 | 219 | def ShuffleV1_aux(**kwargs): 220 | cfg = {"out_planes": [240, 480, 960], "num_blocks": [4, 8, 4], "groups": 3} 221 | return ShuffleNet_Auxiliary(cfg, **kwargs) 222 | -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, random_split 3 | from torchvision import transforms 4 | import os.path 5 | import pickle 6 | from typing import Any, Callable, Optional, Tuple 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class CIFAR10(VisionDataset): 16 | """`CIFAR10 `_ Dataset. 17 | 18 | Args: 19 | root (string): Root directory of dataset where directory 20 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 21 | train (bool, optional): If True, creates dataset from training set, otherwise 22 | creates from test set. 23 | transform (callable, optional): A function/transform that takes in an PIL image 24 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 25 | target_transform (callable, optional): A function/transform that takes in the 26 | target and transforms it. 27 | download (bool, optional): If true, downloads the dataset from the internet and 28 | puts it in root directory. If dataset is already downloaded, it is not 29 | downloaded again. 30 | 31 | """ 32 | 33 | base_folder = "cifar-10-batches-py" 34 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 35 | filename = "cifar-10-python.tar.gz" 36 | tgz_md5 = "c58f30108f718f92721af3b95e74349a" 37 | train_list = [ 38 | ["data_batch_1", "c99cafc152244af753f735de768cd75f"], 39 | ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], 40 | ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], 41 | ["data_batch_4", "634d18415352ddfa80567beed471001a"], 42 | ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], 43 | ] 44 | 45 | test_list = [ 46 | ["test_batch", "40351d587109b95175f43aff81a1287e"], 47 | ] 48 | meta = { 49 | "filename": "batches.meta", 50 | "key": "label_names", 51 | "md5": "5ff9c542aee3614f3951f8cda6e48888", 52 | } 53 | 54 | def __init__( 55 | self, 56 | root: str, 57 | train: bool = True, 58 | transform: Optional[Callable] = None, 59 | target_transform: Optional[Callable] = None, 60 | download: bool = False, 61 | ) -> None: 62 | 63 | super().__init__(root, transform=transform, target_transform=target_transform) 64 | 65 | self.train = train # training set or test set 66 | 67 | if download: 68 | self.download() 69 | 70 | if not self._check_integrity(): 71 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 72 | 73 | if self.train: 74 | downloaded_list = self.train_list 75 | else: 76 | downloaded_list = self.test_list 77 | 78 | self.data: Any = [] 79 | self.targets = [] 80 | 81 | # now load the picked numpy arrays 82 | for file_name, checksum in downloaded_list: 83 | file_path = os.path.join(self.root, self.base_folder, file_name) 84 | with open(file_path, "rb") as f: 85 | entry = pickle.load(f, encoding="latin1") 86 | self.data.append(entry["data"]) 87 | if "labels" in entry: 88 | self.targets.extend(entry["labels"]) 89 | else: 90 | self.targets.extend(entry["fine_labels"]) 91 | 92 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 93 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 94 | 95 | self._load_meta() 96 | 97 | def _load_meta(self) -> None: 98 | path = os.path.join(self.root, self.base_folder, self.meta["filename"]) 99 | if not check_integrity(path, self.meta["md5"]): 100 | raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it") 101 | with open(path, "rb") as infile: 102 | data = pickle.load(infile, encoding="latin1") 103 | self.classes = data[self.meta["key"]] 104 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 105 | 106 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 107 | """ 108 | Args: 109 | index (int): Index 110 | 111 | Returns: 112 | tuple: (image, target) where target is index of the target class. 113 | """ 114 | img, target = self.data[index], self.targets[index] 115 | 116 | # doing this so that it is consistent with all other datasets 117 | # to return a PIL Image 118 | img = Image.fromarray(img) 119 | 120 | if self.transform is not None: 121 | img = self.transform(img) 122 | 123 | if self.target_transform is not None: 124 | target = self.target_transform(target) 125 | 126 | return img, target 127 | 128 | def __len__(self) -> int: 129 | return len(self.data) 130 | 131 | def _check_integrity(self) -> bool: 132 | root = self.root 133 | for fentry in self.train_list + self.test_list: 134 | filename, md5 = fentry[0], fentry[1] 135 | fpath = os.path.join(root, self.base_folder, filename) 136 | if not check_integrity(fpath, md5): 137 | return False 138 | return True 139 | 140 | def download(self) -> None: 141 | if self._check_integrity(): 142 | print("Files already downloaded and verified") 143 | return 144 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 145 | 146 | def extra_repr(self) -> str: 147 | split = "Train" if self.train is True else "Test" 148 | return f"Split: {split}" 149 | 150 | 151 | class CIFAR100(CIFAR10): 152 | """`CIFAR100 `_ Dataset. 153 | 154 | This is a subclass of the `CIFAR10` Dataset. 155 | """ 156 | 157 | base_folder = "cifar-100-python" 158 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 159 | filename = "cifar-100-python.tar.gz" 160 | tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" 161 | train_list = [ 162 | ["train", "16019d7e3df5f24257cddd939b257f8d"], 163 | ] 164 | 165 | test_list = [ 166 | ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], 167 | ] 168 | meta = { 169 | "filename": "meta", 170 | "key": "fine_label_names", 171 | "md5": "7973b15100ade9c7d40fb424638fde48", 172 | } 173 | 174 | 175 | def get_CIFAR100_train(batch_size=256, 176 | num_workers=8, 177 | pin_memory=True, 178 | augment=False, 179 | ): 180 | if not augment: 181 | transform = transforms.Compose([ 182 | transforms.ToTensor(), 183 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 184 | ]) 185 | else: 186 | transform = transforms.Compose([ 187 | transforms.RandomHorizontalFlip(), 188 | transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), 189 | transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), 190 | transforms.RandomRotation(5), 191 | transforms.ToTensor(), 192 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 193 | ]) 194 | 195 | set = CIFAR100('./resources/CIFAR100', train=True, download=True, transform=transform) 196 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 197 | shuffle=True) 198 | return loader 199 | 200 | 201 | def get_CIFAR100_test(batch_size=256, 202 | num_workers=8, 203 | pin_memory=False, ): 204 | transform = transforms.Compose([ 205 | transforms.ToTensor(), 206 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 207 | ]) 208 | set = CIFAR100('./resources/CIFAR100', train=False, download=True, transform=transform) 209 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 210 | return loader 211 | 212 | 213 | def get_CIFAR10_train(batch_size=256, 214 | num_workers=8, 215 | pin_memory=True, 216 | augment=False, 217 | validate=False 218 | ): 219 | 220 | if not augment: 221 | transform = transforms.Compose([ 222 | transforms.ToTensor(), 223 | transforms.Normalize(((0.4914, 0.4822, 0.4465)), (0.2470, 0.2435, 0.2616)) 224 | ]) 225 | else: 226 | transform = transforms.Compose([ 227 | transforms.RandomHorizontalFlip(), 228 | transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), 229 | transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), 230 | transforms.RandomRotation(5), 231 | transforms.ToTensor(), 232 | transforms.Normalize(((0.4914, 0.4822, 0.4465)), (0.2470, 0.2435, 0.2616)), 233 | 234 | ]) 235 | set = CIFAR10('./resources/CIFAR10', train=True, download=True, transform=transform) 236 | 237 | if validate: 238 | train_size = int(0.8 * len(set)) 239 | validation_size = int(0.2 * len(set)) 240 | 241 | train_datasets = random_split(set, [train_size, validation_size]) 242 | train_set, valiation_set = train_datasets 243 | 244 | train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 245 | shuffle=True) 246 | validation_loader = DataLoader(valiation_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 247 | shuffle=True) 248 | return train_loader, validation_loader 249 | 250 | else: 251 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, 252 | shuffle=True) 253 | return loader 254 | 255 | 256 | def get_CIFAR10_test(batch_size=256, 257 | num_workers=8, 258 | pin_memory=True, ): 259 | transform = transforms.Compose([ 260 | transforms.ToTensor(), 261 | transforms.Normalize(((0.4914, 0.4822, 0.4465)), (0.2470, 0.2435, 0.2616)) 262 | ]) 263 | set = CIFAR10('./resources/CIFAR10', train=False, download=True, transform=transform) 264 | loader = DataLoader(set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 265 | return loader 266 | -------------------------------------------------------------------------------- /backbones/resnetv2.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __all__ = ["ResNet50_aux"] 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1, is_last=False): 18 | super(BasicBlock, self).__init__() 19 | self.is_last = is_last 20 | self.conv1 = nn.Conv2d( 21 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 22 | ) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion * planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d( 31 | in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False 32 | ), 33 | nn.BatchNorm2d(self.expansion * planes), 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | preact = out 41 | out = F.relu(out) 42 | if self.is_last: 43 | return out, preact 44 | else: 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, in_planes, planes, stride=1, is_last=False): 52 | super(Bottleneck, self).__init__() 53 | self.is_last = is_last 54 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 60 | 61 | self.shortcut = nn.Sequential() 62 | if stride != 1 or in_planes != self.expansion * planes: 63 | self.shortcut = nn.Sequential( 64 | nn.Conv2d( 65 | in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False 66 | ), 67 | nn.BatchNorm2d(self.expansion * planes), 68 | ) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | out = self.bn3(self.conv3(out)) 74 | out += self.shortcut(x) 75 | out = F.relu(out) 76 | return out 77 | 78 | 79 | class ResNet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 81 | super(ResNet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 91 | self.linear = nn.Linear(512 * block.expansion, num_classes) 92 | self.last_channel = 512 * block.expansion 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 97 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 98 | nn.init.constant_(m.weight, 1) 99 | nn.init.constant_(m.bias, 0) 100 | 101 | # Zero-initialize the last BN in each residual branch, 102 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 103 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 104 | if zero_init_residual: 105 | for m in self.modules(): 106 | if isinstance(m, Bottleneck): 107 | nn.init.constant_(m.bn3.weight, 0) 108 | elif isinstance(m, BasicBlock): 109 | nn.init.constant_(m.bn2.weight, 0) 110 | 111 | def get_feat_modules(self): 112 | feat_m = nn.ModuleList([]) 113 | feat_m.append(self.conv1) 114 | feat_m.append(self.bn1) 115 | feat_m.append(self.layer1) 116 | feat_m.append(self.layer2) 117 | feat_m.append(self.layer3) 118 | feat_m.append(self.layer4) 119 | return feat_m 120 | 121 | def get_bn_before_relu(self): 122 | if isinstance(self.layer1[0], Bottleneck): 123 | bn1 = self.layer1[-1].bn3 124 | bn2 = self.layer2[-1].bn3 125 | bn3 = self.layer3[-1].bn3 126 | bn4 = self.layer4[-1].bn3 127 | elif isinstance(self.layer1[0], BasicBlock): 128 | bn1 = self.layer1[-1].bn2 129 | bn2 = self.layer2[-1].bn2 130 | bn3 = self.layer3[-1].bn2 131 | bn4 = self.layer4[-1].bn2 132 | else: 133 | raise NotImplementedError("ResNet unknown block error !!!") 134 | 135 | return [bn1, bn2, bn3, bn4] 136 | 137 | def _make_layer(self, block, planes, num_blocks, stride): 138 | strides = [stride] + [1] * (num_blocks - 1) 139 | layers = [] 140 | for i in range(num_blocks): 141 | stride = strides[i] 142 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 143 | self.in_planes = planes * block.expansion 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x, is_feat=False, preact=False): 147 | out = F.relu(self.bn1(self.conv1(x))) 148 | out = self.layer1(out) 149 | f1 = out 150 | out = self.layer2(out) 151 | f2 = out 152 | out = self.layer3(out) 153 | f3 = out 154 | out = self.layer4(out) 155 | f4 = out 156 | out = self.avgpool(out) 157 | out = out.view(out.size(0), -1) 158 | out = self.linear(out) 159 | if is_feat: 160 | return [f1, f2, f3, f4], out 161 | else: 162 | return out 163 | 164 | 165 | class Auxiliary_Classifier(nn.Module): 166 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 167 | super(Auxiliary_Classifier, self).__init__() 168 | 169 | self.in_planes = 64 * block.expansion 170 | self.block_extractor1 = nn.Sequential( 171 | *[ 172 | self._make_layer(block, 128, num_blocks[1], stride=2), 173 | self._make_layer(block, 256, num_blocks[2], stride=2), 174 | self._make_layer(block, 512, num_blocks[3], stride=2), 175 | ] 176 | ) 177 | 178 | self.in_planes = 128 * block.expansion 179 | self.block_extractor2 = nn.Sequential( 180 | *[ 181 | self._make_layer(block, 256, num_blocks[2], stride=2), 182 | self._make_layer(block, 512, num_blocks[3], stride=2), 183 | ] 184 | ) 185 | 186 | self.in_planes = 256 * block.expansion 187 | self.block_extractor3 = nn.Sequential( 188 | *[self._make_layer(block, 512, num_blocks[3], stride=2)] 189 | ) 190 | 191 | self.in_planes = 512 * block.expansion 192 | self.block_extractor4 = nn.Sequential( 193 | *[self._make_layer(block, 512, num_blocks[3], stride=1)] 194 | ) 195 | 196 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 197 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 198 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 199 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 200 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 201 | 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 205 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | 209 | def _make_layer(self, block, planes, num_blocks, stride): 210 | strides = [stride] + [1] * (num_blocks - 1) 211 | layers = [] 212 | for i in range(num_blocks): 213 | stride = strides[i] 214 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 215 | self.in_planes = planes * block.expansion 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x): 219 | ss_logits = [] 220 | for i in range(len(x)): 221 | idx = i + 1 222 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 223 | out = self.avg_pool(out) 224 | out = out.view(out.size(0), -1) 225 | out = getattr(self, "fc" + str(idx))(out) 226 | ss_logits.append(out) 227 | return ss_logits 228 | 229 | 230 | class ResNet_Auxiliary(nn.Module): 231 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 232 | super(ResNet_Auxiliary, self).__init__() 233 | self.backbone = ResNet( 234 | block, num_blocks, num_classes=num_classes, zero_init_residual=zero_init_residual 235 | ) 236 | self.auxiliary_classifier = Auxiliary_Classifier( 237 | block, num_blocks, num_classes=num_classes * 4, zero_init_residual=zero_init_residual 238 | ) 239 | 240 | def forward(self, x, grad=False): 241 | if grad is False: 242 | feats, logit = self.backbone(x, is_feat=True) 243 | for i in range(len(feats)): 244 | feats[i] = feats[i].detach() 245 | else: 246 | feats, logit = self.backbone(x, is_feat=True) 247 | 248 | ss_logits = self.auxiliary_classifier(feats) 249 | return logit, ss_logits 250 | 251 | 252 | def ResNet18(**kwargs): 253 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 254 | 255 | 256 | def ResNet34(**kwargs): 257 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 258 | 259 | 260 | def ResNet50(**kwargs): 261 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 262 | 263 | 264 | def ResNet50_aux(**kwargs): 265 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 266 | 267 | 268 | def ResNet101(**kwargs): 269 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 270 | 271 | 272 | def ResNet152(**kwargs): 273 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 274 | -------------------------------------------------------------------------------- /backbones/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | """ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | 4 | 5 | adding hyperparameter norm_layer: Huanran Chen 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __all__ = ["ShuffleV2_aux", "ShuffleV2"] 12 | 13 | 14 | class ShuffleBlock(nn.Module): 15 | def __init__(self, groups=2): 16 | super(ShuffleBlock, self).__init__() 17 | self.groups = groups 18 | 19 | def forward(self, x): 20 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 21 | N, C, H, W = x.size() 22 | g = self.groups 23 | return x.view(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 24 | 25 | 26 | class SplitBlock(nn.Module): 27 | def __init__(self, ratio): 28 | super(SplitBlock, self).__init__() 29 | self.ratio = ratio 30 | 31 | def forward(self, x): 32 | c = int(x.size(1) * self.ratio) 33 | return x[:, :c, :, :], x[:, c:, :, :] 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | def __init__(self, in_channels, split_ratio=0.5, is_last=False, norm_layer=nn.BatchNorm2d): 38 | super(BasicBlock, self).__init__() 39 | self.is_last = is_last 40 | self.split = SplitBlock(split_ratio) 41 | in_channels = int(in_channels * split_ratio) 42 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 43 | self.bn1 = norm_layer(in_channels) 44 | self.conv2 = nn.Conv2d( 45 | in_channels, 46 | in_channels, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1, 50 | groups=in_channels, 51 | bias=False, 52 | ) 53 | self.bn2 = norm_layer(in_channels) 54 | self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 55 | self.bn3 = norm_layer(in_channels) 56 | self.shuffle = ShuffleBlock() 57 | 58 | def forward(self, x): 59 | x1, x2 = self.split(x) 60 | out = F.relu(self.bn1(self.conv1(x2))) 61 | out = self.bn2(self.conv2(out)) 62 | preact = self.bn3(self.conv3(out)) 63 | out = F.relu(preact) 64 | # out = F.relu(self.bn3(self.conv3(out))) 65 | preact = torch.cat([x1, preact], 1) 66 | out = torch.cat([x1, out], 1) 67 | out = self.shuffle(out) 68 | return out 69 | 70 | 71 | class DownBlock(nn.Module): 72 | def __init__(self, in_channels, out_channels, stride=2): 73 | super(DownBlock, self).__init__() 74 | mid_channels = out_channels // 2 75 | # left 76 | self.conv1 = nn.Conv2d( 77 | in_channels, 78 | in_channels, 79 | kernel_size=3, 80 | stride=stride, 81 | padding=1, 82 | groups=in_channels, 83 | bias=False, 84 | ) 85 | self.bn1 = nn.BatchNorm2d(in_channels) 86 | self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 87 | self.bn2 = nn.BatchNorm2d(mid_channels) 88 | # right 89 | self.conv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(mid_channels) 91 | self.conv4 = nn.Conv2d( 92 | mid_channels, 93 | mid_channels, 94 | kernel_size=3, 95 | stride=stride, 96 | padding=1, 97 | groups=mid_channels, 98 | bias=False, 99 | ) 100 | self.bn4 = nn.BatchNorm2d(mid_channels) 101 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False) 102 | self.bn5 = nn.BatchNorm2d(mid_channels) 103 | 104 | self.shuffle = ShuffleBlock() 105 | 106 | def forward(self, x): 107 | # left 108 | out1 = self.bn1(self.conv1(x)) 109 | out1 = F.relu(self.bn2(self.conv2(out1))) 110 | # right 111 | out2 = F.relu(self.bn3(self.conv3(x))) 112 | out2 = self.bn4(self.conv4(out2)) 113 | out2 = F.relu(self.bn5(self.conv5(out2))) 114 | # concat 115 | out = torch.cat([out1, out2], 1) 116 | out = self.shuffle(out) 117 | return out 118 | 119 | 120 | class ShuffleNetV2(nn.Module): 121 | def __init__(self, net_size, num_classes=100, norm_layer=nn.BatchNorm2d): 122 | super(ShuffleNetV2, self).__init__() 123 | out_channels = configs[net_size]["out_channels"] 124 | num_blocks = configs[net_size]["num_blocks"] 125 | 126 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 127 | # stride=1, padding=1, bias=False) 128 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 129 | self.bn1 = nn.BatchNorm2d(24) 130 | self.in_channels = 24 131 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0], norm_layer) 132 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1], norm_layer) 133 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2], norm_layer) 134 | self.conv2 = nn.Conv2d( 135 | out_channels[2], out_channels[3], kernel_size=1, stride=1, padding=0, bias=False 136 | ) 137 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 138 | self.linear = nn.Linear(out_channels[3], num_classes) 139 | self.last_channel = out_channels[3] 140 | 141 | def _make_layer(self, out_channels, num_blocks, norm_layer=nn.BatchNorm2d): 142 | layers = [DownBlock(self.in_channels, out_channels)] 143 | for i in range(num_blocks): 144 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1), norm_layer=norm_layer)) 145 | self.in_channels = out_channels 146 | return nn.Sequential(*layers) 147 | 148 | def get_feat_modules(self): 149 | feat_m = nn.ModuleList([]) 150 | feat_m.append(self.conv1) 151 | feat_m.append(self.bn1) 152 | feat_m.append(self.layer1) 153 | feat_m.append(self.layer2) 154 | feat_m.append(self.layer3) 155 | return feat_m 156 | 157 | def get_bn_before_relu(self): 158 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher2') 159 | 160 | def forward(self, x, is_feat=False, preact=False): 161 | out = F.relu(self.bn1(self.conv1(x))) 162 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 163 | out = self.layer1(out) 164 | f1 = out 165 | out = self.layer2(out) 166 | f2 = out 167 | out = self.layer3(out) 168 | f3 = out 169 | out = F.relu(self.bn2(self.conv2(out))) 170 | f4 = out 171 | out = F.avg_pool2d(out, 4) 172 | out = out.view(out.size(0), -1) 173 | out = self.linear(out) 174 | if is_feat: 175 | return [f1, f2, f3, f4], out 176 | else: 177 | return out 178 | 179 | 180 | class Auxiliary_Classifier(nn.Module): 181 | def __init__(self, net_size, num_classes=100, norm_layer=nn.BatchNorm2d): 182 | super(Auxiliary_Classifier, self).__init__() 183 | out_channels = configs[net_size]["out_channels"] 184 | num_blocks = configs[net_size]["num_blocks"] 185 | 186 | self.in_channels = out_channels[0] 187 | self.block_extractor1 = nn.Sequential( 188 | *[ 189 | self._make_layer(out_channels[1], num_blocks[1]), 190 | self._make_layer(out_channels[2], num_blocks[2]), 191 | nn.Conv2d( 192 | out_channels[2], out_channels[3], kernel_size=1, stride=1, padding=0, bias=False 193 | ), 194 | nn.BatchNorm2d(out_channels[3]), 195 | nn.ReLU(inplace=True), 196 | ] 197 | ) 198 | 199 | self.in_channels = out_channels[1] 200 | self.block_extractor2 = nn.Sequential( 201 | *[ 202 | self._make_layer(out_channels[2], num_blocks[2]), 203 | nn.Conv2d( 204 | out_channels[2], out_channels[3], kernel_size=1, stride=1, padding=0, bias=False 205 | ), 206 | nn.BatchNorm2d(out_channels[3]), 207 | nn.ReLU(inplace=True), 208 | ] 209 | ) 210 | 211 | self.in_channels = out_channels[2] 212 | self.block_extractor3 = nn.Sequential( 213 | *[ 214 | self._make_layer(out_channels[2], num_blocks[2], stride=1), 215 | nn.Conv2d( 216 | out_channels[2], out_channels[3], kernel_size=1, stride=1, padding=0, bias=False 217 | ), 218 | nn.BatchNorm2d(out_channels[3]), 219 | nn.ReLU(inplace=True), 220 | ] 221 | ) 222 | 223 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 224 | self.fc1 = nn.Linear(out_channels[3], num_classes) 225 | self.fc2 = nn.Linear(out_channels[3], num_classes) 226 | self.fc3 = nn.Linear(out_channels[3], num_classes) 227 | 228 | def _make_layer(self, out_channels, num_blocks, stride=2): 229 | layers = [DownBlock(self.in_channels, out_channels, stride=stride)] 230 | for i in range(num_blocks): 231 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 232 | self.in_channels = out_channels 233 | return nn.Sequential(*layers) 234 | 235 | def forward(self, x): 236 | ss_logits = [] 237 | for i in range(len(x)): 238 | idx = i + 1 239 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 240 | out = self.avg_pool(out) 241 | out = out.view(out.size(0), -1) 242 | out = getattr(self, "fc" + str(idx))(out) 243 | ss_logits.append(out) 244 | return ss_logits 245 | 246 | 247 | class ShuffleNetV2_Auxiliary(nn.Module): 248 | def __init__(self, net_size, num_classes=100): 249 | super(ShuffleNetV2_Auxiliary, self).__init__() 250 | self.backbone = ShuffleNetV2(net_size, num_classes=num_classes) 251 | self.auxiliary_classifier = Auxiliary_Classifier(net_size, num_classes=num_classes * 4) 252 | 253 | def forward(self, x, grad=False): 254 | feats, logit = self.backbone(x, is_feat=True) 255 | if grad is False: 256 | for i in range(len(feats)): 257 | feats[i] = feats[i].detach() 258 | ss_logits = self.auxiliary_classifier(feats) 259 | return logit, ss_logits 260 | 261 | 262 | configs = { 263 | 0.2: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 3, 3)}, 264 | 0.3: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 7, 3)}, 265 | 0.5: {"out_channels": (48, 96, 192, 1024), "num_blocks": (3, 7, 3)}, 266 | 1: {"out_channels": (116, 232, 464, 1024), "num_blocks": (3, 7, 3)}, 267 | 1.5: {"out_channels": (176, 352, 704, 1024), "num_blocks": (3, 7, 3)}, 268 | 2: {"out_channels": (224, 488, 976, 2048), "num_blocks": (3, 7, 3)}, 269 | } 270 | 271 | 272 | def ShuffleV2(**kwargs): 273 | model = ShuffleNetV2(net_size=1, **kwargs) 274 | return model 275 | 276 | 277 | def ShuffleV2_aux(**kwargs): 278 | model = ShuffleNetV2_Auxiliary(net_size=1, **kwargs) 279 | return model 280 | -------------------------------------------------------------------------------- /backbones/PyramidNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyramidNet with shakedrop with high resolution 3 | Huanran Chen 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | from .batchensemble import Ensemble_Conv2d, Ensemble_FC, Ensemble_orderFC 11 | 12 | __all__ = ["pyramidnet272", "pyramidnet164"] 13 | 14 | _inplace_flag = False 15 | 16 | """ 17 | splitnet/model/pyramidnet.py:175: 18 | UserWarning: Output 0 of ShakeDropFunctionBackward is a view and is being modified inplace. 19 | This view was created inside a custom Function (or because an input was returned as-is) 20 | and the autograd logic to handle view+inplace would override 21 | the custom backward associated with the custom Function, 22 | leading to incorrect gradients. 23 | This behavior is deprecated and will be forbidden starting version 1.6. 24 | You can remove this warning by cloning the output of the custom Function. 25 | (Triggered internally at /opt/conda/conda-bld/pytorch_1595629403081/work/torch/csrc/autograd/variable.cpp:464.) 26 | """ 27 | 28 | 29 | class ShakeDropFunction(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 32 | if training: 33 | gate = torch.FloatTensor([0]).bernoulli_(1 - p_drop).to(x.device) 34 | ctx.save_for_backward(gate) 35 | if gate.item() == 0: 36 | alpha = torch.FloatTensor(x.size(0)).uniform_(*alpha_range).to(x.device) 37 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 38 | return alpha * x 39 | else: 40 | return x 41 | else: 42 | return (1 - p_drop) * x 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | gate = ctx.saved_tensors[0] 47 | if gate.item() == 0: 48 | beta = torch.FloatTensor(grad_output.size(0)).uniform_(0, 1).to(grad_output.device) 49 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 50 | beta = Variable(beta) 51 | return beta * grad_output, None, None, None 52 | else: 53 | return grad_output, None, None, None 54 | 55 | 56 | class ShakeDrop(nn.Module): 57 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 58 | super(ShakeDrop, self).__init__() 59 | self.p_drop = p_drop 60 | self.alpha_range = alpha_range 61 | 62 | def forward(self, x): 63 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 64 | 65 | 66 | def conv3x3(in_planes, out_planes, stride=1, num_model=-1): 67 | """ 68 | 3x3 convolution with padding 69 | """ 70 | return ( 71 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 72 | if num_model <= 0 73 | else Ensemble_Conv2d( 74 | in_planes, 75 | out_planes, 76 | kernel_size=3, 77 | stride=stride, 78 | padding=1.0, 79 | bias=False, 80 | num_models=num_model, 81 | ) 82 | ) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | outchannel_ratio = 4 87 | 88 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0, num_model=-1): 89 | super(BasicBlock, self).__init__() 90 | self.num_model = num_model 91 | self.bn1 = nn.BatchNorm2d(inplanes) 92 | self.conv1 = ( 93 | conv3x3(inplanes, planes, stride) 94 | if self.num_model <= 0 95 | else conv3x3(inplanes, planes, stride, num_model) 96 | ) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv2 = ( 99 | conv3x3(planes, planes) 100 | if self.num_model <= 0 101 | else conv3x3(inplanes, planes, 1, num_model) 102 | ) 103 | self.bn3 = nn.BatchNorm2d(planes) 104 | self.relu = nn.ReLU(inplace=_inplace_flag) 105 | self.downsample = downsample 106 | self.stride = stride 107 | self.shake_drop = ShakeDrop(p_shakedrop) 108 | 109 | def forward(self, x): 110 | 111 | out = self.bn1(x) 112 | out = self.conv1(out) 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | out = self.conv2(out) 116 | out = self.bn3(out) 117 | 118 | out = self.shake_drop(out) 119 | 120 | if self.downsample is not None: 121 | shortcut = self.downsample(x) 122 | featuremap_size = shortcut.size()[2:4] 123 | else: 124 | shortcut = x 125 | featuremap_size = out.size()[2:4] 126 | 127 | batch_size = out.size()[0] 128 | residual_channel = out.size()[1] 129 | shortcut_channel = shortcut.size()[1] 130 | 131 | if residual_channel != shortcut_channel: 132 | padding = torch.autograd.Variable( 133 | torch.FloatTensor( 134 | batch_size, 135 | residual_channel - shortcut_channel, 136 | featuremap_size[0], 137 | featuremap_size[1], 138 | ).fill_(0) 139 | ).to(x.device) 140 | out += torch.cat((shortcut, padding), 1) 141 | else: 142 | out += shortcut 143 | 144 | return out 145 | 146 | 147 | class Bottleneck(nn.Module): 148 | outchannel_ratio = 4 149 | 150 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0, num_model=-1): 151 | super(Bottleneck, self).__init__() 152 | self.bn1 = nn.BatchNorm2d(inplanes) 153 | self.num_model = num_model 154 | self.conv1 = ( 155 | nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 156 | if self.num_model <= 0 157 | else Ensemble_Conv2d(inplanes, planes, kernel_size=1, bias=False, num_models=num_model) 158 | ) 159 | self.bn2 = nn.BatchNorm2d(planes) 160 | self.conv2 = ( 161 | nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, padding=1, bias=False) 162 | if self.num_model <= 0 163 | else Ensemble_Conv2d( 164 | planes, 165 | (planes * 1), 166 | kernel_size=3, 167 | stride=stride, 168 | padding=1, 169 | bias=False, 170 | num_models=num_model, 171 | ) 172 | ) 173 | self.bn3 = nn.BatchNorm2d((planes * 1)) 174 | self.conv3 = ( 175 | nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 176 | if self.num_model <= 0 177 | else Ensemble_Conv2d( 178 | (planes * 1), 179 | planes * Bottleneck.outchannel_ratio, 180 | kernel_size=1, 181 | bias=False, 182 | num_models=num_model, 183 | ) 184 | ) 185 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 186 | self.relu = nn.ReLU(inplace=_inplace_flag) 187 | self.downsample = downsample 188 | self.stride = stride 189 | self.shake_drop = ShakeDrop(p_shakedrop) 190 | 191 | def forward(self, x): 192 | 193 | out = self.bn1(x) 194 | out = self.conv1(out) 195 | 196 | out = self.bn2(out) 197 | out = self.relu(out) 198 | out = self.conv2(out) 199 | 200 | out = self.bn3(out) 201 | out = self.relu(out) 202 | out = self.conv3(out) 203 | 204 | out = self.bn4(out) 205 | 206 | out = self.shake_drop(out) 207 | 208 | if self.downsample is not None: 209 | shortcut = self.downsample(x) 210 | featuremap_size = shortcut.size()[2:4] 211 | else: 212 | shortcut = x 213 | featuremap_size = out.size()[2:4] 214 | 215 | batch_size = out.size()[0] 216 | residual_channel = out.size()[1] 217 | shortcut_channel = shortcut.size()[1] 218 | 219 | if residual_channel != shortcut_channel: 220 | padding = torch.autograd.Variable( 221 | torch.FloatTensor( 222 | batch_size, 223 | residual_channel - shortcut_channel, 224 | featuremap_size[0], 225 | featuremap_size[1], 226 | ).fill_(0) 227 | ).to(x.device) 228 | out = out + torch.cat((shortcut, padding), 1) 229 | else: 230 | out = out + shortcut 231 | 232 | return out 233 | 234 | 235 | class PyramidNet(nn.Module): 236 | def __init__( 237 | self, 238 | bottleneck=True, 239 | depth=272, 240 | alpha=200, 241 | num_classes=60, 242 | split_factor=1, 243 | num_models=-1, 244 | ): 245 | super(PyramidNet, self).__init__() 246 | """ 247 | inplanes_dict = {'imagenet': {1: 64, 2: 44, 4: 32, 8: 24}, 248 | 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, 249 | 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, 250 | 'svhn': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, 251 | } 252 | """ 253 | self.num_models = num_models 254 | self.inplanes = 16 255 | 256 | if bottleneck: 257 | n = int((depth - 2) / 9) 258 | block = Bottleneck 259 | else: 260 | n = int((depth - 2) / 6) 261 | block = BasicBlock 262 | 263 | # self.addrate = alpha / (3 * n * 1.0) 264 | self.addrate = alpha / (3 * n * (split_factor ** 0.5)) 265 | self.final_shake_p = 0.5 / (split_factor ** 0.5) 266 | print( 267 | "INFO:PyTorch: PyramidNet: The add rate is {}, " 268 | "the final shake p is {}".format(self.addrate, self.final_shake_p) 269 | ) 270 | 271 | self.ps_shakedrop = [ 272 | 1.0 - (1.0 - (self.final_shake_p / (3 * n)) * (i + 1)) for i in range(3 * n) 273 | ] 274 | 275 | self.input_featuremap_dim = self.inplanes 276 | self.conv1 = ( 277 | nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 278 | if self.num_models <= 0 279 | else Ensemble_Conv2d( 280 | 3, 281 | self.input_featuremap_dim, 282 | kernel_size=7, 283 | stride=2, 284 | padding=3, 285 | bias=False, 286 | num_models=num_models, 287 | ) 288 | ) 289 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 290 | self.relu = nn.ReLU(inplace=_inplace_flag) 291 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 292 | 293 | self.featuremap_dim = self.input_featuremap_dim 294 | self.layer1 = self.pyramidal_make_layer(block, n, stride=1, ensemble=False) 295 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2, ensemble=False) 296 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2, ensemble=True) 297 | 298 | self.final_featuremap_dim = self.input_featuremap_dim 299 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 300 | self.relu_final = nn.ReLU(inplace=_inplace_flag) 301 | # self.avgpool = nn.AvgPool2d(8) 302 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 303 | self.fc = ( 304 | nn.Linear(self.final_featuremap_dim, num_classes) 305 | if self.num_models <= 0 306 | else Ensemble_orderFC( 307 | self.final_featuremap_dim, num_classes, num_models=self.num_models 308 | ) 309 | ) 310 | 311 | def pyramidal_make_layer(self, block, block_depth, stride=1, ensemble=False): 312 | downsample = None 313 | if ( 314 | stride != 1 315 | ): # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 316 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 317 | 318 | layers = [] 319 | self.featuremap_dim = self.featuremap_dim + self.addrate 320 | layers.append( 321 | block( 322 | self.input_featuremap_dim, 323 | int(round(self.featuremap_dim)), 324 | stride, 325 | downsample, 326 | p_shakedrop=self.ps_shakedrop.pop(0), 327 | num_model=self.num_models if ensemble else -1, 328 | ) 329 | ) 330 | for i in range(1, block_depth): 331 | temp_featuremap_dim = self.featuremap_dim + self.addrate 332 | layers.append( 333 | block( 334 | int(round(self.featuremap_dim)) * block.outchannel_ratio, 335 | int(round(temp_featuremap_dim)), 336 | 1, 337 | p_shakedrop=self.ps_shakedrop.pop(0), 338 | num_model=-1 if stride == 1 else self.num_models, 339 | ) 340 | ) 341 | self.featuremap_dim = temp_featuremap_dim 342 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 343 | 344 | return nn.Sequential(*layers) 345 | 346 | def forward(self, x): 347 | x = self.conv1(x) 348 | x = self.bn1(x) 349 | x = self.relu(x) 350 | x = self.maxpool(x) 351 | x = self.layer1(x) 352 | x = self.layer2(x) 353 | x = self.layer3(x) 354 | x = self.bn_final(x) 355 | x = self.relu_final(x) 356 | x = self.avgpool(x) 357 | x = x.view(x.size(0), -1) 358 | x = self.fc(x) 359 | return x 360 | 361 | 362 | def pyramidnet164(bottleneck=True, num_models=-1, **kwargs): 363 | """PyramidNet164 for CIFAR and SVHN""" 364 | return PyramidNet(bottleneck=bottleneck, depth=164, alpha=270, num_models=num_models, **kwargs) 365 | 366 | 367 | def pyramidnet272(bottleneck=True, num_models=-1, **kwargs): 368 | """PyramidNet272 for CIFAR and SVHN""" 369 | return PyramidNet(bottleneck=bottleneck, depth=272, alpha=200, num_models=num_models, **kwargs) 370 | -------------------------------------------------------------------------------- /backbones/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = [ 5 | "resnet18_imagenet", 6 | "resnet18_imagenet_aux", 7 | "resnet34_imagenet", 8 | "resnet34_imagenet_aux", 9 | "resnet50_imagenet", 10 | "resnet50_imagenet_aux", 11 | ] 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=dilation, 22 | groups=groups, 23 | bias=False, 24 | dilation=dilation, 25 | ) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__( 37 | self, 38 | inplanes, 39 | planes, 40 | stride=1, 41 | downsample=None, 42 | groups=1, 43 | base_width=64, 44 | dilation=1, 45 | norm_layer=None, 46 | ): 47 | super(BasicBlock, self).__init__() 48 | if norm_layer is None: 49 | norm_layer = nn.BatchNorm2d 50 | if groups != 1 or base_width != 64: 51 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 52 | if dilation > 1: 53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 55 | self.conv1 = conv3x3(inplanes, planes, stride) 56 | self.bn1 = norm_layer(planes) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3(planes, planes) 59 | self.bn2 = norm_layer(planes) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | identity = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | identity = self.downsample(x) 75 | 76 | out += identity 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 84 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 85 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 86 | # This variant is also known as ResNet V1.5 and improves accuracy according to 87 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 88 | 89 | expansion = 4 90 | 91 | def __init__( 92 | self, 93 | inplanes, 94 | planes, 95 | stride=1, 96 | downsample=None, 97 | groups=1, 98 | base_width=64, 99 | dilation=1, 100 | norm_layer=None, 101 | ): 102 | super(Bottleneck, self).__init__() 103 | if norm_layer is None: 104 | norm_layer = nn.BatchNorm2d 105 | width = int(planes * (base_width / 64.0)) * groups 106 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 107 | self.conv1 = conv1x1(inplanes, width) 108 | self.bn1 = norm_layer(width) 109 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 110 | self.bn2 = norm_layer(width) 111 | self.conv3 = conv1x1(width, planes * self.expansion) 112 | self.bn3 = norm_layer(planes * self.expansion) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.downsample = downsample 115 | self.stride = stride 116 | 117 | def forward(self, x): 118 | identity = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv3(out) 129 | out = self.bn3(out) 130 | 131 | if self.downsample is not None: 132 | identity = self.downsample(x) 133 | 134 | out += identity 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class ResNet(nn.Module): 141 | def __init__( 142 | self, 143 | block, 144 | layers, 145 | num_classes=1000, 146 | zero_init_residual=False, 147 | groups=1, 148 | width_per_group=64, 149 | replace_stride_with_dilation=None, 150 | norm_layer=None, 151 | ): 152 | super(ResNet, self).__init__() 153 | if norm_layer is None: 154 | norm_layer = nn.BatchNorm2d 155 | self._norm_layer = norm_layer 156 | 157 | self.inplanes = 64 158 | self.dilation = 1 159 | if replace_stride_with_dilation is None: 160 | # each element in the tuple indicates if we should replace 161 | # the 2x2 stride with a dilated convolution instead 162 | replace_stride_with_dilation = [False, False, False] 163 | if len(replace_stride_with_dilation) != 3: 164 | raise ValueError( 165 | "replace_stride_with_dilation should be None " 166 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 167 | ) 168 | self.groups = groups 169 | self.base_width = width_per_group 170 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 171 | self.bn1 = norm_layer(self.inplanes) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 174 | self.layer1 = self._make_layer(block, 64, layers[0]) 175 | self.layer2 = self._make_layer( 176 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 177 | ) 178 | self.layer3 = self._make_layer( 179 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 180 | ) 181 | self.layer4 = self._make_layer( 182 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 183 | ) 184 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 185 | self.fc = nn.Linear(512 * block.expansion, num_classes) 186 | self.last_channel = 512 * block.expansion 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 206 | norm_layer = self._norm_layer 207 | downsample = None 208 | previous_dilation = self.dilation 209 | if dilate: 210 | self.dilation *= stride 211 | stride = 1 212 | if stride != 1 or self.inplanes != planes * block.expansion: 213 | downsample = nn.Sequential( 214 | conv1x1(self.inplanes, planes * block.expansion, stride), 215 | norm_layer(planes * block.expansion), 216 | ) 217 | 218 | layers = [] 219 | layers.append( 220 | block( 221 | self.inplanes, 222 | planes, 223 | stride, 224 | downsample, 225 | self.groups, 226 | self.base_width, 227 | previous_dilation, 228 | norm_layer, 229 | ) 230 | ) 231 | self.inplanes = planes * block.expansion 232 | for _ in range(1, blocks): 233 | layers.append( 234 | block( 235 | self.inplanes, 236 | planes, 237 | groups=self.groups, 238 | base_width=self.base_width, 239 | dilation=self.dilation, 240 | norm_layer=norm_layer, 241 | ) 242 | ) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x, is_feat=False): 247 | # See note [TorchScript super()] 248 | x = self.conv1(x) 249 | x = self.bn1(x) 250 | x = self.relu(x) 251 | x = self.maxpool(x) 252 | 253 | x = self.layer1(x) 254 | f1 = x 255 | x = self.layer2(x) 256 | f2 = x 257 | x = self.layer3(x) 258 | f3 = x 259 | x = self.layer4(x) 260 | f4 = x 261 | 262 | x = self.avgpool(x) 263 | x = torch.flatten(x, 1) 264 | x = self.fc(x) 265 | 266 | if is_feat: 267 | return [f1, f2, f3, f4], x 268 | else: 269 | return x 270 | 271 | 272 | class Auxiliary_Classifier(nn.Module): 273 | def __init__( 274 | self, 275 | block, 276 | layers, 277 | num_classes=1000, 278 | zero_init_residual=False, 279 | groups=1, 280 | width_per_group=64, 281 | replace_stride_with_dilation=None, 282 | norm_layer=None, 283 | ): 284 | super(Auxiliary_Classifier, self).__init__() 285 | 286 | self.dilation = 1 287 | self.groups = groups 288 | self.base_width = width_per_group 289 | self.inplanes = 64 * block.expansion 290 | self.block_extractor1 = nn.Sequential( 291 | *[ 292 | self._make_layer(block, 128, layers[1], stride=2), 293 | self._make_layer(block, 256, layers[2], stride=2), 294 | self._make_layer(block, 512, layers[3], stride=2), 295 | ] 296 | ) 297 | 298 | self.inplanes = 128 * block.expansion 299 | self.block_extractor2 = nn.Sequential( 300 | *[ 301 | self._make_layer(block, 256, layers[2], stride=2), 302 | self._make_layer(block, 512, layers[3], stride=2), 303 | ] 304 | ) 305 | 306 | self.inplanes = 256 * block.expansion 307 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=2)]) 308 | 309 | self.inplanes = 512 * block.expansion 310 | self.block_extractor4 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=1)]) 311 | 312 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 313 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 314 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 315 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 316 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 317 | 318 | for m in self.modules(): 319 | if isinstance(m, nn.Conv2d): 320 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 321 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 322 | nn.init.constant_(m.weight, 1) 323 | nn.init.constant_(m.bias, 0) 324 | 325 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 326 | norm_layer = nn.BatchNorm2d 327 | downsample = None 328 | previous_dilation = self.dilation 329 | if dilate: 330 | self.dilation *= stride 331 | stride = 1 332 | if stride != 1 or self.inplanes != planes * block.expansion: 333 | downsample = nn.Sequential( 334 | conv1x1(self.inplanes, planes * block.expansion, stride), 335 | norm_layer(planes * block.expansion), 336 | ) 337 | 338 | layers = [] 339 | layers.append( 340 | block( 341 | self.inplanes, 342 | planes, 343 | stride, 344 | downsample, 345 | self.groups, 346 | self.base_width, 347 | previous_dilation, 348 | norm_layer, 349 | ) 350 | ) 351 | self.inplanes = planes * block.expansion 352 | for _ in range(1, blocks): 353 | layers.append( 354 | block( 355 | self.inplanes, 356 | planes, 357 | groups=self.groups, 358 | base_width=self.base_width, 359 | dilation=self.dilation, 360 | norm_layer=norm_layer, 361 | ) 362 | ) 363 | 364 | return nn.Sequential(*layers) 365 | 366 | def forward(self, x): 367 | ss_logits = [] 368 | for i in range(len(x)): 369 | idx = i + 1 370 | 371 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 372 | out = self.avg_pool(out) 373 | out = out.view(out.size(0), -1) 374 | out = getattr(self, "fc" + str(idx))(out) 375 | ss_logits.append(out) 376 | return ss_logits 377 | 378 | 379 | class ResNet_Auxiliary(nn.Module): 380 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 381 | super(ResNet_Auxiliary, self).__init__() 382 | self.backbone = ResNet( 383 | block, layers, num_classes=num_classes, zero_init_residual=zero_init_residual 384 | ) 385 | self.auxiliary_classifier = Auxiliary_Classifier( 386 | block, layers, num_classes=num_classes * 4, zero_init_residual=zero_init_residual 387 | ) 388 | 389 | def forward(self, x, grad=False): 390 | if grad is False: 391 | feats, logit = self.backbone(x, is_feat=True) 392 | for i in range(len(feats)): 393 | feats[i] = feats[i].detach() 394 | else: 395 | feats, logit = self.backbone(x, is_feat=True) 396 | 397 | ss_logits = self.auxiliary_classifier(feats) 398 | return logit, ss_logits 399 | 400 | 401 | def resnet18_imagenet(**kwargs): 402 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 403 | 404 | 405 | def resnet18_imagenet_aux(**kwargs): 406 | return ResNet_Auxiliary(BasicBlock, [2, 2, 2, 2], **kwargs) 407 | 408 | 409 | def resnet34_imagenet(**kwargs): 410 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 411 | 412 | 413 | def resnet34_imagenet_aux(**kwargs): 414 | return ResNet_Auxiliary(BasicBlock, [3, 4, 6, 3], **kwargs) 415 | 416 | 417 | def resnet50_imagenet(**kwargs): 418 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 419 | 420 | 421 | def resnet50_imagenet_aux(**kwargs): 422 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 423 | -------------------------------------------------------------------------------- /backbones/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | """ 8 | Original Author: Wei Yang 9 | 10 | 11 | adding hyperparameter norm_layer: Huanran Chen 12 | """ 13 | 14 | __all__ = [ 15 | "wrn", 16 | "wrn_40_2_aux", 17 | "wrn_16_2_aux", 18 | "wrn_16_1", 19 | "wrn_16_2", 20 | "wrn_40_1", 21 | "wrn_40_2", 22 | "wrn_40_1_aux", 23 | "wrn_16_2_spkd", 24 | "wrn_40_1_spkd", 25 | "wrn_40_2_spkd", 26 | "wrn_40_1_crd", 27 | "wrn_16_2_crd", 28 | "wrn_40_2_crd", 29 | "wrn_16_2_sskd", 30 | "wrn_40_1_sskd", 31 | "wrn_40_2_sskd", 32 | ] 33 | 34 | 35 | class Normalizer4CRD(nn.Module): 36 | def __init__(self, linear, power=2): 37 | super().__init__() 38 | self.linear = linear 39 | self.power = power 40 | 41 | def forward(self, x): 42 | x = x.flatten(1) 43 | z = self.linear(x) 44 | norm = z.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) 45 | out = z.div(norm) 46 | return out 47 | 48 | 49 | class BasicBlock(nn.Module): 50 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, norm_layer=nn.BatchNorm2d): 51 | super(BasicBlock, self).__init__() 52 | self.bn1 = norm_layer(in_planes) 53 | self.relu1 = nn.ReLU(inplace=True) 54 | self.conv1 = nn.Conv2d( 55 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 56 | ) 57 | self.bn2 = norm_layer(out_planes) 58 | self.relu2 = nn.ReLU(inplace=True) 59 | self.conv2 = nn.Conv2d( 60 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 61 | ) 62 | self.droprate = dropRate 63 | self.equalInOut = in_planes == out_planes 64 | self.convShortcut = ( 65 | (not self.equalInOut) 66 | and nn.Conv2d( 67 | in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False 68 | ) 69 | or None 70 | ) 71 | 72 | def forward(self, x): 73 | if not self.equalInOut: 74 | x = self.relu1(self.bn1(x)) 75 | else: 76 | out = self.relu1(self.bn1(x)) 77 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 78 | if self.droprate > 0: 79 | out = F.dropout(out, p=self.droprate, training=self.training) 80 | out = self.conv2(out) 81 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 82 | 83 | 84 | class NetworkBlock(nn.Module): 85 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, 86 | norm_layer=nn.BatchNorm2d): 87 | super(NetworkBlock, self).__init__() 88 | self.layer = self._make_layer(block, in_planes, out_planes, 89 | nb_layers, stride, dropRate, norm_layer) 90 | 91 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, norm_layer): 92 | layers = [] 93 | for i in range(nb_layers): 94 | layers.append( 95 | block( 96 | i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, 97 | norm_layer=norm_layer 98 | ) 99 | ) 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | return self.layer(x) 104 | 105 | 106 | class WideResNet(nn.Module): 107 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, 108 | norm_layer=nn.BatchNorm2d): 109 | super(WideResNet, self).__init__() 110 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 111 | assert (depth - 4) % 6 == 0, "depth should be 6n+4" 112 | n = (depth - 4) // 6 113 | block = BasicBlock 114 | # 1st conv before any network block 115 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 116 | # 1st block 117 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, norm_layer) 118 | # 2nd block 119 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, norm_layer) 120 | # 3rd block 121 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, norm_layer) 122 | # global average pooling and classifier 123 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.last_channel = nChannels[3] 126 | self.fc = nn.Linear(nChannels[3], num_classes) 127 | self.nChannels = nChannels[3] 128 | 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 132 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | m.weight.data.fill_(1) 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.Linear): 137 | m.bias.data.zero_() 138 | 139 | def get_feat_modules(self): 140 | feat_m = nn.ModuleList([]) 141 | feat_m.append(self.conv1) 142 | feat_m.append(self.block1) 143 | feat_m.append(self.block2) 144 | feat_m.append(self.block3) 145 | return feat_m 146 | 147 | def get_bn_before_relu(self): 148 | bn1 = self.block2.layer[0].bn1 149 | bn2 = self.block3.layer[0].bn1 150 | bn3 = self.bn1 151 | 152 | return [bn1, bn2, bn3] 153 | 154 | def forward(self, x, is_feat=False, preact=False): 155 | out = self.conv1(x) 156 | out = self.block1(out) 157 | f1 = out 158 | out = self.block2(out) 159 | f2 = out 160 | out = self.block3(out) 161 | f3 = out 162 | out = self.relu(self.bn1(out)) 163 | f4 = out 164 | out = F.avg_pool2d(out, 8) 165 | out = out.view(-1, self.nChannels) 166 | out = self.fc(out) 167 | if is_feat: 168 | return [f1, f2, f3, f4], out 169 | else: 170 | return out 171 | 172 | 173 | class Auxiliary_Classifier(nn.Module): 174 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 175 | super(Auxiliary_Classifier, self).__init__() 176 | self.nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 177 | block = BasicBlock 178 | n = (depth - 4) // 6 179 | self.block_extractor1 = nn.Sequential( 180 | *[ 181 | NetworkBlock(n, self.nChannels[1], self.nChannels[2], block, 2), 182 | NetworkBlock(n, self.nChannels[2], self.nChannels[3], block, 2), 183 | ] 184 | ) 185 | self.block_extractor2 = nn.Sequential( 186 | *[NetworkBlock(n, self.nChannels[2], self.nChannels[3], block, 2)] 187 | ) 188 | self.block_extractor3 = nn.Sequential( 189 | *[NetworkBlock(n, self.nChannels[3], self.nChannels[3], block, 1)] 190 | ) 191 | 192 | self.bn1 = nn.BatchNorm2d(self.nChannels[3]) 193 | self.bn2 = nn.BatchNorm2d(self.nChannels[3]) 194 | self.bn3 = nn.BatchNorm2d(self.nChannels[3]) 195 | 196 | self.relu = nn.ReLU(inplace=True) 197 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 198 | self.fc1 = nn.Linear(self.nChannels[3], num_classes) 199 | self.fc2 = nn.Linear(self.nChannels[3], num_classes) 200 | self.fc3 = nn.Linear(self.nChannels[3], num_classes) 201 | 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 205 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 206 | elif isinstance(m, nn.BatchNorm2d): 207 | m.weight.data.fill_(1) 208 | m.bias.data.zero_() 209 | elif isinstance(m, nn.Linear): 210 | m.bias.data.zero_() 211 | 212 | def forward(self, x): 213 | ss_logits = [] 214 | ss_feats = [] 215 | for i in range(len(x)): 216 | idx = i + 1 217 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 218 | out = self.relu(getattr(self, "bn" + str(idx))(out)) 219 | out = self.avg_pool(out) 220 | out = out.view(-1, self.nChannels[3]) 221 | ss_feats.append(out) 222 | out = getattr(self, "fc" + str(idx))(out) 223 | ss_logits.append(out) 224 | return ss_logits 225 | 226 | 227 | class WideResNet_Auxiliary(nn.Module): 228 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 229 | super(WideResNet_Auxiliary, self).__init__() 230 | self.backbone = WideResNet(depth, num_classes, widen_factor=widen_factor) 231 | self.auxiliary_classifier = Auxiliary_Classifier( 232 | depth=depth, num_classes=num_classes * 4, widen_factor=widen_factor 233 | ) 234 | 235 | def forward(self, x, grad=False): 236 | feats, logit = self.backbone(x, is_feat=True) 237 | if grad is False: 238 | for i in range(len(feats)): 239 | feats[i] = feats[i].detach() 240 | ss_logits = self.auxiliary_classifier(feats) 241 | 242 | return logit, ss_logits 243 | 244 | 245 | class WideResNet_SPKD(WideResNet): 246 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 247 | super(WideResNet_SPKD, self).__init__(depth, num_classes, widen_factor, dropRate) 248 | 249 | def forward(self, x, is_feat=False, preact=False): 250 | out = self.conv1(x) 251 | out = self.block1(out) 252 | out = self.block2(out) 253 | out = self.block3(out) 254 | out = self.relu(self.bn1(out)) 255 | out = F.avg_pool2d(out, 8) 256 | out = out.view(-1, self.nChannels) 257 | f4 = out 258 | out = self.fc(out) 259 | return f4, out 260 | 261 | 262 | class WideResNet_SSKD(WideResNet): 263 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 264 | super(WideResNet_SSKD, self).__init__(depth, num_classes, widen_factor, dropRate) 265 | self.ss_module = nn.Sequential( 266 | nn.Linear(self.nChannels, self.nChannels), 267 | nn.ReLU(inplace=True), 268 | nn.Linear(self.nChannels, self.nChannels), 269 | ) 270 | 271 | def forward(self, x, is_feat=False, preact=False): 272 | out = self.conv1(x) 273 | out = self.block1(out) 274 | out = self.block2(out) 275 | out = self.block3(out) 276 | out = self.relu(self.bn1(out)) 277 | out = F.avg_pool2d(out, 8) 278 | out = out.view(-1, self.nChannels) 279 | f4 = self.ss_module(out) 280 | out = self.fc(out) 281 | return f4, out 282 | 283 | 284 | class WideResNet_CRD(nn.Module): 285 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, 286 | norm_layer=nn.BatchNorm2d): 287 | super(WideResNet_CRD, self).__init__() 288 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 289 | assert (depth - 4) % 6 == 0, "depth should be 6n+4" 290 | n = (depth - 4) // 6 291 | block = BasicBlock 292 | # 1st conv before any network block 293 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 294 | # 1st block 295 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 296 | # 2nd block 297 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 298 | # 3rd block 299 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 300 | # global average pooling and classifier 301 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 302 | self.relu = nn.ReLU(inplace=True) 303 | self.fc = nn.Linear(nChannels[3], num_classes) 304 | linear = nn.Linear(nChannels[3], 128, bias=True) 305 | self.normalizer = Normalizer4CRD(linear, power=2) 306 | self.nChannels = nChannels[3] 307 | 308 | for m in self.modules(): 309 | if isinstance(m, nn.Conv2d): 310 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 311 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 312 | elif isinstance(m, nn.BatchNorm2d): 313 | m.weight.data.fill_(1) 314 | m.bias.data.zero_() 315 | elif isinstance(m, nn.Linear): 316 | m.bias.data.zero_() 317 | 318 | def get_feat_modules(self): 319 | feat_m = nn.ModuleList([]) 320 | feat_m.append(self.conv1) 321 | feat_m.append(self.block1) 322 | feat_m.append(self.block2) 323 | feat_m.append(self.block3) 324 | return feat_m 325 | 326 | def get_bn_before_relu(self): 327 | bn1 = self.block2.layer[0].bn1 328 | bn2 = self.block3.layer[0].bn1 329 | bn3 = self.bn1 330 | 331 | return [bn1, bn2, bn3] 332 | 333 | def forward(self, x, is_feat=False, preact=False): 334 | out = self.conv1(x) 335 | out = self.block1(out) 336 | out = self.block2(out) 337 | out = self.block3(out) 338 | out = self.relu(self.bn1(out)) 339 | out = F.avg_pool2d(out, 8) 340 | crdout = out 341 | out = out.view(-1, self.nChannels) 342 | out = self.fc(out) 343 | crdout = self.normalizer(crdout) 344 | return crdout, out 345 | 346 | 347 | def wrn(**kwargs): 348 | """ 349 | Constructs a Wide Residual Networks. 350 | """ 351 | model = WideResNet(**kwargs) 352 | return model 353 | 354 | 355 | def wrn_40_2(**kwargs): 356 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 357 | return model 358 | 359 | 360 | def wrn_40_2_aux(**kwargs): 361 | model = WideResNet_Auxiliary(depth=40, widen_factor=2, **kwargs) 362 | return model 363 | 364 | 365 | def wrn_40_2_spkd(**kwargs): 366 | model = WideResNet_SPKD(depth=40, widen_factor=2, **kwargs) 367 | return model 368 | 369 | 370 | def wrn_40_2_sskd(**kwargs): 371 | model = WideResNet_SSKD(depth=40, widen_factor=2, **kwargs) 372 | return model 373 | 374 | 375 | def wrn_40_2_crd(**kwargs): 376 | model = WideResNet_CRD(depth=40, widen_factor=2, **kwargs) 377 | return model 378 | 379 | 380 | def wrn_40_1(**kwargs): 381 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 382 | return model 383 | 384 | 385 | def wrn_40_1_aux(**kwargs): 386 | model = WideResNet_Auxiliary(depth=40, widen_factor=1, **kwargs) 387 | return model 388 | 389 | 390 | def wrn_40_1_spkd(**kwargs): 391 | model = WideResNet_SPKD(depth=40, widen_factor=1, **kwargs) 392 | return model 393 | 394 | 395 | def wrn_40_1_crd(**kwargs): 396 | model = WideResNet_CRD(depth=40, widen_factor=1, **kwargs) 397 | return model 398 | 399 | 400 | def wrn_40_1_sskd(**kwargs): 401 | model = WideResNet_SSKD(depth=40, widen_factor=1, **kwargs) 402 | return model 403 | 404 | 405 | def wrn_16_2(**kwargs): 406 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 407 | return model 408 | 409 | 410 | def wrn_16_2_aux(**kwargs): 411 | model = WideResNet_Auxiliary(depth=16, widen_factor=2, **kwargs) 412 | return model 413 | 414 | 415 | def wrn_16_2_spkd(**kwargs): 416 | model = WideResNet_SPKD(depth=16, widen_factor=2, **kwargs) 417 | return model 418 | 419 | 420 | def wrn_16_2_crd(**kwargs): 421 | model = WideResNet_CRD(depth=16, widen_factor=2, **kwargs) 422 | return model 423 | 424 | 425 | def wrn_16_2_sskd(**kwargs): 426 | model = WideResNet_SSKD(depth=16, widen_factor=2, **kwargs) 427 | return model 428 | 429 | 430 | def wrn_16_1(**kwargs): 431 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 432 | return model 433 | -------------------------------------------------------------------------------- /backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | 5 | adding hyperparameter norm_layer: Huanran Chen 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | __all__ = [ 14 | "mobilenetv2_T_w", 15 | "mobilenetV2", 16 | "mobilenetV2_aux", 17 | "mobilenetV2_spkd", 18 | "mobilenetV2_crd", 19 | ] 20 | 21 | BN = None 22 | 23 | 24 | class Normalizer4CRD(nn.Module): 25 | def __init__(self, linear, power=2): 26 | super().__init__() 27 | self.linear = linear 28 | self.power = power 29 | 30 | def forward(self, x): 31 | x = x.flatten(1) 32 | z = self.linear(x) 33 | norm = z.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) 34 | out = z.div(norm) 35 | return out 36 | 37 | 38 | def conv_bn(inp, oup, stride): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) 41 | ) 42 | 43 | 44 | def conv_1x1_bn(inp, oup): 45 | return nn.Sequential( 46 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) 47 | ) 48 | 49 | 50 | class InvertedResidual(nn.Module): 51 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=nn.BatchNorm2d): 52 | super(InvertedResidual, self).__init__() 53 | self.blockname = None 54 | 55 | self.stride = stride 56 | assert stride in [1, 2] 57 | 58 | self.use_res_connect = self.stride == 1 and inp == oup 59 | 60 | self.conv = nn.Sequential( 61 | # pw 62 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 63 | norm_layer(inp * expand_ratio), 64 | nn.ReLU(), 65 | # dw 66 | nn.Conv2d( 67 | inp * expand_ratio, 68 | inp * expand_ratio, 69 | 3, 70 | stride, 71 | 1, 72 | groups=inp * expand_ratio, 73 | bias=False, 74 | ), 75 | norm_layer(inp * expand_ratio), 76 | nn.ReLU(), 77 | # pw-linear 78 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 79 | nn.BatchNorm2d(oup), 80 | ) 81 | self.names = ["0", "1", "2", "3", "4", "5", "6", "7"] 82 | 83 | def forward(self, x): 84 | t = x 85 | 86 | if self.use_res_connect: 87 | return t + self.conv(x) 88 | else: 89 | return self.conv(x) 90 | 91 | 92 | class MobileNetV2(nn.Module): 93 | """mobilenetV2""" 94 | 95 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False, 96 | norm_layer=nn.BatchNorm2d): 97 | super(MobileNetV2, self).__init__() 98 | self.remove_avg = remove_avg 99 | 100 | # setting of inverted residual blocks 101 | self.interverted_residual_setting = [ 102 | # t, c, n, s 103 | [1, 16, 1, 1], 104 | [T, 24, 2, 1], 105 | [T, 32, 3, 2], 106 | [T, 64, 4, 2], 107 | [T, 96, 3, 1], 108 | [T, 160, 3, 2], 109 | [T, 320, 1, 1], 110 | ] 111 | 112 | # building first layer 113 | assert input_size % 32 == 0 114 | input_channel = int(32 * width_mult) 115 | self.conv1 = conv_bn(3, input_channel, 2) 116 | 117 | # building inverted residual blocks 118 | self.blocks = nn.ModuleList([]) 119 | for t, c, n, s in self.interverted_residual_setting: 120 | output_channel = int(c * width_mult) 121 | layers = [] 122 | strides = [s] + [1] * (n - 1) 123 | for stride in strides: 124 | layers.append(InvertedResidual(input_channel, output_channel, stride, t, 125 | norm_layer=norm_layer)) 126 | input_channel = output_channel 127 | self.blocks.append(nn.Sequential(*layers)) 128 | 129 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 130 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 131 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 132 | 133 | # building classifier 134 | # self.classifier = nn.Sequential( 135 | # # nn.Dropout(0.5), 136 | # nn.Linear(self.last_channel, feature_dim), 137 | # ) 138 | self.classifier = nn.Linear(self.last_channel, feature_dim) 139 | 140 | self._initialize_weights() 141 | 142 | def get_bn_before_relu(self): 143 | bn1 = self.blocks[1][-1].conv[-1] 144 | bn2 = self.blocks[2][-1].conv[-1] 145 | bn3 = self.blocks[4][-1].conv[-1] 146 | bn4 = self.blocks[6][-1].conv[-1] 147 | return [bn1, bn2, bn3, bn4] 148 | 149 | def get_feat_modules(self): 150 | feat_m = nn.ModuleList([]) 151 | feat_m.append(self.conv1) 152 | feat_m.append(self.blocks) 153 | return feat_m 154 | 155 | def forward(self, x, is_feat=False, preact=False): 156 | 157 | out = self.conv1(x) 158 | out = self.blocks[0](out) 159 | out = self.blocks[1](out) 160 | f1 = out 161 | out = self.blocks[2](out) 162 | f2 = out 163 | out = self.blocks[3](out) 164 | out = self.blocks[4](out) 165 | f3 = out 166 | out = self.blocks[5](out) 167 | out = self.blocks[6](out) 168 | out = self.conv2(out) 169 | f4 = out 170 | 171 | if not self.remove_avg: 172 | out = self.avgpool(out) 173 | out = out.view(out.size(0), -1) 174 | out = self.classifier(out) 175 | 176 | if is_feat: 177 | return [f1, f2, f3, f4], out 178 | else: 179 | return out 180 | 181 | def _initialize_weights(self): 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 185 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 186 | if m.bias is not None: 187 | m.bias.data.zero_() 188 | elif isinstance(m, nn.BatchNorm2d): 189 | m.weight.data.fill_(1) 190 | m.bias.data.zero_() 191 | elif isinstance(m, nn.Linear): 192 | n = m.weight.size(1) 193 | m.weight.data.normal_(0, 0.01) 194 | m.bias.data.zero_() 195 | 196 | 197 | class Auxiliary_Classifier(nn.Module): 198 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False): 199 | super(Auxiliary_Classifier, self).__init__() 200 | 201 | self.remove_avg = remove_avg 202 | self.width_mult = width_mult 203 | # setting of inverted residual blocks 204 | interverted_residual_setting1 = [ 205 | [T, 32, 3, 2], 206 | [T, 64, 4, 2], 207 | [T, 96, 3, 1], 208 | [T, 160, 3, 2], 209 | [T, 320, 1, 1], 210 | ] 211 | self.block_extractor1 = self._make_layer( 212 | input_channel=12, interverted_residual_setting=interverted_residual_setting1 213 | ) 214 | 215 | interverted_residual_setting2 = [ 216 | [T, 64, 4, 2], 217 | [T, 96, 3, 1], 218 | [T, 160, 3, 2], 219 | [T, 320, 1, 1], 220 | ] 221 | self.block_extractor2 = self._make_layer( 222 | input_channel=16, interverted_residual_setting=interverted_residual_setting2 223 | ) 224 | 225 | interverted_residual_setting3 = [ 226 | [T, 160, 3, 2], 227 | [T, 320, 1, 1], 228 | ] 229 | self.block_extractor3 = self._make_layer( 230 | input_channel=48, interverted_residual_setting=interverted_residual_setting3 231 | ) 232 | 233 | interverted_residual_setting4 = [ 234 | [T, 160, 3, 1], 235 | [T, 320, 1, 1], 236 | ] 237 | self.block_extractor4 = self._make_layer( 238 | input_channel=160, interverted_residual_setting=interverted_residual_setting4 239 | ) 240 | 241 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 242 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 243 | 244 | self.conv2_1 = conv_1x1_bn(160, self.last_channel) 245 | self.conv2_2 = conv_1x1_bn(160, self.last_channel) 246 | self.conv2_3 = conv_1x1_bn(160, self.last_channel) 247 | self.conv2_4 = conv_1x1_bn(160, self.last_channel) 248 | 249 | self.fc1 = nn.Linear(self.last_channel, feature_dim) 250 | self.fc2 = nn.Linear(self.last_channel, feature_dim) 251 | self.fc3 = nn.Linear(self.last_channel, feature_dim) 252 | self.fc4 = nn.Linear(self.last_channel, feature_dim) 253 | 254 | self._initialize_weights() 255 | 256 | def _make_layer(self, input_channel, interverted_residual_setting): 257 | # building inverted residual blocks 258 | blocks = [] 259 | for t, c, n, s in interverted_residual_setting: 260 | output_channel = int(c * self.width_mult) 261 | layers = [] 262 | strides = [s] + [1] * (n - 1) 263 | for stride in strides: 264 | layers.append(InvertedResidual(input_channel, output_channel, stride, t)) 265 | input_channel = output_channel 266 | blocks.append(nn.Sequential(*layers)) 267 | 268 | return nn.Sequential(*blocks) 269 | 270 | def _initialize_weights(self): 271 | for m in self.modules(): 272 | if isinstance(m, nn.Conv2d): 273 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 274 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 275 | if m.bias is not None: 276 | m.bias.data.zero_() 277 | elif isinstance(m, nn.BatchNorm2d): 278 | m.weight.data.fill_(1) 279 | m.bias.data.zero_() 280 | elif isinstance(m, nn.Linear): 281 | n = m.weight.size(1) 282 | m.weight.data.normal_(0, 0.01) 283 | m.bias.data.zero_() 284 | 285 | def forward(self, x): 286 | ss_logits = [] 287 | ss_feats = [] 288 | 289 | for i in range(len(x)): 290 | idx = i + 1 291 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 292 | out = getattr(self, "conv2_" + str(idx))(out) 293 | out = self.avg_pool(out) 294 | out = out.view(out.size(0), -1) 295 | ss_feats.append(out) 296 | out = getattr(self, "fc" + str(idx))(out) 297 | ss_logits.append(out) 298 | 299 | return ss_feats, ss_logits 300 | 301 | 302 | class MobileNetv2_Auxiliary(nn.Module): 303 | def __init__(self, T, W, feature_dim=100): 304 | super(MobileNetv2_Auxiliary, self).__init__() 305 | self.backbone = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 306 | self.auxiliary_classifier = Auxiliary_Classifier( 307 | T=T, feature_dim=4 * feature_dim, width_mult=W 308 | ) 309 | 310 | def forward(self, x, grad=False, att=False): 311 | feats, logit = self.backbone(x, is_feat=True) 312 | if grad is False: 313 | for i in range(len(feats)): 314 | feats[i] = feats[i].detach() 315 | ss_feats, ss_logits = self.auxiliary_classifier(feats) 316 | if att is False: 317 | return logit, ss_logits 318 | else: 319 | return logit, ss_logits, feats 320 | 321 | 322 | class MobileNetV2_SPKD(MobileNetV2): 323 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False): 324 | super(MobileNetV2_SPKD, self).__init__(T, feature_dim, input_size, width_mult, remove_avg) 325 | 326 | def forward(self, x): 327 | out = self.conv1(x) 328 | out = self.blocks[0](out) 329 | out = self.blocks[1](out) 330 | out = self.blocks[2](out) 331 | out = self.blocks[3](out) 332 | out = self.blocks[4](out) 333 | out = self.blocks[5](out) 334 | out = self.blocks[6](out) 335 | f4 = out 336 | 337 | out = self.conv2(out) 338 | 339 | if not self.remove_avg: 340 | out = self.avgpool(out) 341 | out = out.view(out.size(0), -1) 342 | out = self.classifier(out) 343 | return f4, out 344 | 345 | 346 | class MobileNetV2_CRD(nn.Module): 347 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False): 348 | super(MobileNetV2_CRD, self).__init__() 349 | self.remove_avg = remove_avg 350 | 351 | # setting of inverted residual blocks 352 | self.interverted_residual_setting = [ 353 | # t, c, n, s 354 | [1, 16, 1, 1], 355 | [T, 24, 2, 1], 356 | [T, 32, 3, 2], 357 | [T, 64, 4, 2], 358 | [T, 96, 3, 1], 359 | [T, 160, 3, 2], 360 | [T, 320, 1, 1], 361 | ] 362 | 363 | # building first layer 364 | assert input_size % 32 == 0 365 | input_channel = int(32 * width_mult) 366 | self.conv1 = conv_bn(3, input_channel, 1) 367 | 368 | # building inverted residual blocks 369 | self.blocks = nn.ModuleList([]) 370 | for t, c, n, s in self.interverted_residual_setting: 371 | output_channel = int(c * width_mult) 372 | layers = [] 373 | strides = [s] + [1] * (n - 1) 374 | for stride in strides: 375 | layers.append(InvertedResidual(input_channel, output_channel, stride, t)) 376 | input_channel = output_channel 377 | self.blocks.append(nn.Sequential(*layers)) 378 | 379 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 380 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 381 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 382 | self.classifier = nn.Linear(self.last_channel, feature_dim) 383 | linear = nn.Linear(self.last_channel, 128, bias=True) 384 | self.normalizer = Normalizer4CRD(linear, power=2) 385 | self._initialize_weights() 386 | 387 | def get_bn_before_relu(self): 388 | bn1 = self.blocks[1][-1].conv[-1] 389 | bn2 = self.blocks[2][-1].conv[-1] 390 | bn3 = self.blocks[4][-1].conv[-1] 391 | bn4 = self.blocks[6][-1].conv[-1] 392 | return [bn1, bn2, bn3, bn4] 393 | 394 | def get_feat_modules(self): 395 | feat_m = nn.ModuleList([]) 396 | feat_m.append(self.conv1) 397 | feat_m.append(self.blocks) 398 | return feat_m 399 | 400 | def forward(self, x): 401 | 402 | out = self.conv1(x) 403 | out = self.blocks[0](out) 404 | out = self.blocks[1](out) 405 | out = self.blocks[2](out) 406 | out = self.blocks[3](out) 407 | out = self.blocks[4](out) 408 | out = self.blocks[5](out) 409 | out = self.blocks[6](out) 410 | out = self.conv2(out) 411 | out = self.avgpool(out) 412 | f = out 413 | out = out.view(out.size(0), -1) 414 | out = self.classifier(out) 415 | crdout = self.normalizer(f) 416 | return crdout, out 417 | 418 | def _initialize_weights(self): 419 | for m in self.modules(): 420 | if isinstance(m, nn.Conv2d): 421 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 422 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 423 | if m.bias is not None: 424 | m.bias.data.zero_() 425 | elif isinstance(m, nn.BatchNorm2d): 426 | m.weight.data.fill_(1) 427 | m.bias.data.zero_() 428 | elif isinstance(m, nn.Linear): 429 | n = m.weight.size(1) 430 | m.weight.data.normal_(0, 0.01) 431 | m.bias.data.zero_() 432 | 433 | 434 | def mobilenetv2_T_w(T, W, feature_dim=100, **kwargs): 435 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W, **kwargs) 436 | return model 437 | 438 | 439 | def mobilenetV2(num_classes, **kwargs): 440 | return mobilenetv2_T_w(6, 0.5, num_classes, **kwargs) 441 | 442 | 443 | def mobilenetV2_aux(num_classes): 444 | return MobileNetv2_Auxiliary(6, 0.5, num_classes) 445 | 446 | 447 | def mobilenetV2_spkd(num_classes): 448 | return MobileNetV2_SPKD(T=6, width_mult=0.5, feature_dim=num_classes) 449 | 450 | 451 | def mobilenetV2_crd(num_classes): 452 | return MobileNetV2_CRD(T=6, width_mult=0.5, feature_dim=num_classes) 453 | -------------------------------------------------------------------------------- /backbones/vgg.py: -------------------------------------------------------------------------------- 1 | """VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | 4 | 5 | adding hyperparameter norm_layer: Huanran Chen 6 | """ 7 | 8 | import math 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = ["vgg13_bn_aux", "vgg13_bn", "vgg13_bn_spkd", "vgg13_bn_crd", "vgg8_bn"] 14 | 15 | model_urls = { 16 | "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", 17 | "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", 18 | "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", 19 | "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", 20 | } 21 | 22 | 23 | class Normalizer4CRD(nn.Module): 24 | def __init__(self, linear, power=2): 25 | super().__init__() 26 | self.linear = linear 27 | self.power = power 28 | 29 | def forward(self, x): 30 | x = x.flatten(1) 31 | z = self.linear(x) 32 | norm = z.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) 33 | out = z.div(norm) 34 | return out 35 | 36 | 37 | class VGG(nn.Module): 38 | def __init__(self, cfg, batch_norm=False, num_classes=1000, 39 | norm_layer=nn.BatchNorm2d): 40 | super(VGG, self).__init__() 41 | self.block0 = self._make_layers(cfg[0], batch_norm, 3, norm_layer) 42 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1], norm_layer) 43 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1], norm_layer) 44 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1], norm_layer) 45 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1], norm_layer) 46 | 47 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 49 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 50 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 52 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 53 | self.last_channel = 512 54 | self.classifier = nn.Linear(512, num_classes) 55 | self._initialize_weights() 56 | 57 | def get_feat_modules(self): 58 | feat_m = nn.ModuleList([]) 59 | feat_m.append(self.block0) 60 | feat_m.append(self.pool0) 61 | feat_m.append(self.block1) 62 | feat_m.append(self.pool1) 63 | feat_m.append(self.block2) 64 | feat_m.append(self.pool2) 65 | feat_m.append(self.block3) 66 | feat_m.append(self.pool3) 67 | feat_m.append(self.block4) 68 | feat_m.append(self.pool4) 69 | return feat_m 70 | 71 | def get_bn_before_relu(self): 72 | bn1 = self.block1[-1] 73 | bn2 = self.block2[-1] 74 | bn3 = self.block3[-1] 75 | bn4 = self.block4[-1] 76 | return [bn1, bn2, bn3, bn4] 77 | 78 | def forward(self, x, is_feat=False, preact=False): 79 | h = x.shape[2] 80 | x = F.relu(self.block0(x)) 81 | f0 = x 82 | 83 | x = self.pool0(x) 84 | x = self.block1(x) 85 | x = F.relu(x) 86 | f1 = x 87 | 88 | x = self.pool1(x) 89 | x = self.block2(x) 90 | x = F.relu(x) 91 | f2 = x 92 | 93 | x = self.pool2(x) 94 | x = self.block3(x) 95 | x = F.relu(x) 96 | if h == 64: 97 | x = self.pool3(x) 98 | x = self.block4(x) 99 | x = F.relu(x) 100 | f3 = x 101 | 102 | x = self.pool4(x) 103 | x = x.view(x.size(0), -1) 104 | x = self.classifier(x) 105 | 106 | if is_feat: 107 | return [f0, f1, f2, f3], x 108 | else: 109 | return x 110 | 111 | @staticmethod 112 | def _make_layers(cfg, batch_norm=False, in_channels=3, norm_layer=nn.BatchNorm2d): 113 | layers = [] 114 | for v in cfg: 115 | if v == "M": 116 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 117 | else: 118 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 119 | if batch_norm: 120 | layers += [conv2d, norm_layer(v), nn.ReLU(inplace=True)] 121 | else: 122 | layers += [conv2d, nn.ReLU(inplace=True)] 123 | in_channels = v 124 | layers = layers[:-1] 125 | return nn.Sequential(*layers) 126 | 127 | def _initialize_weights(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 132 | if m.bias is not None: 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.Linear): 138 | n = m.weight.size(1) 139 | m.weight.data.normal_(0, 0.01) 140 | m.bias.data.zero_() 141 | 142 | 143 | class Auxiliary_Classifier(nn.Module): 144 | def __init__(self, cfg, batch_norm=False, num_classes=100): 145 | super(Auxiliary_Classifier, self).__init__() 146 | 147 | self.block_extractor1 = nn.Sequential( 148 | *[ 149 | nn.MaxPool2d(kernel_size=2, stride=2), 150 | self._make_layers(cfg[1], batch_norm, cfg[0][-1]), 151 | nn.ReLU(inplace=True), 152 | nn.MaxPool2d(kernel_size=2, stride=2), 153 | self._make_layers(cfg[2], batch_norm, cfg[1][-1]), 154 | nn.ReLU(inplace=True), 155 | nn.MaxPool2d(kernel_size=2, stride=2), 156 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 157 | nn.ReLU(inplace=True), 158 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 159 | nn.ReLU(inplace=True), 160 | nn.AdaptiveAvgPool2d((1, 1)), 161 | ] 162 | ) 163 | 164 | self.block_extractor2 = nn.Sequential( 165 | *[ 166 | nn.MaxPool2d(kernel_size=2, stride=2), 167 | self._make_layers(cfg[2], batch_norm, cfg[1][-1]), 168 | nn.ReLU(inplace=True), 169 | nn.MaxPool2d(kernel_size=2, stride=2), 170 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 171 | nn.ReLU(inplace=True), 172 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 173 | nn.ReLU(inplace=True), 174 | nn.AdaptiveAvgPool2d((1, 1)), 175 | ] 176 | ) 177 | 178 | self.block_extractor3 = nn.Sequential( 179 | *[ 180 | nn.MaxPool2d(kernel_size=2, stride=2), 181 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 182 | nn.ReLU(inplace=True), 183 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 184 | nn.ReLU(inplace=True), 185 | nn.AdaptiveAvgPool2d((1, 1)), 186 | ] 187 | ) 188 | 189 | self.block_extractor4 = nn.Sequential( 190 | *[ 191 | self._make_layers(cfg[3], batch_norm, cfg[4][-1]), 192 | nn.ReLU(inplace=True), 193 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 194 | nn.ReLU(inplace=True), 195 | nn.AdaptiveAvgPool2d((1, 1)), 196 | ] 197 | ) 198 | 199 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 200 | self.fc1 = nn.Linear(512, num_classes) 201 | self.fc2 = nn.Linear(512, num_classes) 202 | self.fc3 = nn.Linear(512, num_classes) 203 | self.fc4 = nn.Linear(512, num_classes) 204 | 205 | def _initialize_weights(self): 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 209 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 210 | if m.bias is not None: 211 | m.bias.data.zero_() 212 | elif isinstance(m, nn.BatchNorm2d): 213 | m.weight.data.fill_(1) 214 | m.bias.data.zero_() 215 | elif isinstance(m, nn.Linear): 216 | n = m.weight.size(1) 217 | m.weight.data.normal_(0, 0.01) 218 | m.bias.data.zero_() 219 | 220 | @staticmethod 221 | def _make_layers(cfg, batch_norm=False, in_channels=3): 222 | layers = [] 223 | for v in cfg: 224 | if v == "M": 225 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 226 | else: 227 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 228 | if batch_norm: 229 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 230 | else: 231 | layers += [conv2d, nn.ReLU(inplace=True)] 232 | in_channels = v 233 | layers = layers[:-1] 234 | return nn.Sequential(*layers) 235 | 236 | def forward(self, x): 237 | ss_logits = [] 238 | for i in range(len(x)): 239 | idx = i + 1 240 | out = getattr(self, "block_extractor" + str(idx))(x[i]) 241 | out = out.view(-1, 512) 242 | out = getattr(self, "fc" + str(idx))(out) 243 | ss_logits.append(out) 244 | return ss_logits 245 | 246 | 247 | class VGG_Auxiliary(nn.Module): 248 | def __init__(self, cfg, batch_norm=False, num_classes=100): 249 | super(VGG_Auxiliary, self).__init__() 250 | self.backbone = VGG(cfg, batch_norm=batch_norm, num_classes=num_classes) 251 | self.auxiliary_classifier = Auxiliary_Classifier( 252 | cfg, batch_norm=batch_norm, num_classes=num_classes * 4 253 | ) 254 | 255 | def forward(self, x, grad=False): 256 | feats, logit = self.backbone(x, is_feat=True) 257 | if grad is False: 258 | for i in range(len(feats)): 259 | feats[i] = feats[i].detach() 260 | ss_logits = self.auxiliary_classifier(feats) 261 | return logit, ss_logits 262 | 263 | 264 | class VGG_SPKD(VGG): 265 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 266 | super(VGG_SPKD, self).__init__(cfg, batch_norm, num_classes) 267 | 268 | def forward(self, x): 269 | h = x.shape[2] 270 | x = F.relu(self.block0(x)) 271 | 272 | x = self.pool0(x) 273 | x = self.block1(x) 274 | x = F.relu(x) 275 | 276 | x = self.pool1(x) 277 | x = self.block2(x) 278 | x = F.relu(x) 279 | 280 | x = self.pool2(x) 281 | x = self.block3(x) 282 | x = F.relu(x) 283 | if h == 64: 284 | x = self.pool3(x) 285 | x = self.block4(x) 286 | x = F.relu(x) 287 | f3 = x 288 | 289 | x = self.pool4(x) 290 | x = x.view(x.size(0), -1) 291 | x = self.classifier(x) 292 | return f3, x 293 | 294 | 295 | cfg = { 296 | "A": [[64], [128], [256, 256], [512, 512], [512, 512]], 297 | "B": [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 298 | "D": [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 299 | "E": [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 300 | "S": [[64], [128], [256], [512], [512]], 301 | } 302 | 303 | 304 | class VGG_CRD(nn.Module): 305 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 306 | super(VGG_CRD, self).__init__() 307 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 308 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 309 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 310 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 311 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 312 | 313 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 314 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 315 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 316 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 317 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 318 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 319 | self.classifier = nn.Linear(512, num_classes) 320 | linear = nn.Linear(512, 128, bias=True) 321 | self.normalizer = Normalizer4CRD(linear, power=2) 322 | self._initialize_weights() 323 | 324 | def get_feat_modules(self): 325 | feat_m = nn.ModuleList([]) 326 | feat_m.append(self.block0) 327 | feat_m.append(self.pool0) 328 | feat_m.append(self.block1) 329 | feat_m.append(self.pool1) 330 | feat_m.append(self.block2) 331 | feat_m.append(self.pool2) 332 | feat_m.append(self.block3) 333 | feat_m.append(self.pool3) 334 | feat_m.append(self.block4) 335 | feat_m.append(self.pool4) 336 | return feat_m 337 | 338 | def get_bn_before_relu(self): 339 | bn1 = self.block1[-1] 340 | bn2 = self.block2[-1] 341 | bn3 = self.block3[-1] 342 | bn4 = self.block4[-1] 343 | return [bn1, bn2, bn3, bn4] 344 | 345 | def forward(self, x): 346 | h = x.shape[2] 347 | x = F.relu(self.block0(x)) 348 | 349 | x = self.pool0(x) 350 | x = self.block1(x) 351 | x = F.relu(x) 352 | 353 | x = self.pool1(x) 354 | x = self.block2(x) 355 | x = F.relu(x) 356 | 357 | x = self.pool2(x) 358 | x = self.block3(x) 359 | x = F.relu(x) 360 | if h == 64: 361 | x = self.pool3(x) 362 | x = self.block4(x) 363 | x = F.relu(x) 364 | 365 | x = self.pool4(x) 366 | crdout = x 367 | x = x.view(x.size(0), -1) 368 | x = self.classifier(x) 369 | crdout = self.normalizer(crdout) 370 | return crdout, x 371 | 372 | @staticmethod 373 | def _make_layers(cfg, batch_norm=False, in_channels=3): 374 | layers = [] 375 | for v in cfg: 376 | if v == "M": 377 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 378 | else: 379 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 380 | if batch_norm: 381 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 382 | else: 383 | layers += [conv2d, nn.ReLU(inplace=True)] 384 | in_channels = v 385 | layers = layers[:-1] 386 | return nn.Sequential(*layers) 387 | 388 | def _initialize_weights(self): 389 | for m in self.modules(): 390 | if isinstance(m, nn.Conv2d): 391 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 392 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 393 | if m.bias is not None: 394 | m.bias.data.zero_() 395 | elif isinstance(m, nn.BatchNorm2d): 396 | m.weight.data.fill_(1) 397 | m.bias.data.zero_() 398 | elif isinstance(m, nn.Linear): 399 | n = m.weight.size(1) 400 | m.weight.data.normal_(0, 0.01) 401 | m.bias.data.zero_() 402 | 403 | 404 | def vgg8(**kwargs): 405 | """VGG 8-layer model (configuration "S") 406 | Args: 407 | pretrained (bool): If True, returns a model pre-trained on ImageNet 408 | """ 409 | model = VGG(cfg["S"], **kwargs) 410 | return model 411 | 412 | 413 | def vgg8_bn(**kwargs): 414 | """VGG 8-layer model (configuration "S") 415 | Args: 416 | pretrained (bool): If True, returns a model pre-trained on ImageNet 417 | """ 418 | model = VGG(cfg["S"], batch_norm=True, **kwargs) 419 | return model 420 | 421 | 422 | def vgg8_bn_aux(**kwargs): 423 | """VGG 8-layer model (configuration "S") 424 | Args: 425 | pretrained (bool): If True, returns a model pre-trained on ImageNet 426 | """ 427 | model = VGG_Auxiliary(cfg["S"], batch_norm=True, **kwargs) 428 | return model 429 | 430 | 431 | def vgg8_bn_spkd(**kwargs): 432 | """VGG 8-layer model (configuration "S") 433 | Args: 434 | pretrained (bool): If True, returns a model pre-trained on ImageNet 435 | """ 436 | model = VGG_SPKD(cfg["S"], batch_norm=True, **kwargs) 437 | return model 438 | 439 | 440 | def vgg8_bn_crd(**kwargs): 441 | """VGG 8-layer model (configuration "S") 442 | Args: 443 | pretrained (bool): If True, returns a model pre-trained on ImageNet 444 | """ 445 | model = VGG_CRD(cfg["S"], batch_norm=True, **kwargs) 446 | return model 447 | 448 | 449 | def vgg11(**kwargs): 450 | """VGG 11-layer model (configuration "A") 451 | Args: 452 | pretrained (bool): If True, returns a model pre-trained on ImageNet 453 | """ 454 | model = VGG(cfg["A"], **kwargs) 455 | return model 456 | 457 | 458 | def vgg11_bn(**kwargs): 459 | """VGG 11-layer model (configuration "A") with batch normalization""" 460 | model = VGG(cfg["A"], batch_norm=True, **kwargs) 461 | return model 462 | 463 | 464 | def vgg13(**kwargs): 465 | """VGG 13-layer model (configuration "B") 466 | Args: 467 | pretrained (bool): If True, returns a model pre-trained on ImageNet 468 | """ 469 | model = VGG(cfg["B"], **kwargs) 470 | return model 471 | 472 | 473 | def vgg13_bn(**kwargs): 474 | """VGG 13-layer model (configuration "B") with batch normalization""" 475 | model = VGG(cfg["B"], batch_norm=True, **kwargs) 476 | return model 477 | 478 | 479 | def vgg13_bn_aux(**kwargs): 480 | """VGG 13-layer model (configuration "B") with batch normalization""" 481 | model = VGG_Auxiliary(cfg["B"], batch_norm=True, **kwargs) 482 | return model 483 | 484 | 485 | def vgg13_bn_spkd(**kwargs): 486 | """VGG 13-layer model (configuration "B") with batch normalization""" 487 | model = VGG_SPKD(cfg["B"], batch_norm=True, **kwargs) 488 | return model 489 | 490 | 491 | def vgg13_bn_crd(**kwargs): 492 | """VGG 13-layer model (configuration "B") with batch normalization""" 493 | model = VGG_CRD(cfg["B"], batch_norm=True, **kwargs) 494 | return model 495 | 496 | 497 | def vgg16(**kwargs): 498 | """VGG 16-layer model (configuration "D") 499 | Args: 500 | pretrained (bool): If True, returns a model pre-trained on ImageNet 501 | """ 502 | model = VGG(cfg["D"], **kwargs) 503 | return model 504 | 505 | 506 | def vgg16_bn(**kwargs): 507 | """VGG 16-layer model (configuration "D") with batch normalization""" 508 | model = VGG(cfg["D"], batch_norm=True, **kwargs) 509 | return model 510 | 511 | 512 | def vgg19(**kwargs): 513 | """VGG 19-layer model (configuration "E") 514 | Args: 515 | pretrained (bool): If True, returns a model pre-trained on ImageNet 516 | """ 517 | model = VGG(cfg["E"], **kwargs) 518 | return model 519 | 520 | 521 | def vgg19_bn(**kwargs): 522 | """VGG 19-layer model (configuration 'E') with batch normalization""" 523 | model = VGG(cfg["E"], batch_norm=True, **kwargs) 524 | return model 525 | -------------------------------------------------------------------------------- /backbones/RSC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy.random as npr 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import random 8 | # from .utils import load_state_dict_from_url 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 16 | 'wide_resnet50_2', 'wide_resnet101_2'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 21 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 22 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 23 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 24 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 25 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 26 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 27 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 28 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 29 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 34 | """3x3 convolution with padding""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=dilation, groups=groups, bias=False, dilation=dilation) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | __constants__ = ['downsample'] 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 49 | base_width=64, dilation=1, norm_layer=None): 50 | super(BasicBlock, self).__init__() 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | if groups != 1 or base_width != 64: 54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 55 | if dilation > 1: 56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 58 | self.conv1 = conv3x3(inplanes, planes, stride) 59 | self.bn1 = norm_layer(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = norm_layer(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | identity = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | expansion = 4 87 | __constants__ = ['downsample'] 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 90 | base_width=64, dilation=1, norm_layer=None): 91 | super(Bottleneck, self).__init__() 92 | if norm_layer is None: 93 | norm_layer = nn.BatchNorm2d 94 | width = int(planes * (base_width / 64.)) * groups 95 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 96 | self.conv1 = conv1x1(inplanes, width) 97 | self.bn1 = norm_layer(width) 98 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 99 | self.bn2 = norm_layer(width) 100 | self.conv3 = conv1x1(width, planes * self.expansion) 101 | self.bn3 = norm_layer(planes * self.expansion) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | identity = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | identity = self.downsample(x) 122 | 123 | out += identity 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class ResNet(nn.Module): 130 | 131 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 132 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 133 | norm_layer=None): 134 | super(ResNet, self).__init__() 135 | if norm_layer is None: 136 | norm_layer = nn.BatchNorm2d 137 | self._norm_layer = norm_layer 138 | 139 | self.inplanes = 64 140 | self.dilation = 1 141 | if replace_stride_with_dilation is None: 142 | # each element in the tuple indicates if we should replace 143 | # the 2x2 stride with a dilated convolution instead 144 | replace_stride_with_dilation = [False, False, False] 145 | if len(replace_stride_with_dilation) != 3: 146 | raise ValueError("replace_stride_with_dilation should be None " 147 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 148 | self.groups = groups 149 | self.base_width = width_per_group 150 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 151 | bias=False) 152 | self.bn1 = norm_layer(self.inplanes) 153 | self.relu = nn.ReLU(inplace=True) 154 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 155 | self.layer1 = self._make_layer(block, 64, layers[0]) 156 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 157 | dilate=replace_stride_with_dilation[0]) 158 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 159 | dilate=replace_stride_with_dilation[1]) 160 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 161 | dilate=replace_stride_with_dilation[2]) 162 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 163 | self.fc = nn.Linear(512 * block.expansion, num_classes) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 168 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 169 | nn.init.constant_(m.weight, 1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | # Zero-initialize the last BN in each residual branch, 173 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 174 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 175 | if zero_init_residual: 176 | for m in self.modules(): 177 | if isinstance(m, Bottleneck): 178 | nn.init.constant_(m.bn3.weight, 0) 179 | elif isinstance(m, BasicBlock): 180 | nn.init.constant_(m.bn2.weight, 0) 181 | 182 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 183 | norm_layer = self._norm_layer 184 | downsample = None 185 | previous_dilation = self.dilation 186 | if dilate: 187 | self.dilation *= stride 188 | stride = 1 189 | if stride != 1 or self.inplanes != planes * block.expansion: 190 | downsample = nn.Sequential( 191 | conv1x1(self.inplanes, planes * block.expansion, stride), 192 | norm_layer(planes * block.expansion), 193 | ) 194 | 195 | layers = [] 196 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 197 | self.base_width, previous_dilation, norm_layer)) 198 | self.inplanes = planes * block.expansion 199 | for _ in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, groups=self.groups, 201 | base_width=self.base_width, dilation=self.dilation, 202 | norm_layer=norm_layer)) 203 | 204 | return nn.Sequential(*layers) 205 | 206 | def forward(self, x, gt=None, epoch=None): 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | x = self.maxpool(x) 211 | 212 | x = self.layer1(x) 213 | x = self.layer2(x) 214 | x = self.layer3(x) 215 | x = self.layer4(x) 216 | 217 | if self.training: 218 | if epoch <= 18: 219 | percent = 1/6.0 220 | elif epoch <= 38: 221 | percent = 1/5.5 222 | elif epoch <= 58: 223 | percent = 1/5.0 224 | elif epoch <= 78: 225 | percent = 1/4.5 226 | else: 227 | percent = 1/4.0 228 | 229 | self.eval() 230 | x_new = x.clone().detach() 231 | x_new = Variable(x_new.data, requires_grad=True) 232 | x_new_view = self.avgpool(x_new) 233 | x_new_view = x_new_view.view(x_new_view.size(0), -1) 234 | output = self.fc(x_new_view) 235 | class_num = output.shape[1] 236 | index = gt 237 | num_rois = x_new.shape[0] 238 | num_channel = x_new.shape[1] 239 | H = x_new.shape[2] 240 | HW = H * H 241 | one_hot = torch.zeros((1), dtype=torch.float32).cuda() 242 | one_hot = Variable(one_hot, requires_grad=False) 243 | sp_i = torch.ones([2, num_rois]).long() 244 | sp_i[0, :] = torch.arange(num_rois) 245 | sp_i[1, :] = index 246 | sp_v = torch.ones([num_rois]) 247 | one_hot_sparse = torch.sparse.FloatTensor(sp_i, sp_v, torch.Size([num_rois, class_num])).to_dense().cuda() 248 | one_hot_sparse = Variable(one_hot_sparse, requires_grad=False) 249 | one_hot = torch.sum(output * one_hot_sparse) 250 | self.zero_grad() 251 | one_hot.backward() 252 | grads_val = x_new.grad.clone().detach() 253 | grad_channel_mean = torch.mean(grads_val.view(num_rois, num_channel, -1), dim=2) 254 | grad_channel_mean = grad_channel_mean.view(num_rois, num_channel, 1, 1) 255 | spatial_mean = torch.sum(x_new * grad_channel_mean, 1) 256 | spatial_mean = spatial_mean.view(num_rois, HW) 257 | self.zero_grad() 258 | 259 | th_mask_value = torch.sort(spatial_mean, dim=1, descending=True)[0][:, int(HW/2.0)] 260 | th_mask_value = th_mask_value.view(num_rois, 1).expand(num_rois, HW) 261 | mask_all_cuda = torch.where(spatial_mean > th_mask_value, torch.zeros(spatial_mean.shape).cuda(), 262 | torch.ones(spatial_mean.shape).cuda()) 263 | mask_all = mask_all_cuda.detach().cpu().numpy() 264 | spatial_drop_num = int(HW/3.0) 265 | for q in range(num_rois): 266 | mask_all_temp = np.ones((HW), dtype=np.float32) 267 | zero_index = np.where(mask_all[q, :] == 0)[0] 268 | num_zero_index = zero_index.size 269 | if num_zero_index >= spatial_drop_num: 270 | dumy_index = npr.choice(zero_index, size=spatial_drop_num, replace=False) 271 | else: 272 | zero_index = np.arange(49) 273 | dumy_index = npr.choice(zero_index, size=spatial_drop_num, replace=False) 274 | mask_all_temp[dumy_index] = 0 275 | mask_all[q, :] = mask_all_temp 276 | mask_all = torch.from_numpy(mask_all.reshape(num_rois, H, H)).cuda() 277 | mask_all = mask_all.view(num_rois, 1, H, H) 278 | 279 | cls_prob_before = F.softmax(output, dim=1) 280 | x_new_view_after = x_new * mask_all 281 | x_new_view_after = self.avgpool(x_new_view_after) 282 | x_new_view_after = x_new_view_after.view(x_new_view_after.size(0), -1) 283 | x_new_view_after = self.fc(x_new_view_after) 284 | cls_prob_after = F.softmax(x_new_view_after, dim=1) 285 | 286 | sp_i = torch.ones([2, num_rois]).long() 287 | sp_i[0, :] = torch.arange(num_rois) 288 | sp_i[1, :] = index 289 | sp_v = torch.ones([num_rois]) 290 | one_hot_sparse = torch.sparse.FloatTensor(sp_i, sp_v, torch.Size([num_rois, class_num])).to_dense().cuda() 291 | before_vector = torch.sum(one_hot_sparse * cls_prob_before, dim=1) 292 | after_vector = torch.sum(one_hot_sparse * cls_prob_after, dim=1) 293 | change_vector = before_vector - after_vector - 0.0001 294 | change_vector = torch.where(change_vector > 0, change_vector, torch.zeros(change_vector.shape).cuda()) 295 | th_fg_value = torch.sort(change_vector, dim=0, descending=True)[0][int(round(float(num_rois) * percent))] 296 | drop_index_fg = change_vector.gt(th_fg_value) 297 | ignore_index_fg = 1 - drop_index_fg.float() 298 | not_01_ignore_index_fg = ignore_index_fg.nonzero()[:, 0] 299 | mask_all[not_01_ignore_index_fg.long(), :] = 1 300 | 301 | self.train() 302 | mask_all = Variable(mask_all, requires_grad=True) 303 | x = x * mask_all 304 | 305 | x = self.avgpool(x) 306 | x = torch.flatten(x, 1) 307 | x = self.fc(x) 308 | 309 | return x 310 | 311 | 312 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 313 | model = ResNet(block, layers, **kwargs) 314 | if pretrained: 315 | state_dict = load_state_dict_from_url(model_urls[arch], 316 | progress=progress) 317 | # model.load_state_dict(state_dict, strict=False) 318 | model.load_state_dict(state_dict) 319 | return model 320 | 321 | 322 | def resnet18(pretrained=False, progress=True, **kwargs): 323 | r"""ResNet-18 model from 324 | `"Deep Residual Learning for Image Recognition" `_ 325 | Args: 326 | pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | progress (bool): If True, displays a progress bar of the download to stderr 328 | """ 329 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 330 | **kwargs) 331 | 332 | 333 | def resnet34(pretrained=False, progress=True, **kwargs): 334 | r"""ResNet-34 model from 335 | `"Deep Residual Learning for Image Recognition" `_ 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | progress (bool): If True, displays a progress bar of the download to stderr 339 | """ 340 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 341 | **kwargs) 342 | 343 | 344 | def resnet50(pretrained=False, progress=True, **kwargs): 345 | r"""ResNet-50 model from 346 | `"Deep Residual Learning for Image Recognition" `_ 347 | Args: 348 | pretrained (bool): If True, returns a model pre-trained on ImageNet 349 | progress (bool): If True, displays a progress bar of the download to stderr 350 | """ 351 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 352 | **kwargs) 353 | 354 | 355 | def resnet101(pretrained=False, progress=True, **kwargs): 356 | r"""ResNet-101 model from 357 | `"Deep Residual Learning for Image Recognition" `_ 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | progress (bool): If True, displays a progress bar of the download to stderr 361 | """ 362 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 363 | **kwargs) 364 | 365 | 366 | def resnet152(pretrained=False, progress=True, **kwargs): 367 | r"""ResNet-152 model from 368 | `"Deep Residual Learning for Image Recognition" `_ 369 | Args: 370 | pretrained (bool): If True, returns a model pre-trained on ImageNet 371 | progress (bool): If True, displays a progress bar of the download to stderr 372 | """ 373 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 374 | **kwargs) 375 | 376 | 377 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 378 | r"""ResNeXt-50 32x4d model from 379 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 380 | Args: 381 | pretrained (bool): If True, returns a model pre-trained on ImageNet 382 | progress (bool): If True, displays a progress bar of the download to stderr 383 | """ 384 | kwargs['groups'] = 32 385 | kwargs['width_per_group'] = 4 386 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 387 | pretrained, progress, **kwargs) 388 | 389 | 390 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 391 | r"""ResNeXt-101 32x8d model from 392 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 393 | Args: 394 | pretrained (bool): If True, returns a model pre-trained on ImageNet 395 | progress (bool): If True, displays a progress bar of the download to stderr 396 | """ 397 | kwargs['groups'] = 32 398 | kwargs['width_per_group'] = 8 399 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 400 | pretrained, progress, **kwargs) 401 | 402 | 403 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 404 | r"""Wide ResNet-50-2 model from 405 | `"Wide Residual Networks" `_ 406 | The model is the same as ResNet except for the bottleneck number of channels 407 | which is twice larger in every block. The number of channels in outer 1x1 408 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 409 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 410 | Args: 411 | pretrained (bool): If True, returns a model pre-trained on ImageNet 412 | progress (bool): If True, displays a progress bar of the download to stderr 413 | """ 414 | kwargs['width_per_group'] = 64 * 2 415 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 416 | pretrained, progress, **kwargs) 417 | 418 | 419 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 420 | r"""Wide ResNet-101-2 model from 421 | `"Wide Residual Networks" `_ 422 | The model is the same as ResNet except for the bottleneck number of channels 423 | which is twice larger in every block. The number of channels in outer 1x1 424 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 425 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 426 | Args: 427 | pretrained (bool): If True, returns a model pre-trained on ImageNet 428 | progress (bool): If True, displays a progress bar of the download to stderr 429 | """ 430 | kwargs['width_per_group'] = 64 * 2 431 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 432 | pretrained, progress, **kwargs) --------------------------------------------------------------------------------