├── LICENSE.md ├── README.md ├── config.py ├── dataloader.py ├── detection.py ├── image ├── overview.pdf └── overview.png ├── mitigation.py ├── models.py ├── models ├── ULP_model.py ├── __init__.py ├── lenet.py ├── meta_classifier_cifar10_model.py └── preact_resnet.py ├── requirements.txt ├── resnet_nole.py ├── reverse_engineering.py ├── train_models ├── config.py ├── dataloader.py ├── resnet_nole.py └── train_model.py ├── unet_blocks.py └── unet_model.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 RUSSS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeatureRE 2 | This repository is the source code for ["Rethinking the Reverse-engineering of Trojan Triggers"](https://arxiv.org/abs/2210.15127) (NeurIPS 2022). 3 | 4 |
5 | 6 |
7 | 8 | Existing reverse-engineering methods only consider the input space constraint. It conducts 9 | reverse-engineering via searching a static trigger pattern in the input space. These methods fail to 10 | reverse-engineer feature-space Trojans whose trigger is dynamic in the input space. Instead, our idea 11 | is to exploit the feature space constraint and searching a feature space trigger using the constraint 12 | that the Trojan features will form a hyperplane. At the same time, we also reverse-engineer the input 13 | space Trojan transformation based on the feature space constraint. 14 | 15 | ## Environment 16 | See requirements.txt 17 | 18 | ## Generating models 19 | Trojaned models can be generated via using the existing code of the attacks: 20 | 21 | - [BadNets] https://github.com/verazuo/badnets-pytorch 22 | - [WaNet] https://github.com/VinAIResearch/Warping-based_Backdoor_Attack-release 23 | - [IA] https://github.com/VinAIResearch/input-aware-backdoor-attack-release 24 | - [CL] https://github.com/MadryLab/label-consistent-backdoor-code 25 | - [Filter] https://github.com/trojai 26 | - [SIG] https://github.com/bboylyg/NAD 27 | - [ISSBA] https://github.com/yuezunli/ISSBA 28 | 29 | For example, to generate Trojaned models by WaNet: 30 | ```bash 31 | cd train_models \ 32 | CUDA_VISIBLE_DEVICES=0 python train_model.py --dataset cifar10 --set_arch resnet18 --pc 0.1 33 | ``` 34 | To generate benign models: 35 | ```bash 36 | cd train_models \ 37 | CUDA_VISIBLE_DEVICES=0 python train_model.py --dataset cifar10 --set_arch resnet18 --pc 0 38 | ``` 39 | 40 | ## Detection 41 | 42 | For example, to run FeatureRE detection on CIFAR10 with ResNet18 network: 43 | 44 | ```bash 45 | CUDA_VISIBLE_DEVICES=0 python detection.py \ 46 | --dataset cifar10 --set_arch resnet18 \ 47 | --hand_set_model_path \ 48 | --data_fraction 0.01 \ 49 | --lr 1e-3 --bs 256 \ 50 | --set_all2one_target all 51 | ``` 52 | 53 | ## Mitigation 54 | 55 | For example, to run FeatureRE mitigation on CIFAR10 with ResNet18 network produced by filter attack: 56 | 57 | ```bash 58 | CUDA_VISIBLE_DEVICES=0 python mitigation.py \ 59 | --dataset cifar10 --set_arch resnet18 \ 60 | --hand_set_model_path \ 61 | --data_fraction 0.01 \ 62 | --lr 1e-3 --bs 256 \ 63 | --set_all2one_target \ 64 | --mask_size 0.05 --override_epoch 400 --asr_test_type wanet 65 | ``` 66 | 67 | ## Cite this work 68 | You are encouraged to cite the following paper if you use the repo for academic research. 69 | 70 | ``` 71 | @inproceedings{wang2022rethinking, 72 | title={Rethinking the Reverse-engineering of Trojan Triggers}, 73 | author={Wang, Zhenting and Mei, Kai and Ding, Hailun and Zhai, Juan and Ma, Shiqing}, 74 | booktitle={Advances in Neural Information Processing Systems}, 75 | year={2022} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_argument(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Directory option 8 | parser.add_argument("--checkpoints", type=str, default="../../checkpoints/") 9 | parser.add_argument("--data_root", type=str, default="../../data/") 10 | parser.add_argument("--device", type=str, default="cuda") 11 | parser.add_argument("--dataset", type=str, default="mnist") 12 | parser.add_argument("--attack_mode", type=str, default="all2one") 13 | 14 | parser.add_argument("--data_fraction", type=float, default=1.0) 15 | 16 | parser.add_argument("--hand_set_model_path", type=str, default=None) 17 | parser.add_argument("--set_arch", type=str, default=None) 18 | parser.add_argument("--internal_index", type=int, default=None) 19 | 20 | parser.add_argument("--set_all2one_target", type=str, default=None) 21 | 22 | parser.add_argument("--ae_atk_succ_t", type=float, default=0.9) 23 | 24 | parser.add_argument("--ae_filter_num", type=int, default=32) 25 | parser.add_argument("--ae_num_blocks", type=int, default=4) 26 | 27 | parser.add_argument("--mask_size", type=float, default=0.03) 28 | parser.add_argument("--override_epoch", type=int, default=None) 29 | parser.add_argument("--ignore_dist", action='store_true') 30 | parser.add_argument("--p_loss_bound", type=float, default=0.15) 31 | parser.add_argument("--loss_std_bound", type=float, default=1) 32 | parser.add_argument("--asr_test_type", type=str, default="filter") 33 | 34 | 35 | parser.add_argument("--bs", type=int, default=256) 36 | parser.add_argument("--lr", type=float, default=1e-3) 37 | parser.add_argument("--num_workers", type=int, default=8) 38 | 39 | parser.add_argument("--EPSILON", type=float, default=1e-7) 40 | parser.add_argument("--use_norm", type=int, default=1) 41 | 42 | parser.add_argument("--mixed_value_threshold", type=float, default=-0.75) 43 | 44 | return parser 45 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | import csv 7 | import random 8 | import numpy as np 9 | 10 | from PIL import Image 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from torch.utils.data import Dataset 14 | 15 | from io import BytesIO 16 | 17 | 18 | def get_transform(opt, train=True, pretensor_transform=False): 19 | add_nad_transform = False 20 | 21 | transforms_list = [] 22 | transforms_list.append(transforms.Resize((opt.input_height, opt.input_width))) 23 | if pretensor_transform: 24 | if train: 25 | transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop)) 26 | transforms_list.append(transforms.RandomRotation(opt.random_rotation)) 27 | if opt.dataset == "cifar10": 28 | transforms_list.append(transforms.RandomHorizontalFlip(p=0.5)) 29 | 30 | if add_nad_transform: 31 | transforms_list.append(transforms.RandomCrop(opt.input_height, padding=4)) 32 | transforms_list.append(transforms.RandomHorizontalFlip()) 33 | 34 | 35 | transforms_list.append(transforms.ToTensor()) 36 | if opt.dataset == "cifar10": 37 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])) 38 | if add_nad_transform: 39 | transforms_list.append(Cutout(1,9)) 40 | 41 | elif opt.dataset == "mnist": 42 | transforms_list.append(transforms.Normalize([0.1307], [0.3081])) 43 | if add_nad_transform: 44 | transforms_list.append(Cutout(1,9)) 45 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba": 46 | transforms_list.append(transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669))) 47 | if add_nad_transform: 48 | transforms_list.append(Cutout(1,9)) 49 | elif opt.dataset == "imagenet": 50 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 51 | if add_nad_transform: 52 | transforms_list.append(Cutout(1,9)) 53 | else: 54 | raise Exception("Invalid Dataset") 55 | 56 | return transforms.Compose(transforms_list) 57 | 58 | 59 | 60 | class GTSRB(data.Dataset): 61 | def __init__(self, opt, train, transforms): 62 | super(GTSRB, self).__init__() 63 | if train: 64 | self.data_folder = os.path.join(opt.data_root, "GTSRB/Train") 65 | self.images, self.labels = self._get_data_train_list() 66 | else: 67 | self.data_folder = os.path.join(opt.data_root, "GTSRB/Test") 68 | self.images, self.labels = self._get_data_test_list() 69 | 70 | self.transforms = transforms 71 | 72 | def _get_data_train_list(self): 73 | images = [] 74 | labels = [] 75 | for c in range(0, 43): 76 | prefix = self.data_folder + "/" + format(c, "05d") + "/" 77 | gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv") 78 | gtReader = csv.reader(gtFile, delimiter=";") 79 | next(gtReader) 80 | for row in gtReader: 81 | images.append(prefix + row[0]) 82 | labels.append(int(row[7])) 83 | gtFile.close() 84 | return images, labels 85 | 86 | def _get_data_test_list(self): 87 | images = [] 88 | labels = [] 89 | prefix = os.path.join(self.data_folder, "GT-final_test.csv") 90 | gtFile = open(prefix) 91 | gtReader = csv.reader(gtFile, delimiter=";") 92 | next(gtReader) 93 | for row in gtReader: 94 | images.append(self.data_folder + "/" + row[0]) 95 | labels.append(int(row[7])) 96 | return images, labels 97 | 98 | def __len__(self): 99 | return len(self.images) 100 | 101 | def __getitem__(self, index): 102 | image = Image.open(self.images[index]) 103 | image = self.transforms(image) 104 | label = self.labels[index] 105 | return image, label 106 | 107 | def get_dataloader_partial_split(opt, train_fraction=0.1, train=True, pretensor_transform=False,shuffle=True,return_index = False): 108 | data_fraction = train_fraction 109 | 110 | transform_train = get_transform(opt, True, pretensor_transform) 111 | transform_test = get_transform(opt, False, pretensor_transform) 112 | 113 | transform = transform_train 114 | 115 | if opt.dataset == "gtsrb": 116 | dataset = GTSRB(opt, train, transform_train) 117 | dataset_test = GTSRB(opt, train, transform_test) 118 | class_num=43 119 | elif opt.dataset == "mnist": 120 | dataset = torchvision.datasets.MNIST(opt.data_root, train, transform=transform_train, download=True) 121 | dataset_test = torchvision.datasets.MNIST(opt.data_root, train, transform=transform_test, download=True) 122 | 123 | class_num=10 124 | elif opt.dataset == "cifar10": 125 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform_train, download=True) 126 | dataset_test = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform_test, download=True) 127 | class_num=10 128 | elif opt.dataset == "celeba": 129 | if train: 130 | split = "train" 131 | else: 132 | split = "test" 133 | dataset = CelebA_attr(opt, split, transform) 134 | class_num=8 135 | elif opt.dataset == "imagenet": 136 | if train==True: 137 | file_dir = "/workspace/data/imagenet/train" 138 | elif train==False: 139 | file_dir = "/workspace/data/imagenet/val" 140 | dataset = torchvision.datasets.ImageFolder( 141 | file_dir, 142 | transform 143 | ) 144 | dataset_test = torchvision.datasets.ImageFolder( 145 | file_dir, 146 | transform 147 | ) 148 | class_num=1000 149 | else: 150 | raise Exception("Invalid dataset") 151 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True) 152 | #finetuneset = torch.utils.data.Subset(dataset, range(0,dataset.__len__(),int(1/data_fraction))) 153 | dataloader_total = torch.utils.data.DataLoader(dataset, batch_size=1, pin_memory=True,num_workers=opt.num_workers, shuffle=False) 154 | 155 | idx = [] 156 | counter = [0]*class_num 157 | for batch_idx, (inputs, targets) in enumerate(dataloader_total): 158 | 159 | if counter[targets.item()] 1: 20 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 21 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = norm_layer(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = norm_layer(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | # Added another relu here 31 | self.relu2 = nn.ReLU(inplace=True) 32 | 33 | def forward(self, x): 34 | identity = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | identity = self.downsample(x) 45 | 46 | out += identity 47 | 48 | # Modified to use relu2 49 | out = self.relu2(out) 50 | 51 | return out 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * 4) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | x = self.conv1(x) 72 | x = self.bn1(x) 73 | x = self.relu(x) 74 | 75 | x = self.conv2(x) 76 | x = self.bn2(x) 77 | x = self.relu(x) 78 | 79 | x = self.conv3(x) 80 | x = self.bn3(x) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(residual) 84 | 85 | x += residual 86 | x = self.relu(x) 87 | 88 | return x 89 | 90 | 91 | class ResNet(nn.Module): 92 | 93 | def __init__(self, block, layers, num_classes=10,in_channels=3): 94 | self.inplanes = 64 95 | super(ResNet, self).__init__() 96 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 103 | self.avgpool = nn.AvgPool2d(kernel_size=4) 104 | self.fc = nn.Linear(512 * block.expansion, num_classes) 105 | 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 119 | nn.BatchNorm2d(planes * block.expansion), 120 | ) 121 | 122 | layers = [] 123 | layers.append(block(self.inplanes, planes, stride, downsample)) 124 | self.inplanes = planes * block.expansion 125 | for i in range(1, blocks): 126 | layers.append(block(self.inplanes, planes)) 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | x = self.layer4(x) 138 | 139 | x = self.avgpool(x) 140 | x = x.view(x.size(0), -1) 141 | x = self.fc(x) 142 | 143 | return x 144 | 145 | def from_input_to_features(self, x, index): 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | return x 155 | 156 | def from_features_to_output(self, x, index): 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | x = self.fc(x) 160 | return x 161 | 162 | def resnet18(**kwargs): 163 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | 165 | 166 | def resnet34(**kwargs): 167 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 168 | 169 | 170 | def resnet50(**kwargs): 171 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 172 | 173 | 174 | def resnet101(**kwargs): 175 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 176 | 177 | 178 | def resnet152(**kwargs): 179 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) -------------------------------------------------------------------------------- /reverse_engineering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | import torchvision 4 | import os 5 | import numpy as np 6 | from resnet_nole import * 7 | from models import meta_classifier_cifar10_model,lenet,ULP_model,preact_resnet 8 | import torch.nn.functional as F 9 | 10 | import unet_model 11 | import random 12 | import pilgram 13 | from PIL import Image 14 | from functools import reduce 15 | 16 | class RegressionModel(nn.Module): 17 | def __init__(self, opt, init_mask): 18 | self._EPSILON = opt.EPSILON 19 | super(RegressionModel, self).__init__() 20 | 21 | if init_mask is not None: 22 | self.mask_tanh = nn.Parameter(torch.tensor(init_mask)) 23 | 24 | self.classifier = self._get_classifier(opt) 25 | self.example_features = None 26 | 27 | if opt.dataset == "mnist": 28 | self.AE = unet_model.UNet(n_channels=1,num_classes=1,base_filter_num=opt.ae_filter_num, num_blocks=opt.ae_num_blocks) 29 | else: 30 | self.AE = unet_model.UNet(n_channels=3,num_classes=3,base_filter_num=opt.ae_filter_num, num_blocks=opt.ae_num_blocks) 31 | 32 | self.AE.train() 33 | self.example_ori_img = None 34 | self.example_ae_img = None 35 | self.opt = opt 36 | 37 | def forward_ori(self, x,opt): 38 | 39 | features = self.classifier.from_input_to_features(x, opt.internal_index) 40 | out = self.classifier.from_features_to_output(features, opt.internal_index) 41 | 42 | return out, features 43 | 44 | def forward_flip_mask(self, x,opt): 45 | 46 | strategy = "flip" 47 | features = self.classifier.from_input_to_features(x, opt.internal_index) 48 | if strategy == "flip": 49 | features = (1 - opt.flip_mask) * features - opt.flip_mask * features 50 | elif strategy == "zero": 51 | features = (1 - opt.flip_mask) * features 52 | 53 | out = self.classifier.from_features_to_output(features, opt.internal_index) 54 | 55 | return out, features 56 | 57 | def forward_ae(self, x,opt): 58 | 59 | self.example_ori_img = x 60 | x_before_ae = x 61 | x = self.AE(x) 62 | x_after_ae = x 63 | self.example_ae_img = x 64 | 65 | features = self.classifier.from_input_to_features(x, opt.internal_index) 66 | out = self.classifier.from_features_to_output(features, opt.internal_index) 67 | 68 | self.example_features = features 69 | 70 | return out, features, x_before_ae, x_after_ae 71 | 72 | 73 | def forward_ae_mask_p(self, x,opt): 74 | mask = self.get_raw_mask(opt) 75 | self.example_ori_img = x 76 | x_before_ae = x 77 | x = self.AE(x) 78 | x_after_ae = x 79 | self.example_ae_img = x 80 | 81 | features = self.classifier.from_input_to_features(x, opt.internal_index) 82 | reference_features_index_list = np.random.choice(range(opt.all_features.shape[0]), features.shape[0], replace=True) 83 | reference_features = opt.all_features[reference_features_index_list] 84 | features_ori = features 85 | features = mask * features + (1-mask) * reference_features.reshape(features.shape) 86 | 87 | out = self.classifier.from_features_to_output(features, opt.internal_index) 88 | 89 | self.example_features = features_ori 90 | 91 | return out, features, x_before_ae, x_after_ae, features_ori 92 | 93 | def forward_ae_mask_p_test(self, x,opt): 94 | mask = self.get_raw_mask(opt) 95 | self.example_ori_img = x 96 | x_before_ae = x 97 | x = self.AE(x) 98 | x_after_ae = x 99 | self.example_ae_img = x 100 | 101 | features = self.classifier.from_input_to_features(x, opt.internal_index) 102 | bs = features.shape[0] 103 | index_1 = list(range(bs)) 104 | random.shuffle(index_1) 105 | reference_features = features[index_1] 106 | features_ori = features 107 | features = mask * features + (1-mask) * reference_features.reshape(features.shape) 108 | out = self.classifier.from_features_to_output(features, opt.internal_index) 109 | self.example_features = features_ori 110 | 111 | return out, features, x_before_ae, x_after_ae, features_ori 112 | 113 | def get_raw_mask(self,opt): 114 | mask = nn.Tanh()(self.mask_tanh) 115 | bounded = mask / (2 + self._EPSILON) + 0.5 116 | return bounded 117 | 118 | def _get_classifier(self, opt): 119 | 120 | if opt.set_arch: 121 | if opt.set_arch == "resnet18": 122 | classifier = resnet18(num_classes = opt.num_classes, in_channels = opt.input_channel) 123 | elif opt.set_arch=="preact_resnet18": 124 | classifier = preact_resnet.PreActResNet18(num_classes=opt.num_classes) 125 | elif opt.set_arch=="meta_classifier_cifar10_model": 126 | classifier = meta_classifier_cifar10_model.MetaClassifierCifar10Model() 127 | elif opt.set_arch=="mnist_lenet": 128 | classifier = lenet.LeNet5() 129 | elif opt.set_arch=="ulp_vgg": 130 | classifier = ULP_model.CNN_classifier() 131 | else: 132 | print("invalid arch") 133 | 134 | if opt.hand_set_model_path: 135 | ckpt_path = opt.hand_set_model_path 136 | 137 | state_dict = torch.load(ckpt_path) 138 | try: 139 | classifier.load_state_dict(state_dict["net_state_dict"]) 140 | except: 141 | try: 142 | classifier.load_state_dict(state_dict["netC"]) 143 | except: 144 | try: 145 | from collections import OrderedDict 146 | new_state_dict = OrderedDict() 147 | for k, v in state_dict["model"].items(): 148 | name = k[7:] # remove `module.` 149 | new_state_dict[name] = v 150 | classifier.load_state_dict(new_state_dict) 151 | 152 | except: 153 | classifier.load_state_dict(state_dict) 154 | 155 | for param in classifier.parameters(): 156 | param.requires_grad = False 157 | classifier.eval() 158 | return classifier.to(opt.device) 159 | 160 | class Recorder: 161 | def __init__(self, opt): 162 | super().__init__() 163 | self.mixed_value_best = float("inf") 164 | 165 | def test_ori(opt, regression_model, testloader, flip=False): 166 | regression_model.eval() 167 | regression_model.AE.eval() 168 | regression_model.classifier.eval() 169 | total_pred = 0 170 | true_pred = 0 171 | cross_entropy = nn.CrossEntropyLoss() 172 | for inputs,labels in testloader: 173 | inputs = inputs.to(opt.device) 174 | labels = labels.to(opt.device) 175 | sample_num = inputs.shape[0] 176 | total_pred += sample_num 177 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label 178 | 179 | if flip: 180 | out, features = regression_model.forward_flip_mask(inputs,opt) 181 | else: 182 | out, features = regression_model.forward_ori(inputs,opt) 183 | predictions = out 184 | 185 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == labels).detach() 186 | loss_ce = cross_entropy(predictions, target_labels) 187 | 188 | print("BA true_pred:",true_pred) 189 | print("BA total_pred:",total_pred) 190 | print( 191 | "BA test acc:",true_pred * 100.0 / total_pred 192 | ) 193 | 194 | def test_ori_attack(opt, regression_model, testloader, flip=False): 195 | regression_model.eval() 196 | regression_model.AE.eval() 197 | regression_model.classifier.eval() 198 | total_pred = 0 199 | true_pred = 0 200 | cross_entropy = nn.CrossEntropyLoss() 201 | for inputs,labels in testloader: 202 | 203 | inputs = inputs.to(opt.device) 204 | 205 | if opt.asr_test_type == "filter": 206 | 207 | t_mean = opt.t_mean.cuda() 208 | t_std = opt.t_std.cuda() 209 | GT_img = inputs 210 | GT_img = (torch.clamp(GT_img*t_std+t_mean, min=0, max=1).detach().cpu().numpy()*255).astype(np.uint8) 211 | for j in range(GT_img.shape[0]): 212 | ori_pil_img = Image.fromarray(GT_img[j].transpose((1,2,0))) 213 | convered_pil_img = pilgram._1977(ori_pil_img) 214 | GT_img[j] = np.asarray(convered_pil_img).transpose((2,0,1)) 215 | GT_img = GT_img.astype(np.float32) 216 | GT_img = GT_img/255 217 | GT_img = torch.from_numpy(GT_img).cuda() 218 | GT_img = (GT_img - t_mean)/t_std 219 | inputs = GT_img 220 | elif opt.asr_test_type == "wanet": 221 | inputs = F.grid_sample(inputs, opt.grid_temps.repeat(inputs.shape[0], 1, 1, 1), align_corners=True) 222 | 223 | 224 | inputs = inputs.to(opt.device) 225 | labels = labels.to(opt.device) 226 | sample_num = inputs.shape[0] 227 | total_pred += sample_num 228 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label 229 | 230 | if flip: 231 | out, features = regression_model.forward_flip_mask(inputs,opt) 232 | else: 233 | out, features = regression_model.forward_ori(inputs,opt) 234 | predictions = out 235 | 236 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() 237 | loss_ce = cross_entropy(predictions, target_labels) 238 | 239 | print("ASR true_pred:",true_pred) 240 | print("ASR total_pred:",total_pred) 241 | print( 242 | "ASR test acc:",true_pred * 100.0 / total_pred 243 | ) 244 | 245 | def fix_neuron_flip(opt,trainloader,testloader,testloader_asr): 246 | 247 | trained_regression_model = opt.trained_regression_model 248 | trained_regression_model.eval() 249 | trained_regression_model.AE.eval() 250 | trained_regression_model.classifier.eval() 251 | 252 | if opt.asr_test_type == "wanet": 253 | ckpt_path = opt.hand_set_model_path 254 | state_dict = torch.load(ckpt_path) 255 | identity_grid = state_dict["identity_grid"] 256 | noise_grid = state_dict["noise_grid"] 257 | grid_temps = (identity_grid + 0.5 * noise_grid / opt.input_height) * 1 258 | grid_temps = torch.clamp(grid_temps, -1, 1) 259 | 260 | opt.grid_temps = grid_temps 261 | 262 | test_ori(opt, trained_regression_model,testloader,flip=False) 263 | test_ori_attack(opt, trained_regression_model,testloader_asr,flip=False) 264 | 265 | neuron_finding_strategy = "hyperplane" 266 | 267 | cross_entropy = nn.CrossEntropyLoss() 268 | for batch_idx, (inputs, labels) in enumerate(trainloader): 269 | inputs = inputs.to(opt.device) 270 | labels = labels.to(opt.device) 271 | out, features_reversed, x_before_ae, x_after_ae = trained_regression_model.forward_ae(inputs,opt) 272 | loss_ce_transformed = cross_entropy(out, labels) 273 | 274 | out, features_ori = trained_regression_model.forward_ori(inputs,opt) 275 | loss_ce_ori = cross_entropy(out, labels) 276 | 277 | feature_dist = torch.nn.MSELoss(reduction='none').cuda()(features_ori,features_reversed).mean(0) 278 | print(feature_dist) 279 | 280 | if neuron_finding_strategy == "diff": 281 | values, indices = feature_dist.reshape(-1).topk(int(0.03*torch.numel(feature_dist)), largest=True, sorted=True) 282 | flip_mask = torch.zeros(feature_dist.reshape(-1).shape).to(opt.device) 283 | for index in indices: 284 | flip_mask[index] = 1 285 | flip_mask = flip_mask.reshape(feature_dist.shape) 286 | 287 | elif neuron_finding_strategy == "hyperplane": 288 | flip_mask = trained_regression_model.get_raw_mask(opt) 289 | 290 | opt.flip_mask = flip_mask 291 | 292 | print("loss_ce_transformed:",loss_ce_transformed) 293 | print("loss_ce_ori:",loss_ce_ori) 294 | 295 | 296 | test_ori(opt, trained_regression_model,testloader,flip=True) 297 | test_ori_attack(opt, trained_regression_model,testloader_asr,flip=True) 298 | 299 | def train(opt, init_mask): 300 | 301 | data_now = opt.data_now 302 | opt.weight_p = 1 303 | opt.weight_acc = 1 304 | opt.weight_std = 1 305 | opt.init_mask = init_mask 306 | 307 | recorder = Recorder(opt) 308 | regression_model = RegressionModel(opt, init_mask).to(opt.device) 309 | 310 | opt.epoch = 400 311 | if opt.override_epoch: 312 | opt.epoch = opt.override_epoch 313 | 314 | optimizerR = torch.optim.Adam(regression_model.AE.parameters(),lr=opt.lr,betas=(0.5,0.9)) 315 | optimizerR_mask = torch.optim.Adam([regression_model.mask_tanh],lr=1e-1,betas=(0.5,0.9)) 316 | 317 | regression_model.AE.train() 318 | recorder = Recorder(opt) 319 | process = train_step 320 | 321 | warm_up_epoch = 100 322 | for epoch in range(warm_up_epoch): 323 | process(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt, warm_up=True) 324 | 325 | for epoch in range(opt.epoch): 326 | process(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt) 327 | 328 | opt.trained_regression_model = regression_model 329 | 330 | return recorder, opt 331 | 332 | def get_range(opt, init_mask): 333 | 334 | test_dataloader = opt.re_dataloader_total_fixed 335 | inversion_engine = RegressionModel(opt, init_mask).to(opt.device) 336 | 337 | features_list = [] 338 | features_list_class = [[] for i in range(opt.num_classes)] 339 | for batch_idx, (inputs, labels) in enumerate(test_dataloader): 340 | inputs = inputs.to(opt.device) 341 | out, features = inversion_engine.forward_ori(inputs,opt) 342 | print(torch.argmax(out,dim=1)) 343 | 344 | features_list.append(features) 345 | for i in range(inputs.shape[0]): 346 | features_list_class[labels[i].item()].append(features[i].unsqueeze(0)) 347 | all_features = torch.cat(features_list,dim=0) 348 | opt.all_features = all_features 349 | print(all_features.shape) 350 | 351 | del features_list 352 | del test_dataloader 353 | 354 | weight_map_class = [] 355 | for i in range(opt.num_classes): 356 | feature_mean_class = torch.cat(features_list_class[i],dim=0).mean(0) 357 | weight_map_class.append(feature_mean_class) 358 | 359 | opt.weight_map_class = weight_map_class 360 | del all_features 361 | del features_list_class 362 | 363 | def train_step(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt, warm_up=False): 364 | print("Epoch {} - Label: {} | {} - {}:".format(epoch, opt.target_label, opt.dataset, opt.attack_mode)) 365 | cross_entropy = nn.CrossEntropyLoss() 366 | total_pred = 0 367 | true_pred = 0 368 | 369 | loss_ce_list = [] 370 | loss_dist_list = [] 371 | loss_list = [] 372 | acc_list = [] 373 | 374 | p_loss_list = [] 375 | loss_mask_norm_list = [] 376 | loss_std_list = [] 377 | 378 | for inputs in data_now: 379 | regression_model.AE.train() 380 | regression_model.mask_tanh.requires_grad = False 381 | 382 | optimizerR.zero_grad() 383 | 384 | inputs = inputs.to(opt.device) 385 | sample_num = inputs.shape[0] 386 | total_pred += sample_num 387 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label 388 | if warm_up: 389 | predictions, features, x_before_ae, x_after_ae = regression_model.forward_ae(inputs,opt) 390 | else: 391 | predictions, features, x_before_ae, x_after_ae, features_ori = regression_model.forward_ae_mask_p(inputs,opt) 392 | 393 | loss_ce = cross_entropy(predictions, target_labels) 394 | 395 | mse_loss = torch.nn.MSELoss(size_average = True).cuda()(x_after_ae,x_before_ae) 396 | 397 | if warm_up: 398 | dist_loss = torch.cosine_similarity(opt.weight_map_class[opt.target_label].reshape(-1),features.mean(0).reshape(-1),dim=0) 399 | else: 400 | dist_loss = torch.cosine_similarity(opt.weight_map_class[opt.target_label].reshape(-1),features_ori.mean(0).reshape(-1),dim=0) 401 | 402 | acc_list_ = [] 403 | minibatch_accuracy_ = torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() / sample_num 404 | acc_list_.append(minibatch_accuracy_) 405 | acc_list_ = torch.stack(acc_list_) 406 | avg_acc_G = torch.mean(acc_list_) 407 | 408 | acc_list.append(minibatch_accuracy_) 409 | 410 | p_loss = mse_loss 411 | p_loss_bound = opt.p_loss_bound 412 | loss_std_bound = opt.loss_std_bound 413 | 414 | atk_succ_threshold = opt.ae_atk_succ_t 415 | 416 | if opt.ignore_dist: 417 | dist_loss = dist_loss*0 418 | 419 | if warm_up: 420 | if (p_loss>p_loss_bound): 421 | total_loss = loss_ce + p_loss*100 422 | else: 423 | total_loss = loss_ce 424 | else: 425 | loss_std = (features_ori*regression_model.get_raw_mask(opt)).std(0).sum() 426 | loss_std = loss_std/(torch.norm(regression_model.get_raw_mask(opt), 1)) 427 | 428 | total_loss = dist_loss*5 429 | if dist_loss<0: 430 | total_loss = total_loss - dist_loss*5 431 | if loss_std>loss_std_bound: 432 | total_loss = total_loss + loss_std*10*(1+opt.weight_std) 433 | if (p_loss>p_loss_bound): 434 | total_loss = total_loss + p_loss*10*(1+opt.weight_p) 435 | 436 | if avg_acc_G.item()mask_norm_bound: 455 | loss_mask_total = loss_mask_total + loss_mask_norm 456 | 457 | loss_mask_total.backward() 458 | optimizerR_mask.step() 459 | 460 | loss_ce_list.append(loss_ce.detach()) 461 | loss_dist_list.append(dist_loss.detach()) 462 | loss_list.append(total_loss.detach()) 463 | 464 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() 465 | 466 | if not warm_up: 467 | p_loss_list.append(p_loss) 468 | loss_mask_norm_list.append(loss_mask_norm) 469 | loss_std_list.append(loss_std) 470 | 471 | loss_ce_list = torch.stack(loss_ce_list) 472 | loss_dist_list = torch.stack(loss_dist_list) 473 | loss_list = torch.stack(loss_list) 474 | acc_list = torch.stack(acc_list) 475 | 476 | avg_loss_ce = torch.mean(loss_ce_list) 477 | avg_loss_dist = torch.mean(loss_dist_list) 478 | avg_loss = torch.mean(loss_list) 479 | avg_acc = torch.mean(acc_list) 480 | 481 | if not warm_up: 482 | p_loss_list = torch.stack(p_loss_list) 483 | loss_mask_norm_list = torch.stack(loss_mask_norm_list) 484 | loss_std_list = torch.stack(loss_std_list) 485 | 486 | avg_p_loss = torch.mean(p_loss_list) 487 | avg_loss_mask_norm = torch.mean(loss_mask_norm_list) 488 | avg_loss_std = torch.mean(loss_std_list) 489 | print("avg_ce_loss:",avg_loss_ce) 490 | print("avg_asr:",avg_acc) 491 | print("avg_p_loss:",avg_p_loss) 492 | print("avg_loss_mask_norm:",avg_loss_mask_norm) 493 | print("avg_loss_std:",avg_loss_std) 494 | 495 | 496 | if avg_acc.item()1.0*p_loss_bound: 499 | print("@avg_p_loss larger than bound") 500 | if avg_loss_mask_norm>1.0*mask_norm_bound: 501 | print("@avg_loss_mask_norm larger than bound") 502 | if avg_loss_std>1.0*loss_std_bound: 503 | print("@avg_loss_std larger than bound") 504 | 505 | 506 | mixed_value = avg_loss_dist.detach() - avg_acc + max(avg_p_loss.detach()-p_loss_bound,0)/p_loss_bound + max(avg_loss_mask_norm.detach()-mask_norm_bound,0)/mask_norm_bound + max(avg_loss_std.detach()-loss_std_bound,0)/loss_std_bound 507 | print("mixed_value:",mixed_value) 508 | if mixed_value < recorder.mixed_value_best: 509 | recorder.mixed_value_best = mixed_value 510 | opt.weight_p = max(avg_p_loss.detach()-p_loss_bound,0)/p_loss_bound 511 | opt.weight_acc = max(atk_succ_threshold-avg_acc,0)/atk_succ_threshold 512 | opt.weight_std = max(avg_loss_std.detach()-loss_std_bound,0)/loss_std_bound 513 | 514 | 515 | print( 516 | " Result: ASR: {:.3f} | Cross Entropy Loss: {:.6f} | Dist Loss: {:.6f} | Mixed_value best: {:.6f}".format( 517 | true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_dist, recorder.mixed_value_best 518 | ) 519 | ) 520 | 521 | recorder.final_asr = avg_acc 522 | 523 | return avg_acc 524 | 525 | if __name__ == "__main__": 526 | pass 527 | -------------------------------------------------------------------------------- /train_models/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_arguments(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument("--data_root", type=str, default="./data/") 8 | parser.add_argument("--checkpoints", type=str, default="./checkpoints") 9 | parser.add_argument("--temps", type=str, default="./temps") 10 | parser.add_argument("--device", type=str, default="cuda") 11 | parser.add_argument("--continue_training", action="store_true") 12 | 13 | parser.add_argument("--model_filepath", type=str, default="./checkpoints") 14 | 15 | parser.add_argument("--dataset", type=str, default="cifar10") 16 | parser.add_argument("--set_arch", type=str, default=None) 17 | parser.add_argument("--attack_mode", type=str, default="all2one") 18 | 19 | parser.add_argument("--save_all", type=bool, default=False) 20 | parser.add_argument("--save_freq", type=int, default=50) 21 | 22 | parser.add_argument("--bs", type=int, default=128) 23 | parser.add_argument("--lr_C", type=float, default=1e-2) 24 | parser.add_argument("--schedulerC_milestones", type=list, default=[100, 200, 300, 400]) 25 | parser.add_argument("--schedulerC_lambda", type=float, default=0.1) 26 | parser.add_argument("--n_iters", type=int, default=1000) 27 | parser.add_argument("--num_workers", type=float, default=6) 28 | 29 | parser.add_argument("--target_label", type=int, default=0) 30 | parser.add_argument("--pc", type=float, default=0.1) 31 | parser.add_argument("--cross_ratio", type=float, default=2) # rho_a = pc, rho_n = pc * cross_ratio 32 | 33 | parser.add_argument("--random_rotation", type=int, default=10) 34 | parser.add_argument("--random_crop", type=int, default=5) 35 | 36 | parser.add_argument("--extra_flag", type=str, default="") 37 | 38 | parser.add_argument("--s", type=float, default=0.5) 39 | parser.add_argument("--k", type=int, default=4) 40 | parser.add_argument( 41 | "--grid-rescale", type=float, default=1 42 | ) # scale grid values to avoid pixel values going out of [-1, 1]. For example, grid-rescale = 0.98 43 | 44 | return parser 45 | -------------------------------------------------------------------------------- /train_models/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | import csv 7 | import kornia.augmentation as A 8 | import random 9 | import numpy as np 10 | 11 | from PIL import Image 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from torch.utils.data import Dataset 15 | from natsort import natsorted 16 | 17 | from io import BytesIO 18 | 19 | class ToNumpy: 20 | def __call__(self, x): 21 | x = np.array(x) 22 | if len(x.shape) == 2: 23 | x = np.expand_dims(x, axis=2) 24 | return x 25 | 26 | 27 | class ProbTransform(torch.nn.Module): 28 | def __init__(self, f, p=1): 29 | super(ProbTransform, self).__init__() 30 | self.f = f 31 | self.p = p 32 | 33 | def forward(self, x): # , **kwargs): 34 | if random.random() < self.p: 35 | return self.f(x) 36 | else: 37 | return x 38 | 39 | def get_transform(opt, train=True, pretensor_transform=False): 40 | add_nad_transform = False 41 | 42 | if opt.dataset == "trojai": 43 | return transforms.Compose([transforms.CenterCrop(opt.input_height),transforms.ToTensor()]) 44 | 45 | transforms_list = [] 46 | transforms_list.append(transforms.Resize((opt.input_height, opt.input_width))) 47 | if pretensor_transform: 48 | if train: 49 | transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop)) 50 | transforms_list.append(transforms.RandomRotation(opt.random_rotation)) 51 | if opt.dataset == "cifar10": 52 | transforms_list.append(transforms.RandomHorizontalFlip(p=0.5)) 53 | 54 | if add_nad_transform: 55 | transforms_list.append(transforms.RandomCrop(opt.input_height, padding=4)) 56 | transforms_list.append(transforms.RandomHorizontalFlip()) 57 | 58 | 59 | transforms_list.append(transforms.ToTensor()) 60 | if (opt.set_arch is not None) and (("nole" in opt.set_arch) or ("mnist_lenet" in opt.set_arch)): 61 | if opt.dataset == "cifar10": 62 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])) 63 | if add_nad_transform: 64 | transforms_list.append(Cutout(1,9)) 65 | 66 | elif opt.dataset == "mnist": 67 | transforms_list.append(transforms.Normalize([0.1307], [0.3081])) 68 | if add_nad_transform: 69 | transforms_list.append(Cutout(1,9)) 70 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba": 71 | transforms_list.append(transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669))) 72 | if add_nad_transform: 73 | transforms_list.append(Cutout(1,9)) 74 | elif opt.dataset == "imagenet": 75 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 76 | if add_nad_transform: 77 | transforms_list.append(Cutout(1,9)) 78 | else: 79 | raise Exception("Invalid Dataset") 80 | else: 81 | if opt.dataset == "cifar10": 82 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])) 83 | if add_nad_transform: 84 | transforms_list.append(Cutout(1,9)) 85 | elif opt.dataset == "mnist": 86 | transforms_list.append(transforms.Normalize([0.5], [0.5])) 87 | if add_nad_transform: 88 | transforms_list.append(Cutout(1,9)) 89 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba": 90 | pass 91 | elif opt.dataset == "imagenet": 92 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 93 | if add_nad_transform: 94 | transforms_list.append(Cutout(1,9)) 95 | else: 96 | raise Exception("Invalid Dataset") 97 | return transforms.Compose(transforms_list) 98 | class Cutout(object): 99 | """Randomly mask out one or more patches from an image. 100 | Args: 101 | n_holes (int): Number of patches to cut out of each image. 102 | length (int): The length (in pixels) of each square patch. 103 | """ 104 | def __init__(self, n_holes, length): 105 | self.n_holes = n_holes 106 | self.length = length 107 | 108 | def __call__(self, img): 109 | """ 110 | Args: 111 | img (Tensor): Tensor image of size (C, H, W). 112 | Returns: 113 | Tensor: Image with n_holes of dimension length x length cut out of it. 114 | """ 115 | h = img.size(1) 116 | w = img.size(2) 117 | 118 | mask = np.ones((h, w), np.float32) 119 | 120 | for n in range(self.n_holes): 121 | y = np.random.randint(h) 122 | x = np.random.randint(w) 123 | 124 | y1 = np.clip(y - self.length // 2, 0, h) 125 | y2 = np.clip(y + self.length // 2, 0, h) 126 | x1 = np.clip(x - self.length // 2, 0, w) 127 | x2 = np.clip(x + self.length // 2, 0, w) 128 | 129 | mask[y1: y2, x1: x2] = 0. 130 | 131 | mask = torch.from_numpy(mask) 132 | mask = mask.expand_as(img) 133 | img = img * mask 134 | #print(img) 135 | 136 | return img 137 | 138 | class PostTensorTransform(torch.nn.Module): 139 | def __init__(self, opt): 140 | super(PostTensorTransform, self).__init__() 141 | self.random_crop = ProbTransform( 142 | A.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop), p=0.8 143 | ) 144 | self.random_rotation = ProbTransform(A.RandomRotation(opt.random_rotation), p=0.5) 145 | if opt.dataset == "cifar10": 146 | self.random_horizontal_flip = A.RandomHorizontalFlip(p=0.5) 147 | 148 | def forward(self, x): 149 | for module in self.children(): 150 | x = module(x) 151 | return x 152 | 153 | def get_dataloader(opt, train=True, pretensor_transform=False, shuffle=True, return_dataset = False): 154 | transform = get_transform(opt, train, pretensor_transform) 155 | if opt.dataset == "cifar10": 156 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform, download=True) 157 | else: 158 | raise Exception("Invalid dataset") 159 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle) 160 | if return_dataset: 161 | return dataset, dataloader, transform 162 | else: 163 | return dataloader, transform 164 | 165 | def get_dataloader_random_ratio(opt, train=True, pretensor_transform=False, shuffle=True): 166 | transform = get_transform(opt, train, pretensor_transform) 167 | if opt.dataset == "cifar10": 168 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform, download=True) 169 | else: 170 | raise Exception("Invalid dataset") 171 | 172 | idx = random.sample(range(dataset.__len__()),int(dataset.__len__()*opt.random_ratio)) 173 | dataset = torch.utils.data.Subset(dataset,idx) 174 | #trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) 175 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle) 176 | return dataloader, transform 177 | 178 | def main(): 179 | pass 180 | 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /train_models/resnet_nole.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | # 3x3 convolution with padding 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | '''class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | #print(downsample) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | x = self.conv1(x) 28 | x = self.bn1(x) 29 | x = self.relu(x) 30 | 31 | x = self.conv2(x) 32 | x = self.bn2(x) 33 | 34 | if self.downsample is not None: 35 | #print(x.shape) 36 | residual = self.downsample(residual) 37 | 38 | x += residual 39 | x = self.relu(x) 40 | 41 | return x 42 | 43 | def input_to_residual(self, x): 44 | residual = x 45 | if self.downsample is not None: 46 | residual = self.downsample(residual) 47 | return residual 48 | 49 | def residual_to_output(self, residual,conv2): 50 | x = residual + conv2 51 | x = self.relu(x) 52 | 53 | return x 54 | 55 | 56 | def input_to_conv2(self, x): 57 | residual = x 58 | x = self.conv1(x) 59 | x = self.bn1(x) 60 | x = self.relu(x) 61 | x = self.conv2(x) 62 | return x 63 | 64 | def conv2_to_output(self, x, residual): 65 | x = self.bn2(x) 66 | x = residual + x 67 | x = self.relu(x) 68 | return x 69 | 70 | def conv2_to_output_mask(self, x, residual,mask,pattern): 71 | x = self.bn2(x) 72 | x = residual + x 73 | x = (1 - mask) * x + mask * pattern 74 | x = self.relu(x) 75 | return x 76 | 77 | def input_to_conv1(self, x): 78 | x = self.conv1(x) 79 | return x 80 | 81 | def conv1_to_output(self, x, residual): 82 | x = self.bn1(x) 83 | x = self.relu(x) 84 | 85 | x = self.conv2(x) 86 | x = self.bn2(x) 87 | 88 | x += residual 89 | x = self.relu(x) 90 | 91 | return x''' 92 | 93 | class BasicBlock(nn.Module): 94 | expansion = 1 95 | 96 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 97 | base_width=64, dilation=1, norm_layer=None): 98 | super(BasicBlock, self).__init__() 99 | if norm_layer is None: 100 | norm_layer = nn.BatchNorm2d 101 | if groups != 1 or base_width != 64: 102 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 103 | if dilation > 1: 104 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 105 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 106 | self.conv1 = conv3x3(inplanes, planes, stride) 107 | self.bn1 = norm_layer(planes) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(planes, planes) 110 | self.bn2 = norm_layer(planes) 111 | self.downsample = downsample 112 | self.stride = stride 113 | 114 | # Added another relu here 115 | self.relu2 = nn.ReLU(inplace=True) 116 | 117 | def forward(self, x): 118 | identity = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | 127 | if self.downsample is not None: 128 | identity = self.downsample(x) 129 | 130 | out += identity 131 | 132 | # Modified to use relu2 133 | out = self.relu2(out) 134 | 135 | return out 136 | 137 | class Bottleneck(nn.Module): 138 | expansion = 4 139 | 140 | def __init__(self, inplanes, planes, stride=1, downsample=None): 141 | super(Bottleneck, self).__init__() 142 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 143 | self.bn1 = nn.BatchNorm2d(planes) 144 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 145 | self.bn2 = nn.BatchNorm2d(planes) 146 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 147 | self.bn3 = nn.BatchNorm2d(planes * 4) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.downsample = downsample 150 | self.stride = stride 151 | 152 | def forward(self, x): 153 | residual = x 154 | 155 | x = self.conv1(x) 156 | x = self.bn1(x) 157 | x = self.relu(x) 158 | 159 | x = self.conv2(x) 160 | x = self.bn2(x) 161 | x = self.relu(x) 162 | 163 | x = self.conv3(x) 164 | x = self.bn3(x) 165 | 166 | if self.downsample is not None: 167 | residual = self.downsample(residual) 168 | 169 | x += residual 170 | x = self.relu(x) 171 | 172 | return x 173 | 174 | 175 | class ResNet(nn.Module): 176 | 177 | def __init__(self, block, layers, num_classes=10,in_channels=3): 178 | self.inplanes = 64 179 | super(ResNet, self).__init__() 180 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 181 | self.bn1 = nn.BatchNorm2d(64) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.layer1 = self._make_layer(block, 64, layers[0]) 184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 185 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 186 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 187 | self.avgpool = nn.AvgPool2d(kernel_size=4) 188 | self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | 190 | self.inter_feature = {} 191 | self.inter_gradient = {} 192 | 193 | self.register_all_hooks() 194 | 195 | for m in self.modules(): 196 | if isinstance(m, nn.Conv2d): 197 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 198 | m.weight.data.normal_(0, math.sqrt(2. / n)) 199 | elif isinstance(m, nn.BatchNorm2d): 200 | m.weight.data.fill_(1) 201 | m.bias.data.zero_() 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1): 204 | downsample = None 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | downsample = nn.Sequential( 207 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 208 | nn.BatchNorm2d(planes * block.expansion), 209 | ) 210 | 211 | layers = [] 212 | layers.append(block(self.inplanes, planes, stride, downsample)) 213 | self.inplanes = planes * block.expansion 214 | for i in range(1, blocks): 215 | layers.append(block(self.inplanes, planes)) 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x): 219 | x = self.conv1(x) 220 | x = self.bn1(x) 221 | x = self.relu(x) 222 | 223 | x = self.layer1(x) 224 | x = self.layer2(x) 225 | x = self.layer3(x) 226 | x = self.layer4(x) 227 | 228 | x = self.avgpool(x) 229 | x = x.view(x.size(0), -1) 230 | x = self.fc(x) 231 | 232 | return x 233 | 234 | def get_fm(self, x): 235 | x = self.conv1(x) 236 | x = self.bn1(x) 237 | x = self.relu(x) 238 | 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | x = self.layer4(x) 243 | 244 | #x = self.avgpool(x) 245 | 246 | return x 247 | 248 | def input_to_conv1(self, x): 249 | x = self.conv1(x) 250 | 251 | return x 252 | 253 | def conv1_to_output(self, x): 254 | #x = self.conv1(x) 255 | x = self.bn1(x) 256 | x = self.relu(x) 257 | 258 | x = self.layer1(x) 259 | x = self.layer2(x) 260 | x = self.layer3(x) 261 | x = self.layer4(x) 262 | 263 | x = self.avgpool(x) 264 | x = x.view(x.size(0), -1) 265 | x = self.fc(x) 266 | return x 267 | 268 | def input_to_layer1(self, x): 269 | x = self.conv1(x) 270 | x = self.bn1(x) 271 | x = self.relu(x) 272 | 273 | x = self.layer1(x) 274 | 275 | return x 276 | 277 | def layer1_to_output(self, x): 278 | #x = self.conv1(x) 279 | #x = self.bn1(x) 280 | #x = self.relu(x) 281 | 282 | #x = self.layer1(x) 283 | x = self.layer2(x) 284 | x = self.layer3(x) 285 | x = self.layer4(x) 286 | 287 | x = self.avgpool(x) 288 | x = x.view(x.size(0), -1) 289 | x = self.fc(x) 290 | return x 291 | 292 | def input_to_layer2(self, x): 293 | x = self.conv1(x) 294 | x = self.bn1(x) 295 | x = self.relu(x) 296 | 297 | x = self.layer1(x) 298 | x = self.layer2(x) 299 | 300 | return x 301 | 302 | def layer2_to_output(self, x): 303 | #x = self.conv1(x) 304 | #x = self.bn1(x) 305 | #x = self.relu(x) 306 | 307 | #x = self.layer1(x) 308 | #x = self.layer2(x) 309 | x = self.layer3(x) 310 | x = self.layer4(x) 311 | 312 | x = self.avgpool(x) 313 | x = x.view(x.size(0), -1) 314 | x = self.fc(x) 315 | return x 316 | 317 | def input_to_layer3(self, x): 318 | x = self.conv1(x) 319 | x = self.bn1(x) 320 | x = self.relu(x) 321 | 322 | x = self.layer1(x) 323 | x = self.layer2(x) 324 | x = self.layer3(x) 325 | 326 | return x 327 | 328 | def layer3_to_output(self, x): 329 | 330 | #x = self.conv1(x) 331 | #x = self.bn1(x) 332 | #x = self.relu(x) 333 | 334 | #x = self.layer1(x) 335 | #x = self.layer2(x) 336 | #x = self.layer3(x) 337 | x = self.layer4(x) 338 | 339 | x = self.avgpool(x) 340 | x = x.view(x.size(0), -1) 341 | x = self.fc(x) 342 | return x 343 | 344 | def input_to_layer4(self, x): 345 | x = self.conv1(x) 346 | x = self.bn1(x) 347 | x = self.relu(x) 348 | 349 | x = self.layer1(x) 350 | x = self.layer2(x) 351 | x = self.layer3(x) 352 | x = self.layer4(x) 353 | 354 | return x 355 | 356 | def layer4_to_output(self, x): 357 | 358 | x = self.avgpool(x) 359 | x = x.view(x.size(0), -1) 360 | x = self.fc(x) 361 | return x 362 | 363 | def make_hook(self, name, flag): 364 | if flag == 'forward': 365 | def hook(m, input, output): 366 | self.inter_feature[name] = output 367 | return hook 368 | elif flag == 'backward': 369 | def hook(m, input, output): 370 | self.inter_gradient[name] = output 371 | return hook 372 | else: 373 | assert False 374 | 375 | def register_all_hooks(self): 376 | self.conv1.register_forward_hook(self.make_hook("Conv1_Conv1_Conv1_", 'forward')) 377 | self.layer1[0].conv1.register_forward_hook(self.make_hook("Layer1_0_Conv1_", 'forward')) 378 | self.layer1[0].conv2.register_forward_hook(self.make_hook("Layer1_0_Conv2_", 'forward')) 379 | self.layer1[1].conv1.register_forward_hook(self.make_hook("Layer1_1_Conv1_", 'forward')) 380 | self.layer1[1].conv2.register_forward_hook(self.make_hook("Layer1_1_Conv2_", 'forward')) 381 | 382 | self.layer2[0].conv1.register_forward_hook(self.make_hook("Layer2_0_Conv1_", 'forward')) 383 | self.layer2[0].downsample.register_forward_hook(self.make_hook("Layer2_0_Downsample_", 'forward')) 384 | self.layer2[0].conv2.register_forward_hook(self.make_hook("Layer2_0_Conv2_", 'forward')) 385 | self.layer2[1].conv1.register_forward_hook(self.make_hook("Layer2_1_Conv1_", 'forward')) 386 | self.layer2[1].conv2.register_forward_hook(self.make_hook("Layer2_1_Conv2_", 'forward')) 387 | 388 | self.layer3[0].conv1.register_forward_hook(self.make_hook("Layer3_0_Conv1_", 'forward')) 389 | self.layer3[0].downsample.register_forward_hook(self.make_hook("Layer3_0_Downsample_", 'forward')) 390 | self.layer3[0].conv2.register_forward_hook(self.make_hook("Layer3_0_Conv2_", 'forward')) 391 | self.layer3[1].conv1.register_forward_hook(self.make_hook("Layer3_1_Conv1_", 'forward')) 392 | self.layer3[1].conv2.register_forward_hook(self.make_hook("Layer3_1_Conv2_", 'forward')) 393 | 394 | self.layer4[0].conv1.register_forward_hook(self.make_hook("Layer4_0_Conv1_", 'forward')) 395 | self.layer4[0].downsample.register_forward_hook(self.make_hook("Layer4_0_Downsample_", 'forward')) 396 | self.layer4[0].conv2.register_forward_hook(self.make_hook("Layer4_0_Conv2_", 'forward')) 397 | self.layer4[1].conv1.register_forward_hook(self.make_hook("Layer4_1_Conv1_", 'forward')) 398 | self.layer4[1].conv2.register_forward_hook(self.make_hook("Layer4_1_Conv2_", 'forward')) 399 | 400 | 401 | 402 | '''def get_all_inner_activation(self, x): 403 | inner_output_index = [0,2,4,8,10,12,16,18] 404 | inner_output_list = [] 405 | for i in range(23): 406 | x = self.classifier[i](x) 407 | if i in inner_output_index: 408 | inner_output_list.append(x) 409 | x = x.view(x.size(0), self.num_classes) 410 | return x,inner_output_list''' 411 | 412 | ############################################################################# 413 | def input_to_conv1(self, x): 414 | x = self.conv1(x) 415 | return x 416 | 417 | def conv1_to_output(self, x): 418 | x = self.bn1(x) 419 | x = self.relu(x) 420 | 421 | x = self.layer1(x) 422 | x = self.layer2(x) 423 | x = self.layer3(x) 424 | x = self.layer4(x) 425 | 426 | x = self.avgpool(x) 427 | x = x.view(x.size(0), -1) 428 | x = self.fc(x) 429 | 430 | return x 431 | 432 | ############################################################################# 433 | def input_to_layer1_0_residual(self, x): 434 | x = self.conv1(x) 435 | x = self.bn1(x) 436 | x = self.relu(x) 437 | 438 | x = self.layer1[0].input_to_residual(x) 439 | 440 | return x 441 | 442 | def layer1_0_residual_to_output(self, residual, conv2): 443 | 444 | x = self.layer1[0].residual_to_output(residual,conv2) 445 | x = self.layer1[1](x) 446 | x = self.layer2(x) 447 | x = self.layer3(x) 448 | x = self.layer4(x) 449 | 450 | x = self.avgpool(x) 451 | x = x.view(x.size(0), -1) 452 | x = self.fc(x) 453 | return x 454 | 455 | def input_to_layer1_0_conv2(self, x): 456 | x = self.conv1(x) 457 | x = self.bn1(x) 458 | x = self.relu(x) 459 | x = self.layer1[0].input_to_conv2(x) 460 | return x 461 | 462 | def layer1_0_conv2_to_output(self, x, residual): 463 | x = self.layer1[0].conv2_to_output(x, residual) 464 | x = self.layer1[1](x) 465 | x = self.layer2(x) 466 | x = self.layer3(x) 467 | x = self.layer4(x) 468 | x = self.avgpool(x) 469 | x = x.view(x.size(0), -1) 470 | x = self.fc(x) 471 | return x 472 | 473 | def input_to_layer1_0_conv1(self, x): 474 | x = self.conv1(x) 475 | x = self.bn1(x) 476 | x = self.relu(x) 477 | x = self.layer1[0].input_to_conv1(x) 478 | return x 479 | 480 | def layer1_0_conv1_to_output(self, x, residual): 481 | x = self.layer1[0].conv1_to_output(x, residual) 482 | x = self.layer1[1](x) 483 | x = self.layer2(x) 484 | x = self.layer3(x) 485 | x = self.layer4(x) 486 | x = self.avgpool(x) 487 | x = x.view(x.size(0), -1) 488 | x = self.fc(x) 489 | return x 490 | ############################################################################# 491 | 492 | def input_to_layer1_1_residual(self, x): 493 | x = self.conv1(x) 494 | x = self.bn1(x) 495 | x = self.relu(x) 496 | x = self.layer1[0](x) 497 | x = self.layer1[1].input_to_residual(x) 498 | 499 | return x 500 | 501 | def input_to_layer1_1_conv2(self, x): 502 | x = self.conv1(x) 503 | x = self.bn1(x) 504 | x = self.relu(x) 505 | x = self.layer1[0](x) 506 | x = self.layer1[1].input_to_conv2(x) 507 | return x 508 | 509 | def layer1_1_conv2_to_output(self, x, residual): 510 | x = self.layer1[1].conv2_to_output(x, residual) 511 | x = self.layer2(x) 512 | x = self.layer3(x) 513 | x = self.layer4(x) 514 | x = self.avgpool(x) 515 | x = x.view(x.size(0), -1) 516 | x = self.fc(x) 517 | return x 518 | 519 | def layer1_1_conv2_to_output_mask(self, x, residual,mask,pattern): 520 | x = self.layer1[1].conv2_to_output_mask(x, residual,mask,pattern) 521 | x = self.layer2(x) 522 | x = self.layer3(x) 523 | x = self.layer4(x) 524 | x = self.avgpool(x) 525 | x = x.view(x.size(0), -1) 526 | x = self.fc(x) 527 | return x 528 | 529 | def input_to_layer1_1_conv1(self, x): 530 | x = self.conv1(x) 531 | x = self.bn1(x) 532 | x = self.relu(x) 533 | x = self.layer1[0](x) 534 | x = self.layer1[1].input_to_conv1(x) 535 | return x 536 | 537 | def layer1_1_conv1_to_output(self, x, residual): 538 | x = self.layer1[1].conv1_to_output(x, residual) 539 | x = self.layer2(x) 540 | x = self.layer3(x) 541 | x = self.layer4(x) 542 | x = self.avgpool(x) 543 | x = x.view(x.size(0), -1) 544 | x = self.fc(x) 545 | return x 546 | 547 | ############################################################################# 548 | 549 | ############################################################################# 550 | def input_to_layer2_0_residual(self, x): 551 | x = self.conv1(x) 552 | x = self.bn1(x) 553 | x = self.relu(x) 554 | 555 | x = self.layer1(x) 556 | x = self.layer2[0].input_to_residual(x) 557 | 558 | return x 559 | 560 | def layer2_0_residual_to_output(self, residual, conv2): 561 | 562 | x = self.layer2[0].residual_to_output(residual,conv2) 563 | x = self.layer2[1](x) 564 | x = self.layer3(x) 565 | x = self.layer4(x) 566 | 567 | x = self.avgpool(x) 568 | x = x.view(x.size(0), -1) 569 | x = self.fc(x) 570 | return x 571 | 572 | def input_to_layer2_0_conv2(self, x): 573 | x = self.conv1(x) 574 | x = self.bn1(x) 575 | x = self.relu(x) 576 | x = self.layer1(x) 577 | x = self.layer2[0].input_to_conv2(x) 578 | return x 579 | 580 | def layer2_0_conv2_to_output(self, x, residual): 581 | x = self.layer2[0].conv2_to_output(x, residual) 582 | x = self.layer2[1](x) 583 | x = self.layer3(x) 584 | x = self.layer4(x) 585 | x = self.avgpool(x) 586 | x = x.view(x.size(0), -1) 587 | x = self.fc(x) 588 | return x 589 | 590 | def input_to_layer2_0_conv1(self, x): 591 | x = self.conv1(x) 592 | x = self.bn1(x) 593 | x = self.relu(x) 594 | x = self.layer1(x) 595 | x = self.layer2[0].input_to_conv1(x) 596 | return x 597 | 598 | def layer2_0_conv1_to_output(self, x, residual): 599 | x = self.layer2[0].conv1_to_output(x, residual) 600 | x = self.layer2[1](x) 601 | x = self.layer3(x) 602 | x = self.layer4(x) 603 | x = self.avgpool(x) 604 | x = x.view(x.size(0), -1) 605 | x = self.fc(x) 606 | return x 607 | ############################################################################# 608 | 609 | def input_to_layer2_1_residual(self, x): 610 | x = self.conv1(x) 611 | x = self.bn1(x) 612 | x = self.relu(x) 613 | 614 | x = self.layer1(x) 615 | x = self.layer2[0](x) 616 | x = self.layer2[1].input_to_residual(x) 617 | 618 | return x 619 | 620 | def input_to_layer2_1_conv2(self, x): 621 | x = self.conv1(x) 622 | x = self.bn1(x) 623 | x = self.relu(x) 624 | x = self.layer1(x) 625 | x = self.layer2[0](x) 626 | x = self.layer2[1].input_to_conv2(x) 627 | return x 628 | 629 | def layer2_1_conv2_to_output(self, x, residual): 630 | x = self.layer2[1].conv2_to_output(x, residual) 631 | x = self.layer3(x) 632 | x = self.layer4(x) 633 | x = self.avgpool(x) 634 | x = x.view(x.size(0), -1) 635 | x = self.fc(x) 636 | return x 637 | 638 | 639 | def layer2_1_conv2_to_output_mask(self, x, residual,mask,pattern): 640 | x = self.layer2[1].conv2_to_output_mask(x, residual,mask,pattern) 641 | x = self.layer3(x) 642 | x = self.layer4(x) 643 | x = self.avgpool(x) 644 | x = x.view(x.size(0), -1) 645 | x = self.fc(x) 646 | return x 647 | 648 | def input_to_layer2_1_conv1(self, x): 649 | x = self.conv1(x) 650 | x = self.bn1(x) 651 | x = self.relu(x) 652 | x = self.layer1(x) 653 | x = self.layer2[0](x) 654 | x = self.layer2[1].input_to_conv1(x) 655 | return x 656 | 657 | def layer2_1_conv1_to_output(self, x, residual): 658 | x = self.layer2[1].conv1_to_output(x, residual) 659 | x = self.layer3(x) 660 | x = self.layer4(x) 661 | x = self.avgpool(x) 662 | x = x.view(x.size(0), -1) 663 | x = self.fc(x) 664 | return x 665 | 666 | ############################################################################# 667 | 668 | ############################################################################# 669 | def input_to_layer3_0_residual(self, x): 670 | x = self.conv1(x) 671 | x = self.bn1(x) 672 | x = self.relu(x) 673 | 674 | x = self.layer1(x) 675 | x = self.layer2(x) 676 | x = self.layer3[0].input_to_residual(x) 677 | 678 | return x 679 | 680 | def layer3_0_residual_to_output(self, residual, conv2): 681 | 682 | x = self.layer3[0].residual_to_output(residual,conv2) 683 | x = self.layer3[1](x) 684 | x = self.layer4(x) 685 | 686 | x = self.avgpool(x) 687 | x = x.view(x.size(0), -1) 688 | x = self.fc(x) 689 | return x 690 | 691 | def input_to_layer3_0_conv2(self, x): 692 | x = self.conv1(x) 693 | x = self.bn1(x) 694 | x = self.relu(x) 695 | x = self.layer1(x) 696 | x = self.layer2(x) 697 | x = self.layer3[0].input_to_conv2(x) 698 | return x 699 | 700 | def layer3_0_conv2_to_output(self, x, residual): 701 | x = self.layer3[0].conv2_to_output(x, residual) 702 | x = self.layer3[1](x) 703 | x = self.layer4(x) 704 | x = self.avgpool(x) 705 | x = x.view(x.size(0), -1) 706 | x = self.fc(x) 707 | return x 708 | 709 | def input_to_layer3_0_conv1(self, x): 710 | x = self.conv1(x) 711 | x = self.bn1(x) 712 | x = self.relu(x) 713 | x = self.layer1(x) 714 | x = self.layer2(x) 715 | x = self.layer3[0].input_to_conv1(x) 716 | return x 717 | 718 | def layer3_0_conv1_to_output(self, x, residual): 719 | x = self.layer3[0].conv1_to_output(x, residual) 720 | x = self.layer3[1](x) 721 | x = self.layer4(x) 722 | x = self.avgpool(x) 723 | x = x.view(x.size(0), -1) 724 | x = self.fc(x) 725 | return x 726 | ############################################################################# 727 | 728 | def input_to_layer3_1_residual(self, x): 729 | x = self.conv1(x) 730 | x = self.bn1(x) 731 | x = self.relu(x) 732 | 733 | x = self.layer1(x) 734 | x = self.layer2(x) 735 | x = self.layer3[0](x) 736 | x = self.layer3[1].input_to_residual(x) 737 | 738 | return x 739 | 740 | def input_to_layer3_1_conv2(self, x): 741 | x = self.conv1(x) 742 | x = self.bn1(x) 743 | x = self.relu(x) 744 | x = self.layer1(x) 745 | x = self.layer2(x) 746 | x = self.layer3[0](x) 747 | x = self.layer3[1].input_to_conv2(x) 748 | return x 749 | 750 | def layer3_1_conv2_to_output(self, x, residual): 751 | x = self.layer3[1].conv2_to_output(x, residual) 752 | x = self.layer4(x) 753 | x = self.avgpool(x) 754 | x = x.view(x.size(0), -1) 755 | x = self.fc(x) 756 | return x 757 | 758 | def layer3_1_conv2_to_output_mask(self, x, residual,mask,pattern): 759 | x = self.layer3[1].conv2_to_output_mask(x, residual,mask,pattern) 760 | x = self.layer4(x) 761 | x = self.avgpool(x) 762 | x = x.view(x.size(0), -1) 763 | x = self.fc(x) 764 | return x 765 | 766 | def input_to_layer3_1_conv1(self, x): 767 | x = self.conv1(x) 768 | x = self.bn1(x) 769 | x = self.relu(x) 770 | x = self.layer1(x) 771 | x = self.layer2(x) 772 | x = self.layer3[0](x) 773 | x = self.layer3[1].input_to_conv1(x) 774 | return x 775 | 776 | def layer3_1_conv1_to_output(self, x, residual): 777 | x = self.layer3[1].conv1_to_output(x, residual) 778 | x = self.layer4(x) 779 | x = self.avgpool(x) 780 | x = x.view(x.size(0), -1) 781 | x = self.fc(x) 782 | return x 783 | 784 | ############################################################################# 785 | def input_to_layer4_0_residual(self, x): 786 | x = self.conv1(x) 787 | x = self.bn1(x) 788 | x = self.relu(x) 789 | 790 | x = self.layer1(x) 791 | x = self.layer2(x) 792 | x = self.layer3(x) 793 | x = self.layer4[0].input_to_residual(x) 794 | 795 | return x 796 | 797 | def layer4_0_residual_to_output(self, residual, conv2): 798 | 799 | x = self.layer4[0].residual_to_output(residual,conv2) 800 | x = self.layer4[1](x) 801 | 802 | x = self.avgpool(x) 803 | x = x.view(x.size(0), -1) 804 | x = self.fc(x) 805 | return x 806 | 807 | def input_to_layer4_0_conv2(self, x): 808 | x = self.conv1(x) 809 | x = self.bn1(x) 810 | x = self.relu(x) 811 | x = self.layer1(x) 812 | x = self.layer2(x) 813 | x = self.layer3(x) 814 | x = self.layer4[0].input_to_conv2(x) 815 | return x 816 | 817 | def layer4_0_conv2_to_output(self, x, residual): 818 | x = self.layer4[0].conv2_to_output(x, residual) 819 | x = self.layer4[1](x) 820 | x = self.avgpool(x) 821 | x = x.view(x.size(0), -1) 822 | x = self.fc(x) 823 | return x 824 | 825 | def input_to_layer4_0_conv1(self, x): 826 | x = self.conv1(x) 827 | x = self.bn1(x) 828 | x = self.relu(x) 829 | x = self.layer1(x) 830 | x = self.layer2(x) 831 | x = self.layer3(x) 832 | x = self.layer4[0].input_to_conv1(x) 833 | return x 834 | 835 | def layer4_0_conv1_to_output(self, x, residual): 836 | x = self.layer4[0].conv1_to_output(x, residual) 837 | x = self.layer4[1](x) 838 | x = self.avgpool(x) 839 | x = x.view(x.size(0), -1) 840 | x = self.fc(x) 841 | return x 842 | ############################################################################# 843 | def input_to_layer4_1_residual(self, x): 844 | x = self.conv1(x) 845 | x = self.bn1(x) 846 | x = self.relu(x) 847 | 848 | x = self.layer1(x) 849 | x = self.layer2(x) 850 | x = self.layer3(x) 851 | x = self.layer4[0](x) 852 | x = self.layer4[1].input_to_residual(x) 853 | 854 | return x 855 | 856 | def input_to_layer4_1_conv2(self, x): 857 | x = self.conv1(x) 858 | x = self.bn1(x) 859 | x = self.relu(x) 860 | x = self.layer1(x) 861 | x = self.layer2(x) 862 | x = self.layer3(x) 863 | x = self.layer4[0](x) 864 | x = self.layer4[1].input_to_conv2(x) 865 | return x 866 | 867 | def layer4_1_conv2_to_output(self, x, residual): 868 | x = self.layer4[1].conv2_to_output(x, residual) 869 | x = self.avgpool(x) 870 | x = x.view(x.size(0), -1) 871 | x = self.fc(x) 872 | return x 873 | 874 | def layer4_1_conv2_to_output_mask(self, x, residual,mask,pattern): 875 | x = self.layer4[1].conv2_to_output_mask(x, residual,mask,pattern) 876 | x = self.avgpool(x) 877 | x = x.view(x.size(0), -1) 878 | x = self.fc(x) 879 | return x 880 | 881 | def input_to_layer4_1_conv1(self, x): 882 | x = self.conv1(x) 883 | x = self.bn1(x) 884 | x = self.relu(x) 885 | x = self.layer1(x) 886 | x = self.layer2(x) 887 | x = self.layer3(x) 888 | x = self.layer4[0](x) 889 | x = self.layer4[1].input_to_conv1(x) 890 | return x 891 | 892 | def layer4_1_conv1_to_output(self, x, residual): 893 | x = self.layer4[1].conv1_to_output(x, residual) 894 | x = self.avgpool(x) 895 | x = x.view(x.size(0), -1) 896 | x = self.fc(x) 897 | return x 898 | ############################################################################# 899 | 900 | def resnet18(**kwargs): 901 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 902 | 903 | 904 | def resnet34(**kwargs): 905 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 906 | 907 | 908 | def resnet50(**kwargs): 909 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 910 | 911 | 912 | def resnet101(**kwargs): 913 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 914 | 915 | 916 | def resnet152(**kwargs): 917 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) -------------------------------------------------------------------------------- /train_models/train_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | from time import time 5 | 6 | import config 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision 11 | from torch import nn 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torchvision.transforms import RandomErasing 14 | from dataloader import PostTensorTransform, get_dataloader,get_dataloader_random_ratio 15 | from resnet_nole import * 16 | 17 | import random 18 | 19 | 20 | class Normalize: 21 | def __init__(self, opt, expected_values, variance): 22 | self.n_channels = opt.input_channel 23 | self.expected_values = expected_values 24 | self.variance = variance 25 | assert self.n_channels == len(self.expected_values) 26 | 27 | def __call__(self, x): 28 | x_clone = x.clone() 29 | for channel in range(self.n_channels): 30 | x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel] 31 | return x_clone 32 | 33 | 34 | class Denormalize: 35 | def __init__(self, opt, expected_values, variance): 36 | self.n_channels = opt.input_channel 37 | self.expected_values = expected_values 38 | self.variance = variance 39 | assert self.n_channels == len(self.expected_values) 40 | 41 | def __call__(self, x): 42 | x_clone = x.clone() 43 | for channel in range(self.n_channels): 44 | x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel] 45 | return x_clone 46 | 47 | 48 | class Normalizer: 49 | def __init__(self, opt): 50 | self.normalizer = self._get_normalizer(opt) 51 | 52 | def _get_normalizer(self, opt): 53 | if opt.dataset == "cifar10": 54 | normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 55 | else: 56 | raise Exception("Invalid dataset") 57 | return normalizer 58 | 59 | def __call__(self, x): 60 | if self.normalizer: 61 | x = self.normalizer(x) 62 | return x 63 | 64 | 65 | class Denormalizer: 66 | def __init__(self, opt): 67 | self.denormalizer = self._get_denormalizer(opt) 68 | 69 | def _get_denormalizer(self, opt): 70 | print(opt.dataset) 71 | if opt.dataset == "cifar10": 72 | denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 73 | else: 74 | raise Exception("Invalid dataset") 75 | return denormalizer 76 | 77 | def __call__(self, x): 78 | if self.denormalizer: 79 | x = self.denormalizer(x) 80 | return x 81 | 82 | 83 | def get_model(opt): 84 | netC = None 85 | optimizerC = None 86 | schedulerC = None 87 | 88 | if opt.set_arch: 89 | 90 | if opt.set_arch=="resnet18": 91 | netC = resnet18(num_classes = opt.num_classes, in_channels = opt.input_channel) 92 | netC = netC.to(opt.device) 93 | 94 | optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4) 95 | 96 | schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda) 97 | 98 | return netC, optimizerC, schedulerC 99 | 100 | 101 | def train(train_transform, netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt): 102 | print(" Train:") 103 | 104 | netC.train() 105 | rate_bd = opt.pc 106 | total_loss_ce = 0 107 | total_sample = 0 108 | 109 | total_clean = 0 110 | total_bd = 0 111 | total_cross = 0 112 | total_clean_correct = 0 113 | total_bd_correct = 0 114 | total_cross_correct = 0 115 | criterion_CE = torch.nn.CrossEntropyLoss() 116 | criterion_BCE = torch.nn.BCELoss() 117 | 118 | denormalizer = Denormalizer(opt) 119 | transforms = PostTensorTransform(opt).to(opt.device) 120 | total_time = 0 121 | 122 | avg_acc_cross = 0 123 | 124 | for batch_idx, (inputs, targets) in enumerate(train_dl): 125 | optimizerC.zero_grad() 126 | 127 | inputs, targets = inputs.to(opt.device), targets.to(opt.device) 128 | bs = inputs.shape[0] 129 | 130 | num_bd = int(bs * rate_bd) 131 | num_cross = int(num_bd * opt.cross_ratio) 132 | grid_temps = (identity_grid + opt.s * noise_grid / opt.input_height) * opt.grid_rescale 133 | grid_temps = torch.clamp(grid_temps, -1, 1) 134 | 135 | ins = torch.rand(num_cross, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1 136 | grid_temps2 = grid_temps.repeat(num_cross, 1, 1, 1) + ins / opt.input_height 137 | grid_temps2 = torch.clamp(grid_temps2, -1, 1) 138 | 139 | if num_bd!=0: 140 | 141 | inputs_bd = F.grid_sample(inputs[:num_bd], grid_temps.repeat(num_bd, 1, 1, 1), align_corners=True) 142 | 143 | if opt.attack_mode == "all2one": 144 | targets_bd = torch.ones_like(targets[:num_bd]) * opt.target_label 145 | if opt.attack_mode == "all2all": 146 | targets_bd = torch.remainder(targets[:num_bd] + 1, opt.num_classes) 147 | 148 | inputs_cross = F.grid_sample(inputs[num_bd : (num_bd + num_cross)], grid_temps2, align_corners=True) 149 | 150 | if (num_bd==0 and num_cross==0): 151 | total_inputs = inputs 152 | total_targets = targets 153 | else: 154 | total_inputs = torch.cat([inputs_bd, inputs_cross, inputs[(num_bd + num_cross) :]], dim=0) 155 | total_targets = torch.cat([targets_bd, targets[num_bd:]], dim=0) 156 | 157 | total_inputs = transforms(total_inputs) 158 | start = time() 159 | total_preds = netC(total_inputs) 160 | total_time += time() - start 161 | 162 | loss_ce = criterion_CE(total_preds, total_targets) 163 | 164 | loss = loss_ce 165 | loss.backward() 166 | 167 | optimizerC.step() 168 | 169 | total_sample += bs 170 | total_loss_ce += loss_ce.detach() 171 | 172 | total_clean += bs - num_bd - num_cross 173 | total_bd += num_bd 174 | total_cross += num_cross 175 | total_clean_correct += torch.sum( 176 | torch.argmax(total_preds[(num_bd + num_cross) :], dim=1) == total_targets[(num_bd + num_cross) :] 177 | ) 178 | if num_bd: 179 | total_bd_correct += torch.sum(torch.argmax(total_preds[:num_bd], dim=1) == targets_bd) 180 | avg_acc_bd = total_bd_correct * 100.0 / total_bd 181 | else: 182 | avg_acc_bd = 0 183 | 184 | if num_cross: 185 | total_cross_correct += torch.sum( 186 | torch.argmax(total_preds[num_bd : (num_bd + num_cross)], dim=1) 187 | == total_targets[num_bd : (num_bd + num_cross)] 188 | ) 189 | avg_acc_cross = total_cross_correct * 100.0 / total_cross 190 | else: 191 | avg_acc_cross = 0 192 | 193 | avg_acc_clean = total_clean_correct * 100.0 / total_clean 194 | avg_loss_ce = total_loss_ce / total_sample 195 | 196 | # Save image for debugging 197 | if not batch_idx % 50: 198 | if not os.path.exists(opt.temps): 199 | os.makedirs(opt.temps) 200 | 201 | path = os.path.join(opt.temps, "backdoor_image.png") 202 | path_cross = os.path.join(opt.temps, "cross_image.png") 203 | if num_bd>0: 204 | torchvision.utils.save_image(inputs_bd, path, normalize=True) 205 | if num_cross>0: 206 | torchvision.utils.save_image(inputs_cross, path_cross, normalize=True) 207 | 208 | if (num_bd>0 and num_cross==0): 209 | print( 210 | batch_idx, 211 | len(train_dl), 212 | "CE Loss: {:.4f} | Clean Acc: {:.4f} | Bd Acc: {:.4f}".format( 213 | avg_loss_ce, avg_acc_clean, avg_acc_bd, 214 | )) 215 | if (num_bd>0 and num_cross>0): 216 | print( 217 | batch_idx, 218 | len(train_dl), 219 | "CE Loss: {:.4f} | Clean Acc: {:.4f} | Bd Acc: {:.4f} | Cross Acc: {:.4f}".format( 220 | avg_loss_ce, avg_acc_clean, avg_acc_bd, avg_acc_cross 221 | )) 222 | else: 223 | print( 224 | batch_idx, 225 | len(train_dl), 226 | "CE Loss: {:.4f} | Clean Acc: {:.4f}".format(avg_loss_ce, avg_acc_clean)) 227 | # Image for tensorboard 228 | if batch_idx == len(train_dl) - 2: 229 | if num_bd>0: 230 | residual = inputs_bd - inputs[:num_bd] 231 | batch_img = torch.cat([inputs[:num_bd], inputs_bd, total_inputs[:num_bd], residual], dim=2) 232 | batch_img = denormalizer(batch_img) 233 | batch_img = F.upsample(batch_img, scale_factor=(4, 4)) 234 | grid = torchvision.utils.make_grid(batch_img, normalize=True) 235 | path = os.path.join(opt.temps, "batch_img.png") 236 | torchvision.utils.save_image(batch_img, path, normalize=True) 237 | 238 | # for tensorboard 239 | if not epoch % 1: 240 | tf_writer.add_scalars( 241 | "Clean Accuracy", {"Clean": avg_acc_clean, "Bd": avg_acc_bd, "Cross": avg_acc_cross}, epoch 242 | ) 243 | if num_bd>0: 244 | tf_writer.add_image("Images", grid, global_step=epoch) 245 | 246 | schedulerC.step() 247 | 248 | 249 | def eval( 250 | test_transform, 251 | netC, 252 | optimizerC, 253 | schedulerC, 254 | test_dl, 255 | noise_grid, 256 | identity_grid, 257 | best_clean_acc, 258 | best_bd_acc, 259 | best_cross_acc, 260 | tf_writer, 261 | epoch, 262 | opt, 263 | ): 264 | print(" Eval:") 265 | 266 | netC.eval() 267 | 268 | total_sample = 0 269 | total_clean_correct = 0 270 | total_bd_correct = 0 271 | total_cross_correct = 0 272 | total_ae_loss = 0 273 | 274 | criterion_BCE = torch.nn.BCELoss() 275 | 276 | for batch_idx, (inputs, targets) in enumerate(test_dl): 277 | with torch.no_grad(): 278 | inputs, targets = inputs.to(opt.device), targets.to(opt.device) 279 | #inputs = test_transform(inputs) 280 | bs = inputs.shape[0] 281 | total_sample += bs 282 | 283 | # Evaluate Clean 284 | preds_clean = netC(inputs) 285 | total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets) 286 | 287 | # Evaluate Backdoor 288 | grid_temps = (identity_grid + opt.s * noise_grid / opt.input_height) * opt.grid_rescale 289 | grid_temps = torch.clamp(grid_temps, -1, 1) 290 | 291 | ins = torch.rand(bs, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1 292 | grid_temps2 = grid_temps.repeat(bs, 1, 1, 1) + ins / opt.input_height 293 | grid_temps2 = torch.clamp(grid_temps2, -1, 1) 294 | 295 | inputs_bd = F.grid_sample(inputs, grid_temps.repeat(bs, 1, 1, 1), align_corners=True) 296 | 297 | if opt.attack_mode == "all2one": 298 | targets_bd = torch.ones_like(targets) * opt.target_label 299 | if opt.attack_mode == "all2all": 300 | targets_bd = torch.remainder(targets + 1, opt.num_classes) 301 | 302 | preds_bd = netC(inputs_bd) 303 | total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd) 304 | 305 | acc_clean = total_clean_correct * 100.0 / total_sample 306 | acc_bd = total_bd_correct * 100.0 / total_sample 307 | 308 | # Evaluate cross 309 | if opt.cross_ratio: 310 | inputs_cross = F.grid_sample(inputs, grid_temps2, align_corners=True) 311 | preds_cross = netC(inputs_cross) 312 | total_cross_correct += torch.sum(torch.argmax(preds_cross, 1) == targets) 313 | 314 | acc_cross = total_cross_correct * 100.0 / total_sample 315 | 316 | info_string = ( 317 | "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f} | Cross: {:.4f}".format( 318 | acc_clean, best_clean_acc, acc_bd, best_bd_acc, acc_cross, best_cross_acc 319 | ) 320 | ) 321 | else: 322 | info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format( 323 | acc_clean, best_clean_acc, acc_bd, best_bd_acc 324 | ) 325 | print(batch_idx, len(test_dl), info_string) 326 | 327 | 328 | # tensorboard 329 | if not epoch % 1: 330 | tf_writer.add_scalars("Test Accuracy", {"Clean": acc_clean, "Bd": acc_bd}, epoch) 331 | 332 | # Save checkpoint 333 | if acc_clean > best_clean_acc or (acc_clean > best_clean_acc - 0.1 and acc_bd > best_bd_acc): 334 | print(" Saving...") 335 | best_clean_acc = acc_clean 336 | best_bd_acc = acc_bd 337 | if opt.cross_ratio: 338 | best_cross_acc = acc_cross 339 | else: 340 | best_cross_acc = torch.tensor([0]) 341 | state_dict = { 342 | "netC": netC.state_dict(), 343 | "schedulerC": schedulerC.state_dict(), 344 | "optimizerC": optimizerC.state_dict(), 345 | "best_clean_acc": best_clean_acc, 346 | "best_bd_acc": best_bd_acc, 347 | "best_cross_acc": best_cross_acc, 348 | "epoch_current": epoch, 349 | "identity_grid": identity_grid, 350 | "noise_grid": noise_grid, 351 | } 352 | torch.save(state_dict, opt.ckpt_path) 353 | with open(os.path.join(opt.ckpt_folder, "results.txt"), "w+") as f: 354 | results_dict = { 355 | "clean_acc": best_clean_acc.item(), 356 | "bd_acc": best_bd_acc.item(), 357 | "cross_acc": best_cross_acc.item(), 358 | } 359 | json.dump(results_dict, f, indent=2) 360 | 361 | return best_clean_acc, best_bd_acc, best_cross_acc 362 | 363 | 364 | def main(): 365 | opt = config.get_arguments().parse_args() 366 | 367 | if opt.dataset in ["cifar10"]: 368 | opt.num_classes = 10 369 | 370 | 371 | if opt.dataset == "cifar10": 372 | opt.input_height = 32 373 | opt.input_width = 32 374 | opt.input_channel = 3 375 | 376 | 377 | # Dataset 378 | 379 | opt.random_ratio = 0.95 380 | train_dl, train_transform = get_dataloader_random_ratio(opt, True) 381 | test_dl, test_transform = get_dataloader(opt, False) 382 | 383 | # prepare model 384 | netC, optimizerC, schedulerC = get_model(opt) 385 | 386 | # Load pretrained model 387 | mode = opt.attack_mode 388 | opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) 389 | if opt.set_arch: 390 | opt.ckpt_folder = opt.ckpt_folder + "/neurips_wanet/" + opt.set_arch + "_" + opt.extra_flag+"_"+str(opt.target_label) 391 | else: 392 | opt.ckpt_folder = opt.ckpt_folder + "/neurips_wanet/" + opt.extra_flag+"_"+str(opt.target_label) 393 | opt.ckpt_path = os.path.join(opt.ckpt_folder, "{}_{}_morph_wanet.pth.tar".format(opt.dataset, mode)) 394 | opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") 395 | if not os.path.exists(opt.log_dir): 396 | os.makedirs(opt.log_dir) 397 | 398 | if opt.continue_training: 399 | if os.path.exists(opt.ckpt_path): 400 | print("Continue training!!") 401 | state_dict = torch.load(opt.ckpt_path) 402 | netC.load_state_dict(state_dict["netC"]) 403 | optimizerC.load_state_dict(state_dict["optimizerC"]) 404 | schedulerC.load_state_dict(state_dict["schedulerC"]) 405 | best_clean_acc = state_dict["best_clean_acc"] 406 | best_bd_acc = state_dict["best_bd_acc"] 407 | best_cross_acc = state_dict["best_cross_acc"] 408 | epoch_current = state_dict["epoch_current"] 409 | identity_grid = state_dict["identity_grid"] 410 | noise_grid = state_dict["noise_grid"] 411 | tf_writer = SummaryWriter(log_dir=opt.log_dir) 412 | else: 413 | print("Pretrained model doesnt exist") 414 | exit() 415 | else: 416 | print("Train from scratch!!!") 417 | best_clean_acc = 0.0 418 | best_bd_acc = 0.0 419 | best_cross_acc = 0.0 420 | epoch_current = 0 421 | 422 | # Prepare grid 423 | ins = torch.rand(1, 2, opt.k, opt.k) * 2 - 1 424 | ins = ins / torch.mean(torch.abs(ins)) 425 | noise_grid = ( 426 | F.upsample(ins, size=opt.input_height, mode="bicubic", align_corners=True) 427 | .permute(0, 2, 3, 1) 428 | .to(opt.device) 429 | ) 430 | array1d = torch.linspace(-1, 1, steps=opt.input_height) 431 | x, y = torch.meshgrid(array1d, array1d) 432 | identity_grid = torch.stack((y, x), 2)[None, ...].to(opt.device) 433 | 434 | shutil.rmtree(opt.ckpt_folder, ignore_errors=True) 435 | os.makedirs(opt.log_dir) 436 | with open(os.path.join(opt.ckpt_folder, "opt.json"), "w+") as f: 437 | json.dump(opt.__dict__, f, indent=2) 438 | tf_writer = SummaryWriter(log_dir=opt.log_dir) 439 | 440 | 441 | for epoch in range(epoch_current, opt.n_iters): 442 | print("Epoch {}:".format(epoch + 1)) 443 | train(train_transform,netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt) 444 | best_clean_acc, best_bd_acc, best_cross_acc = eval( 445 | test_transform, 446 | netC, 447 | optimizerC, 448 | schedulerC, 449 | test_dl, 450 | noise_grid, 451 | identity_grid, 452 | best_clean_acc, 453 | best_bd_acc, 454 | best_cross_acc, 455 | tf_writer, 456 | epoch, 457 | opt, 458 | ) 459 | 460 | if opt.save_all: 461 | if (epoch)%opt.save_freq == 0: 462 | state_dict = { 463 | "netC": netC.state_dict(), 464 | "schedulerC": schedulerC.state_dict(), 465 | "optimizerC": optimizerC.state_dict(), 466 | "epoch_current": epoch, 467 | } 468 | epoch_path = os.path.join(opt.ckpt_folder, "{}_{}_epoch{}.pth.tar".format(opt.dataset, mode,epoch)) 469 | torch.save(state_dict, epoch_path) 470 | 471 | 472 | if __name__ == "__main__": 473 | main() 474 | -------------------------------------------------------------------------------- /unet_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class definitions for a standard U-Net Up-and Down-sampling blocks 3 | http://arxiv.org/abs/1505.0.397 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class EncoderBlock(nn.Module): 11 | """ 12 | Instances the Encoder block that forms a part of a U-Net 13 | Parameters: 14 | in_channels (int): Depth (or number of channels) of the tensor that the block acts on 15 | filter_num (int) : Number of filters used in the convolution ops inside the block, 16 | depth of the output of the enc block 17 | dropout(bool) : Flag to decide whether a dropout layer should be applied 18 | dropout_rate (float) : Probability of dropping a convolution output feature channel 19 | """ 20 | def __init__(self, filter_num=64, in_channels=1, dropout=False, dropout_rate=0.3): 21 | 22 | super(EncoderBlock,self).__init__() 23 | self.filter_num = int(filter_num) 24 | self.in_channels = int(in_channels) 25 | self.dropout = dropout 26 | self.dropout_rate = dropout_rate 27 | 28 | self.conv1 = nn.Conv2d(in_channels=self.in_channels, 29 | out_channels=self.filter_num, 30 | kernel_size=3, 31 | padding=1) 32 | 33 | self.conv2 = nn.Conv2d(in_channels=self.filter_num, 34 | out_channels=self.filter_num, 35 | kernel_size=3, 36 | padding=1) 37 | 38 | self.bn_op_1 = nn.InstanceNorm2d(num_features=self.filter_num, affine=True) 39 | self.bn_op_2 = nn.InstanceNorm2d(num_features=self.filter_num, affine=True) 40 | 41 | # Use Dropout ops as nn.Module instead of nn.functional definition 42 | # So using .train() and .eval() flags, can modify their behavior for MC-Dropout 43 | if dropout is True: 44 | self.dropout_1 = nn.Dropout(p=dropout_rate) 45 | self.dropout_2 = nn.Dropout(p=dropout_rate) 46 | 47 | def apply_manual_dropout_mask(self, x, seed): 48 | # Mask size : [Batch_size, Channels, Height, Width] 49 | dropout_mask = torch.bernoulli(input=torch.empty(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).fill_(self.dropout_rate), 50 | generator=torch.Generator().manual_seed(seed)) 51 | 52 | x = x*dropout_mask.to(x.device) 53 | 54 | return x 55 | 56 | def forward(self, x, seeds=None): 57 | 58 | if seeds is not None: 59 | assert(seeds.shape[0] == 2) 60 | 61 | x = self.conv1(x) 62 | x = self.bn_op_1(x) 63 | x = F.leaky_relu(x) 64 | if self.dropout is True: 65 | if seeds is None: 66 | x = self.dropout_1(x) 67 | else: 68 | x = self.apply_manual_dropout_mask(x, seeds[0].item()) 69 | 70 | x = self.conv2(x) 71 | x = self.bn_op_2(x) 72 | x = F.leaky_relu(x) 73 | if self.dropout is True: 74 | if seeds is None: 75 | x = self.dropout_2(x) 76 | else: 77 | x = self.apply_manual_dropout_mask(x, seeds[1].item()) 78 | 79 | return x 80 | 81 | 82 | class DecoderBlock(nn.Module): 83 | """ 84 | Decoder block used in the U-Net 85 | Parameters: 86 | in_channels (int) : Number of channels of the incoming tensor for the upsampling op 87 | concat_layer_depth (int) : Number of channels to be concatenated via skip connections 88 | filter_num (int) : Number of filters used in convolution, the depth of the output of the dec block 89 | interpolate (bool) : Decides if upsampling needs to performed via interpolation or transposed convolution 90 | dropout(bool) : Flag to decide whether a dropout layer should be applied 91 | dropout_rate (float) : Probability of dropping a convolution output feature channel 92 | """ 93 | def __init__(self, in_channels, concat_layer_depth, filter_num, interpolate=False, dropout=False, dropout_rate=0.3): 94 | 95 | # Up-sampling (interpolation or transposed conv) --> EncoderBlock 96 | super(DecoderBlock, self).__init__() 97 | self.filter_num = int(filter_num) 98 | self.in_channels = int(in_channels) 99 | self.concat_layer_depth = int(concat_layer_depth) 100 | self.interpolate = interpolate 101 | self.dropout = dropout 102 | self.dropout_rate = dropout_rate 103 | 104 | # Upsample by interpolation followed by a 3x3 convolution to obtain desired depth 105 | self.up_sample_interpolate = nn.Sequential(nn.Upsample(scale_factor=2, 106 | mode='bilinear', 107 | align_corners=True), 108 | 109 | nn.Conv2d(in_channels=self.in_channels, 110 | out_channels=self.in_channels, 111 | kernel_size=3, 112 | padding=1) 113 | ) 114 | 115 | # Upsample via transposed convolution (know to produce artifacts) 116 | self.up_sample_tranposed = nn.ConvTranspose2d(in_channels=self.in_channels, 117 | out_channels=self.in_channels, 118 | kernel_size=3, 119 | stride=2, 120 | padding=1, 121 | output_padding=1) 122 | 123 | self.down_sample = EncoderBlock(in_channels=self.in_channels+self.concat_layer_depth, 124 | filter_num=self.filter_num, 125 | dropout=self.dropout, 126 | dropout_rate=self.dropout_rate) 127 | 128 | def forward(self, x, skip_layer, seeds=None): 129 | if self.interpolate is True: 130 | up_sample_out = F.leaky_relu(self.up_sample_interpolate(x)) 131 | else: 132 | up_sample_out = F.leaky_relu(self.up_sample_tranposed(x)) 133 | 134 | merged_out = torch.cat([up_sample_out, skip_layer], dim=1) 135 | out = self.down_sample(merged_out, seeds=seeds) 136 | return out 137 | 138 | 139 | class EncoderBlock3D(nn.Module): 140 | 141 | """ 142 | Instances the 3D Encoder block that forms a part of a 3D U-Net 143 | Parameters: 144 | in_channels (int): Depth (or number of channels) of the tensor that the block acts on 145 | filter_num (int) : Number of filters used in the convolution ops inside the block, 146 | depth of the output of the enc block 147 | """ 148 | def __init__(self, filter_num=64, in_channels=1, dropout=False): 149 | 150 | super(EncoderBlock3D, self).__init__() 151 | self.filter_num = int(filter_num) 152 | self.in_channels = int(in_channels) 153 | self.dropout = dropout 154 | 155 | self.conv1 = nn.Conv3d(in_channels=self.in_channels, 156 | out_channels=self.filter_num, 157 | kernel_size=3, 158 | padding=1) 159 | 160 | self.conv2 = nn.Conv3d(in_channels=self.filter_num, 161 | out_channels=self.filter_num*2, 162 | kernel_size=3, 163 | padding=1) 164 | 165 | self.bn_op_1 = nn.InstanceNorm3d(num_features=self.filter_num) 166 | self.bn_op_2 = nn.InstanceNorm3d(num_features=self.filter_num*2) 167 | 168 | def forward(self, x): 169 | 170 | x = self.conv1(x) 171 | x = self.bn_op_1(x) 172 | x = F.leaky_relu(x) 173 | if self.dropout is True: 174 | x = F.dropout3d(x, p=0.3) 175 | 176 | x = self.conv2(x) 177 | x = self.bn_op_2(x) 178 | x = F.leaky_relu(x) 179 | 180 | if self.dropout is True: 181 | x = F.dropout3d(x, p=0.3) 182 | 183 | return x 184 | 185 | 186 | class DecoderBlock3D(nn.Module): 187 | """ 188 | Decoder block used in the 3D U-Net 189 | Parameters: 190 | in_channels (int) : Number of channels of the incoming tensor for the upsampling op 191 | concat_layer_depth (int) : Number of channels to be concatenated via skip connections 192 | filter_num (int) : Number of filters used in convolution, the depth of the output of the dec block 193 | interpolate (bool) : Decides if upsampling needs to performed via interpolation or transposed convolution 194 | """ 195 | def __init__(self, in_channels, concat_layer_depth, filter_num, interpolate=False, dropout=False): 196 | 197 | super(DecoderBlock3D, self).__init__() 198 | self.filter_num = int(filter_num) 199 | self.in_channels = int(in_channels) 200 | self.concat_layer_depth = int(concat_layer_depth) 201 | self.interpolate = interpolate 202 | self.dropout = dropout 203 | 204 | # Upsample by interpolation followed by a 3x3x3 convolution to obtain desired depth 205 | self.up_sample_interpolate = nn.Sequential(nn.Upsample(scale_factor=2, 206 | mode='nearest'), 207 | 208 | nn.Conv3d(in_channels=self.in_channels, 209 | out_channels=self.in_channels, 210 | kernel_size=3, 211 | padding=1) 212 | ) 213 | 214 | # Upsample via transposed convolution (know to produce artifacts) 215 | self.up_sample_transposed = nn.ConvTranspose3d(in_channels=self.in_channels, 216 | out_channels=self.in_channels, 217 | kernel_size=3, 218 | stride=2, 219 | padding=1, 220 | output_padding=1) 221 | 222 | if self.dropout is True: 223 | self.down_sample = nn.Sequential(nn.Conv3d(in_channels=self.in_channels+self.concat_layer_depth, 224 | out_channels=self.filter_num, 225 | kernel_size=3, 226 | padding=1), 227 | 228 | nn.InstanceNorm3d(num_features=self.filter_num), 229 | 230 | nn.LeakyReLU(), 231 | 232 | nn.Dropout3d(p=0.3), 233 | 234 | nn.Conv3d(in_channels=self.filter_num, 235 | out_channels=self.filter_num, 236 | kernel_size=3, 237 | padding=1), 238 | 239 | nn.InstanceNorm3d(num_features=self.filter_num), 240 | 241 | nn.LeakyReLU(), 242 | 243 | nn.Dropout3d(p=0.3)) 244 | else: 245 | self.down_sample = nn.Sequential(nn.Conv3d(in_channels=self.in_channels+self.concat_layer_depth, 246 | out_channels=self.filter_num, 247 | kernel_size=3, 248 | padding=1), 249 | 250 | nn.InstanceNorm3d(num_features=self.filter_num), 251 | 252 | nn.LeakyReLU(), 253 | 254 | nn.Conv3d(in_channels=self.filter_num, 255 | out_channels=self.filter_num, 256 | kernel_size=3, 257 | padding=1), 258 | 259 | nn.InstanceNorm3d(num_features=self.filter_num), 260 | 261 | nn.LeakyReLU()) 262 | 263 | def forward(self, x, skip_layer): 264 | 265 | if self.interpolate is True: 266 | up_sample_out = F.leaky_relu(self.up_sample_interpolate(x)) 267 | else: 268 | up_sample_out = F.leaky_relu(self.up_sample_transposed(x)) 269 | 270 | merged_out = torch.cat([up_sample_out, skip_layer], dim=1) 271 | out = self.down_sample(merged_out) 272 | return out -------------------------------------------------------------------------------- /unet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | A PyTorch Implementation of a U-Net. 3 | Supports 2D (https://arxiv.org/abs/1505.04597) and 3D(https://arxiv.org/abs/1606.06650) variants 4 | Author: Ishaan Bhat 5 | Email: ishaan@isi.uu.nl 6 | """ 7 | from unet_blocks import * 8 | from math import pow 9 | 10 | 11 | class UNet(nn.Module): 12 | """ 13 | PyTorch class definition for the U-Net architecture for image segmentation 14 | Parameters: 15 | n_channels (int) : Number of image channels 16 | base_filter_num (int) : Number of filters for the first convolution (doubled for every subsequent block) 17 | num_blocks (int) : Number of encoder/decoder blocks 18 | num_classes(int) : Number of classes that need to be segmented 19 | mode (str): 2D or 3D 20 | use_pooling (bool): Set to 'True' to use MaxPool as downnsampling op. 21 | If 'False', strided convolution would be used to downsample feature maps (http://arxiv.org/abs/1908.02182) 22 | dropout (bool) : Whether dropout should be added to central encoder and decoder blocks (eg: BayesianSegNet) 23 | dropout_rate (float) : Dropout probability 24 | Returns: 25 | out (torch.Tensor) : Prediction of the segmentation map 26 | """ 27 | def __init__(self, n_channels=1, base_filter_num=64, num_blocks=4, num_classes=5, mode='2D', dropout=False, dropout_rate=0.3, use_pooling=True): 28 | 29 | super(UNet, self).__init__() 30 | self.contracting_path = nn.ModuleList() 31 | self.expanding_path = nn.ModuleList() 32 | self.downsampling_ops = nn.ModuleList() 33 | 34 | self.num_blocks = num_blocks 35 | self.n_channels = int(n_channels) 36 | self.n_classes = int(num_classes) 37 | self.base_filter_num = int(base_filter_num) 38 | self.enc_layer_depths = [] # Keep track of the output depths of each encoder block 39 | self.mode = mode 40 | self.pooling = use_pooling 41 | self.dropout = dropout 42 | self.dropout_rate = dropout_rate 43 | 44 | if mode == '2D': 45 | self.encoder = EncoderBlock 46 | self.decoder = DecoderBlock 47 | self.pool = nn.MaxPool2d 48 | 49 | elif mode == '3D': 50 | self.encoder = EncoderBlock3D 51 | self.decoder = DecoderBlock3D 52 | self.pool = nn.MaxPool3d 53 | else: 54 | print('{} mode is invalid'.format(mode)) 55 | 56 | for block_id in range(num_blocks): 57 | # Due to GPU mem constraints, we cap the filter depth at 512 58 | enc_block_filter_num = min(int(pow(2, block_id)*self.base_filter_num), 512) # Output depth of current encoder stage of the 2-D variant 59 | if block_id == 0: 60 | enc_in_channels = self.n_channels 61 | else: 62 | if self.mode == '2D': 63 | if int(pow(2, block_id)*self.base_filter_num) <= 512: 64 | enc_in_channels = enc_block_filter_num//2 65 | else: 66 | enc_in_channels = 512 67 | else: 68 | enc_in_channels = enc_block_filter_num # In the 3D UNet arch, the encoder features double in the 2nd convolution op 69 | 70 | 71 | # Dropout only applied to central encoder blocks -- See BayesianSegNet by Kendall et al. 72 | if self.dropout is True and block_id >= num_blocks-2: 73 | self.contracting_path.append(self.encoder(in_channels=enc_in_channels, 74 | filter_num=enc_block_filter_num, 75 | dropout=True, 76 | dropout_rate=self.dropout_rate)) 77 | else: 78 | self.contracting_path.append(self.encoder(in_channels=enc_in_channels, 79 | filter_num=enc_block_filter_num, 80 | dropout=False)) 81 | if self.mode == '2D': 82 | self.enc_layer_depths.append(enc_block_filter_num) 83 | if self.pooling is False: 84 | self.downsampling_ops.append(nn.Sequential(nn.Conv2d(in_channels=self.enc_layer_depths[-1], 85 | out_channels=self.enc_layer_depths[-1], 86 | kernel_size=3, 87 | stride=2, 88 | padding=1), 89 | nn.InstanceNorm2d(num_features=self.filter_num), 90 | nn.LeakyReLU())) 91 | else: 92 | self.enc_layer_depths.append(enc_block_filter_num*2) # Specific to 3D U-Net architecture (due to doubling of #feature_maps inside the 3-D Encoder) 93 | if self.pooling is False: 94 | self.downsampling_ops.append(nn.Sequential(nn.Conv3d(in_channels=self.enc_layer_depths[-1], 95 | out_channels=self.enc_layer_depths[-1], 96 | kernel_size=3, 97 | stride=2, 98 | padding=1), 99 | nn.InstanceNorm3d(num_features=self.enc_layer_depths[-1]), 100 | nn.LeakyReLU())) 101 | 102 | # Bottleneck layer 103 | if self.mode == '2D': 104 | bottle_neck_filter_num = self.enc_layer_depths[-1]*2 105 | bottle_neck_in_channels = self.enc_layer_depths[-1] 106 | self.bottle_neck_layer = self.encoder(filter_num=bottle_neck_filter_num, 107 | in_channels=bottle_neck_in_channels) 108 | 109 | else: # Modified for the 3D UNet architecture 110 | bottle_neck_in_channels = self.enc_layer_depths[-1] 111 | bottle_neck_filter_num = self.enc_layer_depths[-1]*2 112 | self.bottle_neck_layer = nn.Sequential(nn.Conv3d(in_channels=bottle_neck_in_channels, 113 | out_channels=bottle_neck_in_channels, 114 | kernel_size=3, 115 | padding=1), 116 | 117 | nn.InstanceNorm3d(num_features=bottle_neck_in_channels), 118 | 119 | nn.LeakyReLU(), 120 | 121 | nn.Conv3d(in_channels=bottle_neck_in_channels, 122 | out_channels=bottle_neck_filter_num, 123 | kernel_size=3, 124 | padding=1), 125 | 126 | nn.InstanceNorm3d(num_features=bottle_neck_filter_num), 127 | 128 | nn.LeakyReLU()) 129 | 130 | # Decoder Path 131 | dec_in_channels = int(bottle_neck_filter_num) 132 | for block_id in range(num_blocks): 133 | if self.dropout is True and block_id < 2: 134 | self.expanding_path.append(self.decoder(in_channels=dec_in_channels, 135 | filter_num=self.enc_layer_depths[-1-block_id], 136 | concat_layer_depth=self.enc_layer_depths[-1-block_id], 137 | interpolate=False, 138 | dropout=True, 139 | dropout_rate=self.dropout_rate)) 140 | else: 141 | self.expanding_path.append(self.decoder(in_channels=dec_in_channels, 142 | filter_num=self.enc_layer_depths[-1-block_id], 143 | concat_layer_depth=self.enc_layer_depths[-1-block_id], 144 | interpolate=False, 145 | dropout=False)) 146 | 147 | dec_in_channels = self.enc_layer_depths[-1-block_id] 148 | 149 | # Output Layer 150 | if mode == '2D': 151 | self.output = nn.Conv2d(in_channels=int(self.enc_layer_depths[0]), 152 | out_channels=self.n_classes, 153 | kernel_size=1) 154 | else: 155 | self.output = nn.Conv3d(in_channels=int(self.enc_layer_depths[0]), 156 | out_channels=self.n_classes, 157 | kernel_size=1) 158 | 159 | def forward(self, x, seeds=None): 160 | 161 | if self.mode == '2D': 162 | h, w = x.shape[-2:] 163 | else: 164 | d, h, w = x.shape[-3:] 165 | 166 | # Encoder 167 | enc_outputs = [] 168 | seed_index = 0 169 | for stage, enc_op in enumerate(self.contracting_path): 170 | if stage >= len(self.contracting_path) - 2: 171 | if seeds is not None: 172 | x = enc_op(x, seeds[seed_index:seed_index+2]) 173 | else: 174 | x = enc_op(x) 175 | seed_index += 2 # 2 seeds required per block 176 | else: 177 | x = enc_op(x) 178 | enc_outputs.append(x) 179 | 180 | if self.pooling is True: 181 | x = self.pool(kernel_size=2)(x) 182 | else: 183 | x = self.downsampling_ops[stage](x) 184 | 185 | # Bottle-neck layer 186 | x = self.bottle_neck_layer(x) 187 | # Decoder 188 | for block_id, dec_op in enumerate(self.expanding_path): 189 | if block_id < 2: 190 | if seeds is not None: 191 | x = dec_op(x, enc_outputs[-1-block_id], seeds[seed_index:seed_index+2]) 192 | else: 193 | x = dec_op(x, enc_outputs[-1-block_id]) 194 | seed_index += 2 195 | else: 196 | x = dec_op(x, enc_outputs[-1-block_id]) 197 | 198 | 199 | # Output 200 | x = self.output(x) 201 | 202 | return x --------------------------------------------------------------------------------