├── LICENSE.md
├── README.md
├── config.py
├── dataloader.py
├── detection.py
├── image
├── overview.pdf
└── overview.png
├── mitigation.py
├── models.py
├── models
├── ULP_model.py
├── __init__.py
├── lenet.py
├── meta_classifier_cifar10_model.py
└── preact_resnet.py
├── requirements.txt
├── resnet_nole.py
├── reverse_engineering.py
├── train_models
├── config.py
├── dataloader.py
├── resnet_nole.py
└── train_model.py
├── unet_blocks.py
└── unet_model.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 RUSSS
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FeatureRE
2 | This repository is the source code for ["Rethinking the Reverse-engineering of Trojan Triggers"](https://arxiv.org/abs/2210.15127) (NeurIPS 2022).
3 |
4 |
5 |

6 |
7 |
8 | Existing reverse-engineering methods only consider the input space constraint. It conducts
9 | reverse-engineering via searching a static trigger pattern in the input space. These methods fail to
10 | reverse-engineer feature-space Trojans whose trigger is dynamic in the input space. Instead, our idea
11 | is to exploit the feature space constraint and searching a feature space trigger using the constraint
12 | that the Trojan features will form a hyperplane. At the same time, we also reverse-engineer the input
13 | space Trojan transformation based on the feature space constraint.
14 |
15 | ## Environment
16 | See requirements.txt
17 |
18 | ## Generating models
19 | Trojaned models can be generated via using the existing code of the attacks:
20 |
21 | - [BadNets] https://github.com/verazuo/badnets-pytorch
22 | - [WaNet] https://github.com/VinAIResearch/Warping-based_Backdoor_Attack-release
23 | - [IA] https://github.com/VinAIResearch/input-aware-backdoor-attack-release
24 | - [CL] https://github.com/MadryLab/label-consistent-backdoor-code
25 | - [Filter] https://github.com/trojai
26 | - [SIG] https://github.com/bboylyg/NAD
27 | - [ISSBA] https://github.com/yuezunli/ISSBA
28 |
29 | For example, to generate Trojaned models by WaNet:
30 | ```bash
31 | cd train_models \
32 | CUDA_VISIBLE_DEVICES=0 python train_model.py --dataset cifar10 --set_arch resnet18 --pc 0.1
33 | ```
34 | To generate benign models:
35 | ```bash
36 | cd train_models \
37 | CUDA_VISIBLE_DEVICES=0 python train_model.py --dataset cifar10 --set_arch resnet18 --pc 0
38 | ```
39 |
40 | ## Detection
41 |
42 | For example, to run FeatureRE detection on CIFAR10 with ResNet18 network:
43 |
44 | ```bash
45 | CUDA_VISIBLE_DEVICES=0 python detection.py \
46 | --dataset cifar10 --set_arch resnet18 \
47 | --hand_set_model_path \
48 | --data_fraction 0.01 \
49 | --lr 1e-3 --bs 256 \
50 | --set_all2one_target all
51 | ```
52 |
53 | ## Mitigation
54 |
55 | For example, to run FeatureRE mitigation on CIFAR10 with ResNet18 network produced by filter attack:
56 |
57 | ```bash
58 | CUDA_VISIBLE_DEVICES=0 python mitigation.py \
59 | --dataset cifar10 --set_arch resnet18 \
60 | --hand_set_model_path \
61 | --data_fraction 0.01 \
62 | --lr 1e-3 --bs 256 \
63 | --set_all2one_target \
64 | --mask_size 0.05 --override_epoch 400 --asr_test_type wanet
65 | ```
66 |
67 | ## Cite this work
68 | You are encouraged to cite the following paper if you use the repo for academic research.
69 |
70 | ```
71 | @inproceedings{wang2022rethinking,
72 | title={Rethinking the Reverse-engineering of Trojan Triggers},
73 | author={Wang, Zhenting and Mei, Kai and Ding, Hailun and Zhai, Juan and Ma, Shiqing},
74 | booktitle={Advances in Neural Information Processing Systems},
75 | year={2022}
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_argument():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Directory option
8 | parser.add_argument("--checkpoints", type=str, default="../../checkpoints/")
9 | parser.add_argument("--data_root", type=str, default="../../data/")
10 | parser.add_argument("--device", type=str, default="cuda")
11 | parser.add_argument("--dataset", type=str, default="mnist")
12 | parser.add_argument("--attack_mode", type=str, default="all2one")
13 |
14 | parser.add_argument("--data_fraction", type=float, default=1.0)
15 |
16 | parser.add_argument("--hand_set_model_path", type=str, default=None)
17 | parser.add_argument("--set_arch", type=str, default=None)
18 | parser.add_argument("--internal_index", type=int, default=None)
19 |
20 | parser.add_argument("--set_all2one_target", type=str, default=None)
21 |
22 | parser.add_argument("--ae_atk_succ_t", type=float, default=0.9)
23 |
24 | parser.add_argument("--ae_filter_num", type=int, default=32)
25 | parser.add_argument("--ae_num_blocks", type=int, default=4)
26 |
27 | parser.add_argument("--mask_size", type=float, default=0.03)
28 | parser.add_argument("--override_epoch", type=int, default=None)
29 | parser.add_argument("--ignore_dist", action='store_true')
30 | parser.add_argument("--p_loss_bound", type=float, default=0.15)
31 | parser.add_argument("--loss_std_bound", type=float, default=1)
32 | parser.add_argument("--asr_test_type", type=str, default="filter")
33 |
34 |
35 | parser.add_argument("--bs", type=int, default=256)
36 | parser.add_argument("--lr", type=float, default=1e-3)
37 | parser.add_argument("--num_workers", type=int, default=8)
38 |
39 | parser.add_argument("--EPSILON", type=float, default=1e-7)
40 | parser.add_argument("--use_norm", type=int, default=1)
41 |
42 | parser.add_argument("--mixed_value_threshold", type=float, default=-0.75)
43 |
44 | return parser
45 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import os
6 | import csv
7 | import random
8 | import numpy as np
9 |
10 | from PIL import Image
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | from torch.utils.data import Dataset
14 |
15 | from io import BytesIO
16 |
17 |
18 | def get_transform(opt, train=True, pretensor_transform=False):
19 | add_nad_transform = False
20 |
21 | transforms_list = []
22 | transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
23 | if pretensor_transform:
24 | if train:
25 | transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
26 | transforms_list.append(transforms.RandomRotation(opt.random_rotation))
27 | if opt.dataset == "cifar10":
28 | transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
29 |
30 | if add_nad_transform:
31 | transforms_list.append(transforms.RandomCrop(opt.input_height, padding=4))
32 | transforms_list.append(transforms.RandomHorizontalFlip())
33 |
34 |
35 | transforms_list.append(transforms.ToTensor())
36 | if opt.dataset == "cifar10":
37 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
38 | if add_nad_transform:
39 | transforms_list.append(Cutout(1,9))
40 |
41 | elif opt.dataset == "mnist":
42 | transforms_list.append(transforms.Normalize([0.1307], [0.3081]))
43 | if add_nad_transform:
44 | transforms_list.append(Cutout(1,9))
45 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
46 | transforms_list.append(transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669)))
47 | if add_nad_transform:
48 | transforms_list.append(Cutout(1,9))
49 | elif opt.dataset == "imagenet":
50 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
51 | if add_nad_transform:
52 | transforms_list.append(Cutout(1,9))
53 | else:
54 | raise Exception("Invalid Dataset")
55 |
56 | return transforms.Compose(transforms_list)
57 |
58 |
59 |
60 | class GTSRB(data.Dataset):
61 | def __init__(self, opt, train, transforms):
62 | super(GTSRB, self).__init__()
63 | if train:
64 | self.data_folder = os.path.join(opt.data_root, "GTSRB/Train")
65 | self.images, self.labels = self._get_data_train_list()
66 | else:
67 | self.data_folder = os.path.join(opt.data_root, "GTSRB/Test")
68 | self.images, self.labels = self._get_data_test_list()
69 |
70 | self.transforms = transforms
71 |
72 | def _get_data_train_list(self):
73 | images = []
74 | labels = []
75 | for c in range(0, 43):
76 | prefix = self.data_folder + "/" + format(c, "05d") + "/"
77 | gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv")
78 | gtReader = csv.reader(gtFile, delimiter=";")
79 | next(gtReader)
80 | for row in gtReader:
81 | images.append(prefix + row[0])
82 | labels.append(int(row[7]))
83 | gtFile.close()
84 | return images, labels
85 |
86 | def _get_data_test_list(self):
87 | images = []
88 | labels = []
89 | prefix = os.path.join(self.data_folder, "GT-final_test.csv")
90 | gtFile = open(prefix)
91 | gtReader = csv.reader(gtFile, delimiter=";")
92 | next(gtReader)
93 | for row in gtReader:
94 | images.append(self.data_folder + "/" + row[0])
95 | labels.append(int(row[7]))
96 | return images, labels
97 |
98 | def __len__(self):
99 | return len(self.images)
100 |
101 | def __getitem__(self, index):
102 | image = Image.open(self.images[index])
103 | image = self.transforms(image)
104 | label = self.labels[index]
105 | return image, label
106 |
107 | def get_dataloader_partial_split(opt, train_fraction=0.1, train=True, pretensor_transform=False,shuffle=True,return_index = False):
108 | data_fraction = train_fraction
109 |
110 | transform_train = get_transform(opt, True, pretensor_transform)
111 | transform_test = get_transform(opt, False, pretensor_transform)
112 |
113 | transform = transform_train
114 |
115 | if opt.dataset == "gtsrb":
116 | dataset = GTSRB(opt, train, transform_train)
117 | dataset_test = GTSRB(opt, train, transform_test)
118 | class_num=43
119 | elif opt.dataset == "mnist":
120 | dataset = torchvision.datasets.MNIST(opt.data_root, train, transform=transform_train, download=True)
121 | dataset_test = torchvision.datasets.MNIST(opt.data_root, train, transform=transform_test, download=True)
122 |
123 | class_num=10
124 | elif opt.dataset == "cifar10":
125 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform_train, download=True)
126 | dataset_test = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform_test, download=True)
127 | class_num=10
128 | elif opt.dataset == "celeba":
129 | if train:
130 | split = "train"
131 | else:
132 | split = "test"
133 | dataset = CelebA_attr(opt, split, transform)
134 | class_num=8
135 | elif opt.dataset == "imagenet":
136 | if train==True:
137 | file_dir = "/workspace/data/imagenet/train"
138 | elif train==False:
139 | file_dir = "/workspace/data/imagenet/val"
140 | dataset = torchvision.datasets.ImageFolder(
141 | file_dir,
142 | transform
143 | )
144 | dataset_test = torchvision.datasets.ImageFolder(
145 | file_dir,
146 | transform
147 | )
148 | class_num=1000
149 | else:
150 | raise Exception("Invalid dataset")
151 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
152 | #finetuneset = torch.utils.data.Subset(dataset, range(0,dataset.__len__(),int(1/data_fraction)))
153 | dataloader_total = torch.utils.data.DataLoader(dataset, batch_size=1, pin_memory=True,num_workers=opt.num_workers, shuffle=False)
154 |
155 | idx = []
156 | counter = [0]*class_num
157 | for batch_idx, (inputs, targets) in enumerate(dataloader_total):
158 |
159 | if counter[targets.item()] 1:
20 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
21 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
22 | self.conv1 = conv3x3(inplanes, planes, stride)
23 | self.bn1 = norm_layer(planes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = norm_layer(planes)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | # Added another relu here
31 | self.relu2 = nn.ReLU(inplace=True)
32 |
33 | def forward(self, x):
34 | identity = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | identity = self.downsample(x)
45 |
46 | out += identity
47 |
48 | # Modified to use relu2
49 | out = self.relu2(out)
50 |
51 | return out
52 |
53 | class Bottleneck(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, inplanes, planes, stride=1, downsample=None):
57 | super(Bottleneck, self).__init__()
58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(planes)
60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
61 | self.bn2 = nn.BatchNorm2d(planes)
62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
63 | self.bn3 = nn.BatchNorm2d(planes * 4)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | def forward(self, x):
69 | residual = x
70 |
71 | x = self.conv1(x)
72 | x = self.bn1(x)
73 | x = self.relu(x)
74 |
75 | x = self.conv2(x)
76 | x = self.bn2(x)
77 | x = self.relu(x)
78 |
79 | x = self.conv3(x)
80 | x = self.bn3(x)
81 |
82 | if self.downsample is not None:
83 | residual = self.downsample(residual)
84 |
85 | x += residual
86 | x = self.relu(x)
87 |
88 | return x
89 |
90 |
91 | class ResNet(nn.Module):
92 |
93 | def __init__(self, block, layers, num_classes=10,in_channels=3):
94 | self.inplanes = 64
95 | super(ResNet, self).__init__()
96 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn1 = nn.BatchNorm2d(64)
98 | self.relu = nn.ReLU(inplace=True)
99 | self.layer1 = self._make_layer(block, 64, layers[0])
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
103 | self.avgpool = nn.AvgPool2d(kernel_size=4)
104 | self.fc = nn.Linear(512 * block.expansion, num_classes)
105 |
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109 | m.weight.data.normal_(0, math.sqrt(2. / n))
110 | elif isinstance(m, nn.BatchNorm2d):
111 | m.weight.data.fill_(1)
112 | m.bias.data.zero_()
113 |
114 | def _make_layer(self, block, planes, blocks, stride=1):
115 | downsample = None
116 | if stride != 1 or self.inplanes != planes * block.expansion:
117 | downsample = nn.Sequential(
118 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
119 | nn.BatchNorm2d(planes * block.expansion),
120 | )
121 |
122 | layers = []
123 | layers.append(block(self.inplanes, planes, stride, downsample))
124 | self.inplanes = planes * block.expansion
125 | for i in range(1, blocks):
126 | layers.append(block(self.inplanes, planes))
127 | return nn.Sequential(*layers)
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 | x = self.bn1(x)
132 | x = self.relu(x)
133 |
134 | x = self.layer1(x)
135 | x = self.layer2(x)
136 | x = self.layer3(x)
137 | x = self.layer4(x)
138 |
139 | x = self.avgpool(x)
140 | x = x.view(x.size(0), -1)
141 | x = self.fc(x)
142 |
143 | return x
144 |
145 | def from_input_to_features(self, x, index):
146 | x = self.conv1(x)
147 | x = self.bn1(x)
148 | x = self.relu(x)
149 |
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 | return x
155 |
156 | def from_features_to_output(self, x, index):
157 | x = self.avgpool(x)
158 | x = x.view(x.size(0), -1)
159 | x = self.fc(x)
160 | return x
161 |
162 | def resnet18(**kwargs):
163 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 |
165 |
166 | def resnet34(**kwargs):
167 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
168 |
169 |
170 | def resnet50(**kwargs):
171 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
172 |
173 |
174 | def resnet101(**kwargs):
175 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
176 |
177 |
178 | def resnet152(**kwargs):
179 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
--------------------------------------------------------------------------------
/reverse_engineering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 | import torchvision
4 | import os
5 | import numpy as np
6 | from resnet_nole import *
7 | from models import meta_classifier_cifar10_model,lenet,ULP_model,preact_resnet
8 | import torch.nn.functional as F
9 |
10 | import unet_model
11 | import random
12 | import pilgram
13 | from PIL import Image
14 | from functools import reduce
15 |
16 | class RegressionModel(nn.Module):
17 | def __init__(self, opt, init_mask):
18 | self._EPSILON = opt.EPSILON
19 | super(RegressionModel, self).__init__()
20 |
21 | if init_mask is not None:
22 | self.mask_tanh = nn.Parameter(torch.tensor(init_mask))
23 |
24 | self.classifier = self._get_classifier(opt)
25 | self.example_features = None
26 |
27 | if opt.dataset == "mnist":
28 | self.AE = unet_model.UNet(n_channels=1,num_classes=1,base_filter_num=opt.ae_filter_num, num_blocks=opt.ae_num_blocks)
29 | else:
30 | self.AE = unet_model.UNet(n_channels=3,num_classes=3,base_filter_num=opt.ae_filter_num, num_blocks=opt.ae_num_blocks)
31 |
32 | self.AE.train()
33 | self.example_ori_img = None
34 | self.example_ae_img = None
35 | self.opt = opt
36 |
37 | def forward_ori(self, x,opt):
38 |
39 | features = self.classifier.from_input_to_features(x, opt.internal_index)
40 | out = self.classifier.from_features_to_output(features, opt.internal_index)
41 |
42 | return out, features
43 |
44 | def forward_flip_mask(self, x,opt):
45 |
46 | strategy = "flip"
47 | features = self.classifier.from_input_to_features(x, opt.internal_index)
48 | if strategy == "flip":
49 | features = (1 - opt.flip_mask) * features - opt.flip_mask * features
50 | elif strategy == "zero":
51 | features = (1 - opt.flip_mask) * features
52 |
53 | out = self.classifier.from_features_to_output(features, opt.internal_index)
54 |
55 | return out, features
56 |
57 | def forward_ae(self, x,opt):
58 |
59 | self.example_ori_img = x
60 | x_before_ae = x
61 | x = self.AE(x)
62 | x_after_ae = x
63 | self.example_ae_img = x
64 |
65 | features = self.classifier.from_input_to_features(x, opt.internal_index)
66 | out = self.classifier.from_features_to_output(features, opt.internal_index)
67 |
68 | self.example_features = features
69 |
70 | return out, features, x_before_ae, x_after_ae
71 |
72 |
73 | def forward_ae_mask_p(self, x,opt):
74 | mask = self.get_raw_mask(opt)
75 | self.example_ori_img = x
76 | x_before_ae = x
77 | x = self.AE(x)
78 | x_after_ae = x
79 | self.example_ae_img = x
80 |
81 | features = self.classifier.from_input_to_features(x, opt.internal_index)
82 | reference_features_index_list = np.random.choice(range(opt.all_features.shape[0]), features.shape[0], replace=True)
83 | reference_features = opt.all_features[reference_features_index_list]
84 | features_ori = features
85 | features = mask * features + (1-mask) * reference_features.reshape(features.shape)
86 |
87 | out = self.classifier.from_features_to_output(features, opt.internal_index)
88 |
89 | self.example_features = features_ori
90 |
91 | return out, features, x_before_ae, x_after_ae, features_ori
92 |
93 | def forward_ae_mask_p_test(self, x,opt):
94 | mask = self.get_raw_mask(opt)
95 | self.example_ori_img = x
96 | x_before_ae = x
97 | x = self.AE(x)
98 | x_after_ae = x
99 | self.example_ae_img = x
100 |
101 | features = self.classifier.from_input_to_features(x, opt.internal_index)
102 | bs = features.shape[0]
103 | index_1 = list(range(bs))
104 | random.shuffle(index_1)
105 | reference_features = features[index_1]
106 | features_ori = features
107 | features = mask * features + (1-mask) * reference_features.reshape(features.shape)
108 | out = self.classifier.from_features_to_output(features, opt.internal_index)
109 | self.example_features = features_ori
110 |
111 | return out, features, x_before_ae, x_after_ae, features_ori
112 |
113 | def get_raw_mask(self,opt):
114 | mask = nn.Tanh()(self.mask_tanh)
115 | bounded = mask / (2 + self._EPSILON) + 0.5
116 | return bounded
117 |
118 | def _get_classifier(self, opt):
119 |
120 | if opt.set_arch:
121 | if opt.set_arch == "resnet18":
122 | classifier = resnet18(num_classes = opt.num_classes, in_channels = opt.input_channel)
123 | elif opt.set_arch=="preact_resnet18":
124 | classifier = preact_resnet.PreActResNet18(num_classes=opt.num_classes)
125 | elif opt.set_arch=="meta_classifier_cifar10_model":
126 | classifier = meta_classifier_cifar10_model.MetaClassifierCifar10Model()
127 | elif opt.set_arch=="mnist_lenet":
128 | classifier = lenet.LeNet5()
129 | elif opt.set_arch=="ulp_vgg":
130 | classifier = ULP_model.CNN_classifier()
131 | else:
132 | print("invalid arch")
133 |
134 | if opt.hand_set_model_path:
135 | ckpt_path = opt.hand_set_model_path
136 |
137 | state_dict = torch.load(ckpt_path)
138 | try:
139 | classifier.load_state_dict(state_dict["net_state_dict"])
140 | except:
141 | try:
142 | classifier.load_state_dict(state_dict["netC"])
143 | except:
144 | try:
145 | from collections import OrderedDict
146 | new_state_dict = OrderedDict()
147 | for k, v in state_dict["model"].items():
148 | name = k[7:] # remove `module.`
149 | new_state_dict[name] = v
150 | classifier.load_state_dict(new_state_dict)
151 |
152 | except:
153 | classifier.load_state_dict(state_dict)
154 |
155 | for param in classifier.parameters():
156 | param.requires_grad = False
157 | classifier.eval()
158 | return classifier.to(opt.device)
159 |
160 | class Recorder:
161 | def __init__(self, opt):
162 | super().__init__()
163 | self.mixed_value_best = float("inf")
164 |
165 | def test_ori(opt, regression_model, testloader, flip=False):
166 | regression_model.eval()
167 | regression_model.AE.eval()
168 | regression_model.classifier.eval()
169 | total_pred = 0
170 | true_pred = 0
171 | cross_entropy = nn.CrossEntropyLoss()
172 | for inputs,labels in testloader:
173 | inputs = inputs.to(opt.device)
174 | labels = labels.to(opt.device)
175 | sample_num = inputs.shape[0]
176 | total_pred += sample_num
177 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label
178 |
179 | if flip:
180 | out, features = regression_model.forward_flip_mask(inputs,opt)
181 | else:
182 | out, features = regression_model.forward_ori(inputs,opt)
183 | predictions = out
184 |
185 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == labels).detach()
186 | loss_ce = cross_entropy(predictions, target_labels)
187 |
188 | print("BA true_pred:",true_pred)
189 | print("BA total_pred:",total_pred)
190 | print(
191 | "BA test acc:",true_pred * 100.0 / total_pred
192 | )
193 |
194 | def test_ori_attack(opt, regression_model, testloader, flip=False):
195 | regression_model.eval()
196 | regression_model.AE.eval()
197 | regression_model.classifier.eval()
198 | total_pred = 0
199 | true_pred = 0
200 | cross_entropy = nn.CrossEntropyLoss()
201 | for inputs,labels in testloader:
202 |
203 | inputs = inputs.to(opt.device)
204 |
205 | if opt.asr_test_type == "filter":
206 |
207 | t_mean = opt.t_mean.cuda()
208 | t_std = opt.t_std.cuda()
209 | GT_img = inputs
210 | GT_img = (torch.clamp(GT_img*t_std+t_mean, min=0, max=1).detach().cpu().numpy()*255).astype(np.uint8)
211 | for j in range(GT_img.shape[0]):
212 | ori_pil_img = Image.fromarray(GT_img[j].transpose((1,2,0)))
213 | convered_pil_img = pilgram._1977(ori_pil_img)
214 | GT_img[j] = np.asarray(convered_pil_img).transpose((2,0,1))
215 | GT_img = GT_img.astype(np.float32)
216 | GT_img = GT_img/255
217 | GT_img = torch.from_numpy(GT_img).cuda()
218 | GT_img = (GT_img - t_mean)/t_std
219 | inputs = GT_img
220 | elif opt.asr_test_type == "wanet":
221 | inputs = F.grid_sample(inputs, opt.grid_temps.repeat(inputs.shape[0], 1, 1, 1), align_corners=True)
222 |
223 |
224 | inputs = inputs.to(opt.device)
225 | labels = labels.to(opt.device)
226 | sample_num = inputs.shape[0]
227 | total_pred += sample_num
228 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label
229 |
230 | if flip:
231 | out, features = regression_model.forward_flip_mask(inputs,opt)
232 | else:
233 | out, features = regression_model.forward_ori(inputs,opt)
234 | predictions = out
235 |
236 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()
237 | loss_ce = cross_entropy(predictions, target_labels)
238 |
239 | print("ASR true_pred:",true_pred)
240 | print("ASR total_pred:",total_pred)
241 | print(
242 | "ASR test acc:",true_pred * 100.0 / total_pred
243 | )
244 |
245 | def fix_neuron_flip(opt,trainloader,testloader,testloader_asr):
246 |
247 | trained_regression_model = opt.trained_regression_model
248 | trained_regression_model.eval()
249 | trained_regression_model.AE.eval()
250 | trained_regression_model.classifier.eval()
251 |
252 | if opt.asr_test_type == "wanet":
253 | ckpt_path = opt.hand_set_model_path
254 | state_dict = torch.load(ckpt_path)
255 | identity_grid = state_dict["identity_grid"]
256 | noise_grid = state_dict["noise_grid"]
257 | grid_temps = (identity_grid + 0.5 * noise_grid / opt.input_height) * 1
258 | grid_temps = torch.clamp(grid_temps, -1, 1)
259 |
260 | opt.grid_temps = grid_temps
261 |
262 | test_ori(opt, trained_regression_model,testloader,flip=False)
263 | test_ori_attack(opt, trained_regression_model,testloader_asr,flip=False)
264 |
265 | neuron_finding_strategy = "hyperplane"
266 |
267 | cross_entropy = nn.CrossEntropyLoss()
268 | for batch_idx, (inputs, labels) in enumerate(trainloader):
269 | inputs = inputs.to(opt.device)
270 | labels = labels.to(opt.device)
271 | out, features_reversed, x_before_ae, x_after_ae = trained_regression_model.forward_ae(inputs,opt)
272 | loss_ce_transformed = cross_entropy(out, labels)
273 |
274 | out, features_ori = trained_regression_model.forward_ori(inputs,opt)
275 | loss_ce_ori = cross_entropy(out, labels)
276 |
277 | feature_dist = torch.nn.MSELoss(reduction='none').cuda()(features_ori,features_reversed).mean(0)
278 | print(feature_dist)
279 |
280 | if neuron_finding_strategy == "diff":
281 | values, indices = feature_dist.reshape(-1).topk(int(0.03*torch.numel(feature_dist)), largest=True, sorted=True)
282 | flip_mask = torch.zeros(feature_dist.reshape(-1).shape).to(opt.device)
283 | for index in indices:
284 | flip_mask[index] = 1
285 | flip_mask = flip_mask.reshape(feature_dist.shape)
286 |
287 | elif neuron_finding_strategy == "hyperplane":
288 | flip_mask = trained_regression_model.get_raw_mask(opt)
289 |
290 | opt.flip_mask = flip_mask
291 |
292 | print("loss_ce_transformed:",loss_ce_transformed)
293 | print("loss_ce_ori:",loss_ce_ori)
294 |
295 |
296 | test_ori(opt, trained_regression_model,testloader,flip=True)
297 | test_ori_attack(opt, trained_regression_model,testloader_asr,flip=True)
298 |
299 | def train(opt, init_mask):
300 |
301 | data_now = opt.data_now
302 | opt.weight_p = 1
303 | opt.weight_acc = 1
304 | opt.weight_std = 1
305 | opt.init_mask = init_mask
306 |
307 | recorder = Recorder(opt)
308 | regression_model = RegressionModel(opt, init_mask).to(opt.device)
309 |
310 | opt.epoch = 400
311 | if opt.override_epoch:
312 | opt.epoch = opt.override_epoch
313 |
314 | optimizerR = torch.optim.Adam(regression_model.AE.parameters(),lr=opt.lr,betas=(0.5,0.9))
315 | optimizerR_mask = torch.optim.Adam([regression_model.mask_tanh],lr=1e-1,betas=(0.5,0.9))
316 |
317 | regression_model.AE.train()
318 | recorder = Recorder(opt)
319 | process = train_step
320 |
321 | warm_up_epoch = 100
322 | for epoch in range(warm_up_epoch):
323 | process(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt, warm_up=True)
324 |
325 | for epoch in range(opt.epoch):
326 | process(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt)
327 |
328 | opt.trained_regression_model = regression_model
329 |
330 | return recorder, opt
331 |
332 | def get_range(opt, init_mask):
333 |
334 | test_dataloader = opt.re_dataloader_total_fixed
335 | inversion_engine = RegressionModel(opt, init_mask).to(opt.device)
336 |
337 | features_list = []
338 | features_list_class = [[] for i in range(opt.num_classes)]
339 | for batch_idx, (inputs, labels) in enumerate(test_dataloader):
340 | inputs = inputs.to(opt.device)
341 | out, features = inversion_engine.forward_ori(inputs,opt)
342 | print(torch.argmax(out,dim=1))
343 |
344 | features_list.append(features)
345 | for i in range(inputs.shape[0]):
346 | features_list_class[labels[i].item()].append(features[i].unsqueeze(0))
347 | all_features = torch.cat(features_list,dim=0)
348 | opt.all_features = all_features
349 | print(all_features.shape)
350 |
351 | del features_list
352 | del test_dataloader
353 |
354 | weight_map_class = []
355 | for i in range(opt.num_classes):
356 | feature_mean_class = torch.cat(features_list_class[i],dim=0).mean(0)
357 | weight_map_class.append(feature_mean_class)
358 |
359 | opt.weight_map_class = weight_map_class
360 | del all_features
361 | del features_list_class
362 |
363 | def train_step(regression_model, optimizerR, optimizerR_mask, data_now, recorder, epoch, opt, warm_up=False):
364 | print("Epoch {} - Label: {} | {} - {}:".format(epoch, opt.target_label, opt.dataset, opt.attack_mode))
365 | cross_entropy = nn.CrossEntropyLoss()
366 | total_pred = 0
367 | true_pred = 0
368 |
369 | loss_ce_list = []
370 | loss_dist_list = []
371 | loss_list = []
372 | acc_list = []
373 |
374 | p_loss_list = []
375 | loss_mask_norm_list = []
376 | loss_std_list = []
377 |
378 | for inputs in data_now:
379 | regression_model.AE.train()
380 | regression_model.mask_tanh.requires_grad = False
381 |
382 | optimizerR.zero_grad()
383 |
384 | inputs = inputs.to(opt.device)
385 | sample_num = inputs.shape[0]
386 | total_pred += sample_num
387 | target_labels = torch.ones((sample_num), dtype=torch.int64).to(opt.device) * opt.target_label
388 | if warm_up:
389 | predictions, features, x_before_ae, x_after_ae = regression_model.forward_ae(inputs,opt)
390 | else:
391 | predictions, features, x_before_ae, x_after_ae, features_ori = regression_model.forward_ae_mask_p(inputs,opt)
392 |
393 | loss_ce = cross_entropy(predictions, target_labels)
394 |
395 | mse_loss = torch.nn.MSELoss(size_average = True).cuda()(x_after_ae,x_before_ae)
396 |
397 | if warm_up:
398 | dist_loss = torch.cosine_similarity(opt.weight_map_class[opt.target_label].reshape(-1),features.mean(0).reshape(-1),dim=0)
399 | else:
400 | dist_loss = torch.cosine_similarity(opt.weight_map_class[opt.target_label].reshape(-1),features_ori.mean(0).reshape(-1),dim=0)
401 |
402 | acc_list_ = []
403 | minibatch_accuracy_ = torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() / sample_num
404 | acc_list_.append(minibatch_accuracy_)
405 | acc_list_ = torch.stack(acc_list_)
406 | avg_acc_G = torch.mean(acc_list_)
407 |
408 | acc_list.append(minibatch_accuracy_)
409 |
410 | p_loss = mse_loss
411 | p_loss_bound = opt.p_loss_bound
412 | loss_std_bound = opt.loss_std_bound
413 |
414 | atk_succ_threshold = opt.ae_atk_succ_t
415 |
416 | if opt.ignore_dist:
417 | dist_loss = dist_loss*0
418 |
419 | if warm_up:
420 | if (p_loss>p_loss_bound):
421 | total_loss = loss_ce + p_loss*100
422 | else:
423 | total_loss = loss_ce
424 | else:
425 | loss_std = (features_ori*regression_model.get_raw_mask(opt)).std(0).sum()
426 | loss_std = loss_std/(torch.norm(regression_model.get_raw_mask(opt), 1))
427 |
428 | total_loss = dist_loss*5
429 | if dist_loss<0:
430 | total_loss = total_loss - dist_loss*5
431 | if loss_std>loss_std_bound:
432 | total_loss = total_loss + loss_std*10*(1+opt.weight_std)
433 | if (p_loss>p_loss_bound):
434 | total_loss = total_loss + p_loss*10*(1+opt.weight_p)
435 |
436 | if avg_acc_G.item()mask_norm_bound:
455 | loss_mask_total = loss_mask_total + loss_mask_norm
456 |
457 | loss_mask_total.backward()
458 | optimizerR_mask.step()
459 |
460 | loss_ce_list.append(loss_ce.detach())
461 | loss_dist_list.append(dist_loss.detach())
462 | loss_list.append(total_loss.detach())
463 |
464 | true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()
465 |
466 | if not warm_up:
467 | p_loss_list.append(p_loss)
468 | loss_mask_norm_list.append(loss_mask_norm)
469 | loss_std_list.append(loss_std)
470 |
471 | loss_ce_list = torch.stack(loss_ce_list)
472 | loss_dist_list = torch.stack(loss_dist_list)
473 | loss_list = torch.stack(loss_list)
474 | acc_list = torch.stack(acc_list)
475 |
476 | avg_loss_ce = torch.mean(loss_ce_list)
477 | avg_loss_dist = torch.mean(loss_dist_list)
478 | avg_loss = torch.mean(loss_list)
479 | avg_acc = torch.mean(acc_list)
480 |
481 | if not warm_up:
482 | p_loss_list = torch.stack(p_loss_list)
483 | loss_mask_norm_list = torch.stack(loss_mask_norm_list)
484 | loss_std_list = torch.stack(loss_std_list)
485 |
486 | avg_p_loss = torch.mean(p_loss_list)
487 | avg_loss_mask_norm = torch.mean(loss_mask_norm_list)
488 | avg_loss_std = torch.mean(loss_std_list)
489 | print("avg_ce_loss:",avg_loss_ce)
490 | print("avg_asr:",avg_acc)
491 | print("avg_p_loss:",avg_p_loss)
492 | print("avg_loss_mask_norm:",avg_loss_mask_norm)
493 | print("avg_loss_std:",avg_loss_std)
494 |
495 |
496 | if avg_acc.item()1.0*p_loss_bound:
499 | print("@avg_p_loss larger than bound")
500 | if avg_loss_mask_norm>1.0*mask_norm_bound:
501 | print("@avg_loss_mask_norm larger than bound")
502 | if avg_loss_std>1.0*loss_std_bound:
503 | print("@avg_loss_std larger than bound")
504 |
505 |
506 | mixed_value = avg_loss_dist.detach() - avg_acc + max(avg_p_loss.detach()-p_loss_bound,0)/p_loss_bound + max(avg_loss_mask_norm.detach()-mask_norm_bound,0)/mask_norm_bound + max(avg_loss_std.detach()-loss_std_bound,0)/loss_std_bound
507 | print("mixed_value:",mixed_value)
508 | if mixed_value < recorder.mixed_value_best:
509 | recorder.mixed_value_best = mixed_value
510 | opt.weight_p = max(avg_p_loss.detach()-p_loss_bound,0)/p_loss_bound
511 | opt.weight_acc = max(atk_succ_threshold-avg_acc,0)/atk_succ_threshold
512 | opt.weight_std = max(avg_loss_std.detach()-loss_std_bound,0)/loss_std_bound
513 |
514 |
515 | print(
516 | " Result: ASR: {:.3f} | Cross Entropy Loss: {:.6f} | Dist Loss: {:.6f} | Mixed_value best: {:.6f}".format(
517 | true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_dist, recorder.mixed_value_best
518 | )
519 | )
520 |
521 | recorder.final_asr = avg_acc
522 |
523 | return avg_acc
524 |
525 | if __name__ == "__main__":
526 | pass
527 |
--------------------------------------------------------------------------------
/train_models/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_arguments():
5 | parser = argparse.ArgumentParser()
6 |
7 | parser.add_argument("--data_root", type=str, default="./data/")
8 | parser.add_argument("--checkpoints", type=str, default="./checkpoints")
9 | parser.add_argument("--temps", type=str, default="./temps")
10 | parser.add_argument("--device", type=str, default="cuda")
11 | parser.add_argument("--continue_training", action="store_true")
12 |
13 | parser.add_argument("--model_filepath", type=str, default="./checkpoints")
14 |
15 | parser.add_argument("--dataset", type=str, default="cifar10")
16 | parser.add_argument("--set_arch", type=str, default=None)
17 | parser.add_argument("--attack_mode", type=str, default="all2one")
18 |
19 | parser.add_argument("--save_all", type=bool, default=False)
20 | parser.add_argument("--save_freq", type=int, default=50)
21 |
22 | parser.add_argument("--bs", type=int, default=128)
23 | parser.add_argument("--lr_C", type=float, default=1e-2)
24 | parser.add_argument("--schedulerC_milestones", type=list, default=[100, 200, 300, 400])
25 | parser.add_argument("--schedulerC_lambda", type=float, default=0.1)
26 | parser.add_argument("--n_iters", type=int, default=1000)
27 | parser.add_argument("--num_workers", type=float, default=6)
28 |
29 | parser.add_argument("--target_label", type=int, default=0)
30 | parser.add_argument("--pc", type=float, default=0.1)
31 | parser.add_argument("--cross_ratio", type=float, default=2) # rho_a = pc, rho_n = pc * cross_ratio
32 |
33 | parser.add_argument("--random_rotation", type=int, default=10)
34 | parser.add_argument("--random_crop", type=int, default=5)
35 |
36 | parser.add_argument("--extra_flag", type=str, default="")
37 |
38 | parser.add_argument("--s", type=float, default=0.5)
39 | parser.add_argument("--k", type=int, default=4)
40 | parser.add_argument(
41 | "--grid-rescale", type=float, default=1
42 | ) # scale grid values to avoid pixel values going out of [-1, 1]. For example, grid-rescale = 0.98
43 |
44 | return parser
45 |
--------------------------------------------------------------------------------
/train_models/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import os
6 | import csv
7 | import kornia.augmentation as A
8 | import random
9 | import numpy as np
10 |
11 | from PIL import Image
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 | from torch.utils.data import Dataset
15 | from natsort import natsorted
16 |
17 | from io import BytesIO
18 |
19 | class ToNumpy:
20 | def __call__(self, x):
21 | x = np.array(x)
22 | if len(x.shape) == 2:
23 | x = np.expand_dims(x, axis=2)
24 | return x
25 |
26 |
27 | class ProbTransform(torch.nn.Module):
28 | def __init__(self, f, p=1):
29 | super(ProbTransform, self).__init__()
30 | self.f = f
31 | self.p = p
32 |
33 | def forward(self, x): # , **kwargs):
34 | if random.random() < self.p:
35 | return self.f(x)
36 | else:
37 | return x
38 |
39 | def get_transform(opt, train=True, pretensor_transform=False):
40 | add_nad_transform = False
41 |
42 | if opt.dataset == "trojai":
43 | return transforms.Compose([transforms.CenterCrop(opt.input_height),transforms.ToTensor()])
44 |
45 | transforms_list = []
46 | transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
47 | if pretensor_transform:
48 | if train:
49 | transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
50 | transforms_list.append(transforms.RandomRotation(opt.random_rotation))
51 | if opt.dataset == "cifar10":
52 | transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
53 |
54 | if add_nad_transform:
55 | transforms_list.append(transforms.RandomCrop(opt.input_height, padding=4))
56 | transforms_list.append(transforms.RandomHorizontalFlip())
57 |
58 |
59 | transforms_list.append(transforms.ToTensor())
60 | if (opt.set_arch is not None) and (("nole" in opt.set_arch) or ("mnist_lenet" in opt.set_arch)):
61 | if opt.dataset == "cifar10":
62 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
63 | if add_nad_transform:
64 | transforms_list.append(Cutout(1,9))
65 |
66 | elif opt.dataset == "mnist":
67 | transforms_list.append(transforms.Normalize([0.1307], [0.3081]))
68 | if add_nad_transform:
69 | transforms_list.append(Cutout(1,9))
70 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
71 | transforms_list.append(transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669)))
72 | if add_nad_transform:
73 | transforms_list.append(Cutout(1,9))
74 | elif opt.dataset == "imagenet":
75 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
76 | if add_nad_transform:
77 | transforms_list.append(Cutout(1,9))
78 | else:
79 | raise Exception("Invalid Dataset")
80 | else:
81 | if opt.dataset == "cifar10":
82 | transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
83 | if add_nad_transform:
84 | transforms_list.append(Cutout(1,9))
85 | elif opt.dataset == "mnist":
86 | transforms_list.append(transforms.Normalize([0.5], [0.5]))
87 | if add_nad_transform:
88 | transforms_list.append(Cutout(1,9))
89 | elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
90 | pass
91 | elif opt.dataset == "imagenet":
92 | transforms_list.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
93 | if add_nad_transform:
94 | transforms_list.append(Cutout(1,9))
95 | else:
96 | raise Exception("Invalid Dataset")
97 | return transforms.Compose(transforms_list)
98 | class Cutout(object):
99 | """Randomly mask out one or more patches from an image.
100 | Args:
101 | n_holes (int): Number of patches to cut out of each image.
102 | length (int): The length (in pixels) of each square patch.
103 | """
104 | def __init__(self, n_holes, length):
105 | self.n_holes = n_holes
106 | self.length = length
107 |
108 | def __call__(self, img):
109 | """
110 | Args:
111 | img (Tensor): Tensor image of size (C, H, W).
112 | Returns:
113 | Tensor: Image with n_holes of dimension length x length cut out of it.
114 | """
115 | h = img.size(1)
116 | w = img.size(2)
117 |
118 | mask = np.ones((h, w), np.float32)
119 |
120 | for n in range(self.n_holes):
121 | y = np.random.randint(h)
122 | x = np.random.randint(w)
123 |
124 | y1 = np.clip(y - self.length // 2, 0, h)
125 | y2 = np.clip(y + self.length // 2, 0, h)
126 | x1 = np.clip(x - self.length // 2, 0, w)
127 | x2 = np.clip(x + self.length // 2, 0, w)
128 |
129 | mask[y1: y2, x1: x2] = 0.
130 |
131 | mask = torch.from_numpy(mask)
132 | mask = mask.expand_as(img)
133 | img = img * mask
134 | #print(img)
135 |
136 | return img
137 |
138 | class PostTensorTransform(torch.nn.Module):
139 | def __init__(self, opt):
140 | super(PostTensorTransform, self).__init__()
141 | self.random_crop = ProbTransform(
142 | A.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop), p=0.8
143 | )
144 | self.random_rotation = ProbTransform(A.RandomRotation(opt.random_rotation), p=0.5)
145 | if opt.dataset == "cifar10":
146 | self.random_horizontal_flip = A.RandomHorizontalFlip(p=0.5)
147 |
148 | def forward(self, x):
149 | for module in self.children():
150 | x = module(x)
151 | return x
152 |
153 | def get_dataloader(opt, train=True, pretensor_transform=False, shuffle=True, return_dataset = False):
154 | transform = get_transform(opt, train, pretensor_transform)
155 | if opt.dataset == "cifar10":
156 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform, download=True)
157 | else:
158 | raise Exception("Invalid dataset")
159 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle)
160 | if return_dataset:
161 | return dataset, dataloader, transform
162 | else:
163 | return dataloader, transform
164 |
165 | def get_dataloader_random_ratio(opt, train=True, pretensor_transform=False, shuffle=True):
166 | transform = get_transform(opt, train, pretensor_transform)
167 | if opt.dataset == "cifar10":
168 | dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=transform, download=True)
169 | else:
170 | raise Exception("Invalid dataset")
171 |
172 | idx = random.sample(range(dataset.__len__()),int(dataset.__len__()*opt.random_ratio))
173 | dataset = torch.utils.data.Subset(dataset,idx)
174 | #trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
175 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle)
176 | return dataloader, transform
177 |
178 | def main():
179 | pass
180 |
181 |
182 | if __name__ == "__main__":
183 | main()
184 |
--------------------------------------------------------------------------------
/train_models/resnet_nole.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | # 3x3 convolution with padding
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
8 |
9 |
10 | '''class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | #print(downsample)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | x = self.conv1(x)
28 | x = self.bn1(x)
29 | x = self.relu(x)
30 |
31 | x = self.conv2(x)
32 | x = self.bn2(x)
33 |
34 | if self.downsample is not None:
35 | #print(x.shape)
36 | residual = self.downsample(residual)
37 |
38 | x += residual
39 | x = self.relu(x)
40 |
41 | return x
42 |
43 | def input_to_residual(self, x):
44 | residual = x
45 | if self.downsample is not None:
46 | residual = self.downsample(residual)
47 | return residual
48 |
49 | def residual_to_output(self, residual,conv2):
50 | x = residual + conv2
51 | x = self.relu(x)
52 |
53 | return x
54 |
55 |
56 | def input_to_conv2(self, x):
57 | residual = x
58 | x = self.conv1(x)
59 | x = self.bn1(x)
60 | x = self.relu(x)
61 | x = self.conv2(x)
62 | return x
63 |
64 | def conv2_to_output(self, x, residual):
65 | x = self.bn2(x)
66 | x = residual + x
67 | x = self.relu(x)
68 | return x
69 |
70 | def conv2_to_output_mask(self, x, residual,mask,pattern):
71 | x = self.bn2(x)
72 | x = residual + x
73 | x = (1 - mask) * x + mask * pattern
74 | x = self.relu(x)
75 | return x
76 |
77 | def input_to_conv1(self, x):
78 | x = self.conv1(x)
79 | return x
80 |
81 | def conv1_to_output(self, x, residual):
82 | x = self.bn1(x)
83 | x = self.relu(x)
84 |
85 | x = self.conv2(x)
86 | x = self.bn2(x)
87 |
88 | x += residual
89 | x = self.relu(x)
90 |
91 | return x'''
92 |
93 | class BasicBlock(nn.Module):
94 | expansion = 1
95 |
96 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
97 | base_width=64, dilation=1, norm_layer=None):
98 | super(BasicBlock, self).__init__()
99 | if norm_layer is None:
100 | norm_layer = nn.BatchNorm2d
101 | if groups != 1 or base_width != 64:
102 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
103 | if dilation > 1:
104 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
105 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
106 | self.conv1 = conv3x3(inplanes, planes, stride)
107 | self.bn1 = norm_layer(planes)
108 | self.relu = nn.ReLU(inplace=True)
109 | self.conv2 = conv3x3(planes, planes)
110 | self.bn2 = norm_layer(planes)
111 | self.downsample = downsample
112 | self.stride = stride
113 |
114 | # Added another relu here
115 | self.relu2 = nn.ReLU(inplace=True)
116 |
117 | def forward(self, x):
118 | identity = x
119 |
120 | out = self.conv1(x)
121 | out = self.bn1(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv2(out)
125 | out = self.bn2(out)
126 |
127 | if self.downsample is not None:
128 | identity = self.downsample(x)
129 |
130 | out += identity
131 |
132 | # Modified to use relu2
133 | out = self.relu2(out)
134 |
135 | return out
136 |
137 | class Bottleneck(nn.Module):
138 | expansion = 4
139 |
140 | def __init__(self, inplanes, planes, stride=1, downsample=None):
141 | super(Bottleneck, self).__init__()
142 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
143 | self.bn1 = nn.BatchNorm2d(planes)
144 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
145 | self.bn2 = nn.BatchNorm2d(planes)
146 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
147 | self.bn3 = nn.BatchNorm2d(planes * 4)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.downsample = downsample
150 | self.stride = stride
151 |
152 | def forward(self, x):
153 | residual = x
154 |
155 | x = self.conv1(x)
156 | x = self.bn1(x)
157 | x = self.relu(x)
158 |
159 | x = self.conv2(x)
160 | x = self.bn2(x)
161 | x = self.relu(x)
162 |
163 | x = self.conv3(x)
164 | x = self.bn3(x)
165 |
166 | if self.downsample is not None:
167 | residual = self.downsample(residual)
168 |
169 | x += residual
170 | x = self.relu(x)
171 |
172 | return x
173 |
174 |
175 | class ResNet(nn.Module):
176 |
177 | def __init__(self, block, layers, num_classes=10,in_channels=3):
178 | self.inplanes = 64
179 | super(ResNet, self).__init__()
180 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
181 | self.bn1 = nn.BatchNorm2d(64)
182 | self.relu = nn.ReLU(inplace=True)
183 | self.layer1 = self._make_layer(block, 64, layers[0])
184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
185 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
186 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
187 | self.avgpool = nn.AvgPool2d(kernel_size=4)
188 | self.fc = nn.Linear(512 * block.expansion, num_classes)
189 |
190 | self.inter_feature = {}
191 | self.inter_gradient = {}
192 |
193 | self.register_all_hooks()
194 |
195 | for m in self.modules():
196 | if isinstance(m, nn.Conv2d):
197 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
198 | m.weight.data.normal_(0, math.sqrt(2. / n))
199 | elif isinstance(m, nn.BatchNorm2d):
200 | m.weight.data.fill_(1)
201 | m.bias.data.zero_()
202 |
203 | def _make_layer(self, block, planes, blocks, stride=1):
204 | downsample = None
205 | if stride != 1 or self.inplanes != planes * block.expansion:
206 | downsample = nn.Sequential(
207 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
208 | nn.BatchNorm2d(planes * block.expansion),
209 | )
210 |
211 | layers = []
212 | layers.append(block(self.inplanes, planes, stride, downsample))
213 | self.inplanes = planes * block.expansion
214 | for i in range(1, blocks):
215 | layers.append(block(self.inplanes, planes))
216 | return nn.Sequential(*layers)
217 |
218 | def forward(self, x):
219 | x = self.conv1(x)
220 | x = self.bn1(x)
221 | x = self.relu(x)
222 |
223 | x = self.layer1(x)
224 | x = self.layer2(x)
225 | x = self.layer3(x)
226 | x = self.layer4(x)
227 |
228 | x = self.avgpool(x)
229 | x = x.view(x.size(0), -1)
230 | x = self.fc(x)
231 |
232 | return x
233 |
234 | def get_fm(self, x):
235 | x = self.conv1(x)
236 | x = self.bn1(x)
237 | x = self.relu(x)
238 |
239 | x = self.layer1(x)
240 | x = self.layer2(x)
241 | x = self.layer3(x)
242 | x = self.layer4(x)
243 |
244 | #x = self.avgpool(x)
245 |
246 | return x
247 |
248 | def input_to_conv1(self, x):
249 | x = self.conv1(x)
250 |
251 | return x
252 |
253 | def conv1_to_output(self, x):
254 | #x = self.conv1(x)
255 | x = self.bn1(x)
256 | x = self.relu(x)
257 |
258 | x = self.layer1(x)
259 | x = self.layer2(x)
260 | x = self.layer3(x)
261 | x = self.layer4(x)
262 |
263 | x = self.avgpool(x)
264 | x = x.view(x.size(0), -1)
265 | x = self.fc(x)
266 | return x
267 |
268 | def input_to_layer1(self, x):
269 | x = self.conv1(x)
270 | x = self.bn1(x)
271 | x = self.relu(x)
272 |
273 | x = self.layer1(x)
274 |
275 | return x
276 |
277 | def layer1_to_output(self, x):
278 | #x = self.conv1(x)
279 | #x = self.bn1(x)
280 | #x = self.relu(x)
281 |
282 | #x = self.layer1(x)
283 | x = self.layer2(x)
284 | x = self.layer3(x)
285 | x = self.layer4(x)
286 |
287 | x = self.avgpool(x)
288 | x = x.view(x.size(0), -1)
289 | x = self.fc(x)
290 | return x
291 |
292 | def input_to_layer2(self, x):
293 | x = self.conv1(x)
294 | x = self.bn1(x)
295 | x = self.relu(x)
296 |
297 | x = self.layer1(x)
298 | x = self.layer2(x)
299 |
300 | return x
301 |
302 | def layer2_to_output(self, x):
303 | #x = self.conv1(x)
304 | #x = self.bn1(x)
305 | #x = self.relu(x)
306 |
307 | #x = self.layer1(x)
308 | #x = self.layer2(x)
309 | x = self.layer3(x)
310 | x = self.layer4(x)
311 |
312 | x = self.avgpool(x)
313 | x = x.view(x.size(0), -1)
314 | x = self.fc(x)
315 | return x
316 |
317 | def input_to_layer3(self, x):
318 | x = self.conv1(x)
319 | x = self.bn1(x)
320 | x = self.relu(x)
321 |
322 | x = self.layer1(x)
323 | x = self.layer2(x)
324 | x = self.layer3(x)
325 |
326 | return x
327 |
328 | def layer3_to_output(self, x):
329 |
330 | #x = self.conv1(x)
331 | #x = self.bn1(x)
332 | #x = self.relu(x)
333 |
334 | #x = self.layer1(x)
335 | #x = self.layer2(x)
336 | #x = self.layer3(x)
337 | x = self.layer4(x)
338 |
339 | x = self.avgpool(x)
340 | x = x.view(x.size(0), -1)
341 | x = self.fc(x)
342 | return x
343 |
344 | def input_to_layer4(self, x):
345 | x = self.conv1(x)
346 | x = self.bn1(x)
347 | x = self.relu(x)
348 |
349 | x = self.layer1(x)
350 | x = self.layer2(x)
351 | x = self.layer3(x)
352 | x = self.layer4(x)
353 |
354 | return x
355 |
356 | def layer4_to_output(self, x):
357 |
358 | x = self.avgpool(x)
359 | x = x.view(x.size(0), -1)
360 | x = self.fc(x)
361 | return x
362 |
363 | def make_hook(self, name, flag):
364 | if flag == 'forward':
365 | def hook(m, input, output):
366 | self.inter_feature[name] = output
367 | return hook
368 | elif flag == 'backward':
369 | def hook(m, input, output):
370 | self.inter_gradient[name] = output
371 | return hook
372 | else:
373 | assert False
374 |
375 | def register_all_hooks(self):
376 | self.conv1.register_forward_hook(self.make_hook("Conv1_Conv1_Conv1_", 'forward'))
377 | self.layer1[0].conv1.register_forward_hook(self.make_hook("Layer1_0_Conv1_", 'forward'))
378 | self.layer1[0].conv2.register_forward_hook(self.make_hook("Layer1_0_Conv2_", 'forward'))
379 | self.layer1[1].conv1.register_forward_hook(self.make_hook("Layer1_1_Conv1_", 'forward'))
380 | self.layer1[1].conv2.register_forward_hook(self.make_hook("Layer1_1_Conv2_", 'forward'))
381 |
382 | self.layer2[0].conv1.register_forward_hook(self.make_hook("Layer2_0_Conv1_", 'forward'))
383 | self.layer2[0].downsample.register_forward_hook(self.make_hook("Layer2_0_Downsample_", 'forward'))
384 | self.layer2[0].conv2.register_forward_hook(self.make_hook("Layer2_0_Conv2_", 'forward'))
385 | self.layer2[1].conv1.register_forward_hook(self.make_hook("Layer2_1_Conv1_", 'forward'))
386 | self.layer2[1].conv2.register_forward_hook(self.make_hook("Layer2_1_Conv2_", 'forward'))
387 |
388 | self.layer3[0].conv1.register_forward_hook(self.make_hook("Layer3_0_Conv1_", 'forward'))
389 | self.layer3[0].downsample.register_forward_hook(self.make_hook("Layer3_0_Downsample_", 'forward'))
390 | self.layer3[0].conv2.register_forward_hook(self.make_hook("Layer3_0_Conv2_", 'forward'))
391 | self.layer3[1].conv1.register_forward_hook(self.make_hook("Layer3_1_Conv1_", 'forward'))
392 | self.layer3[1].conv2.register_forward_hook(self.make_hook("Layer3_1_Conv2_", 'forward'))
393 |
394 | self.layer4[0].conv1.register_forward_hook(self.make_hook("Layer4_0_Conv1_", 'forward'))
395 | self.layer4[0].downsample.register_forward_hook(self.make_hook("Layer4_0_Downsample_", 'forward'))
396 | self.layer4[0].conv2.register_forward_hook(self.make_hook("Layer4_0_Conv2_", 'forward'))
397 | self.layer4[1].conv1.register_forward_hook(self.make_hook("Layer4_1_Conv1_", 'forward'))
398 | self.layer4[1].conv2.register_forward_hook(self.make_hook("Layer4_1_Conv2_", 'forward'))
399 |
400 |
401 |
402 | '''def get_all_inner_activation(self, x):
403 | inner_output_index = [0,2,4,8,10,12,16,18]
404 | inner_output_list = []
405 | for i in range(23):
406 | x = self.classifier[i](x)
407 | if i in inner_output_index:
408 | inner_output_list.append(x)
409 | x = x.view(x.size(0), self.num_classes)
410 | return x,inner_output_list'''
411 |
412 | #############################################################################
413 | def input_to_conv1(self, x):
414 | x = self.conv1(x)
415 | return x
416 |
417 | def conv1_to_output(self, x):
418 | x = self.bn1(x)
419 | x = self.relu(x)
420 |
421 | x = self.layer1(x)
422 | x = self.layer2(x)
423 | x = self.layer3(x)
424 | x = self.layer4(x)
425 |
426 | x = self.avgpool(x)
427 | x = x.view(x.size(0), -1)
428 | x = self.fc(x)
429 |
430 | return x
431 |
432 | #############################################################################
433 | def input_to_layer1_0_residual(self, x):
434 | x = self.conv1(x)
435 | x = self.bn1(x)
436 | x = self.relu(x)
437 |
438 | x = self.layer1[0].input_to_residual(x)
439 |
440 | return x
441 |
442 | def layer1_0_residual_to_output(self, residual, conv2):
443 |
444 | x = self.layer1[0].residual_to_output(residual,conv2)
445 | x = self.layer1[1](x)
446 | x = self.layer2(x)
447 | x = self.layer3(x)
448 | x = self.layer4(x)
449 |
450 | x = self.avgpool(x)
451 | x = x.view(x.size(0), -1)
452 | x = self.fc(x)
453 | return x
454 |
455 | def input_to_layer1_0_conv2(self, x):
456 | x = self.conv1(x)
457 | x = self.bn1(x)
458 | x = self.relu(x)
459 | x = self.layer1[0].input_to_conv2(x)
460 | return x
461 |
462 | def layer1_0_conv2_to_output(self, x, residual):
463 | x = self.layer1[0].conv2_to_output(x, residual)
464 | x = self.layer1[1](x)
465 | x = self.layer2(x)
466 | x = self.layer3(x)
467 | x = self.layer4(x)
468 | x = self.avgpool(x)
469 | x = x.view(x.size(0), -1)
470 | x = self.fc(x)
471 | return x
472 |
473 | def input_to_layer1_0_conv1(self, x):
474 | x = self.conv1(x)
475 | x = self.bn1(x)
476 | x = self.relu(x)
477 | x = self.layer1[0].input_to_conv1(x)
478 | return x
479 |
480 | def layer1_0_conv1_to_output(self, x, residual):
481 | x = self.layer1[0].conv1_to_output(x, residual)
482 | x = self.layer1[1](x)
483 | x = self.layer2(x)
484 | x = self.layer3(x)
485 | x = self.layer4(x)
486 | x = self.avgpool(x)
487 | x = x.view(x.size(0), -1)
488 | x = self.fc(x)
489 | return x
490 | #############################################################################
491 |
492 | def input_to_layer1_1_residual(self, x):
493 | x = self.conv1(x)
494 | x = self.bn1(x)
495 | x = self.relu(x)
496 | x = self.layer1[0](x)
497 | x = self.layer1[1].input_to_residual(x)
498 |
499 | return x
500 |
501 | def input_to_layer1_1_conv2(self, x):
502 | x = self.conv1(x)
503 | x = self.bn1(x)
504 | x = self.relu(x)
505 | x = self.layer1[0](x)
506 | x = self.layer1[1].input_to_conv2(x)
507 | return x
508 |
509 | def layer1_1_conv2_to_output(self, x, residual):
510 | x = self.layer1[1].conv2_to_output(x, residual)
511 | x = self.layer2(x)
512 | x = self.layer3(x)
513 | x = self.layer4(x)
514 | x = self.avgpool(x)
515 | x = x.view(x.size(0), -1)
516 | x = self.fc(x)
517 | return x
518 |
519 | def layer1_1_conv2_to_output_mask(self, x, residual,mask,pattern):
520 | x = self.layer1[1].conv2_to_output_mask(x, residual,mask,pattern)
521 | x = self.layer2(x)
522 | x = self.layer3(x)
523 | x = self.layer4(x)
524 | x = self.avgpool(x)
525 | x = x.view(x.size(0), -1)
526 | x = self.fc(x)
527 | return x
528 |
529 | def input_to_layer1_1_conv1(self, x):
530 | x = self.conv1(x)
531 | x = self.bn1(x)
532 | x = self.relu(x)
533 | x = self.layer1[0](x)
534 | x = self.layer1[1].input_to_conv1(x)
535 | return x
536 |
537 | def layer1_1_conv1_to_output(self, x, residual):
538 | x = self.layer1[1].conv1_to_output(x, residual)
539 | x = self.layer2(x)
540 | x = self.layer3(x)
541 | x = self.layer4(x)
542 | x = self.avgpool(x)
543 | x = x.view(x.size(0), -1)
544 | x = self.fc(x)
545 | return x
546 |
547 | #############################################################################
548 |
549 | #############################################################################
550 | def input_to_layer2_0_residual(self, x):
551 | x = self.conv1(x)
552 | x = self.bn1(x)
553 | x = self.relu(x)
554 |
555 | x = self.layer1(x)
556 | x = self.layer2[0].input_to_residual(x)
557 |
558 | return x
559 |
560 | def layer2_0_residual_to_output(self, residual, conv2):
561 |
562 | x = self.layer2[0].residual_to_output(residual,conv2)
563 | x = self.layer2[1](x)
564 | x = self.layer3(x)
565 | x = self.layer4(x)
566 |
567 | x = self.avgpool(x)
568 | x = x.view(x.size(0), -1)
569 | x = self.fc(x)
570 | return x
571 |
572 | def input_to_layer2_0_conv2(self, x):
573 | x = self.conv1(x)
574 | x = self.bn1(x)
575 | x = self.relu(x)
576 | x = self.layer1(x)
577 | x = self.layer2[0].input_to_conv2(x)
578 | return x
579 |
580 | def layer2_0_conv2_to_output(self, x, residual):
581 | x = self.layer2[0].conv2_to_output(x, residual)
582 | x = self.layer2[1](x)
583 | x = self.layer3(x)
584 | x = self.layer4(x)
585 | x = self.avgpool(x)
586 | x = x.view(x.size(0), -1)
587 | x = self.fc(x)
588 | return x
589 |
590 | def input_to_layer2_0_conv1(self, x):
591 | x = self.conv1(x)
592 | x = self.bn1(x)
593 | x = self.relu(x)
594 | x = self.layer1(x)
595 | x = self.layer2[0].input_to_conv1(x)
596 | return x
597 |
598 | def layer2_0_conv1_to_output(self, x, residual):
599 | x = self.layer2[0].conv1_to_output(x, residual)
600 | x = self.layer2[1](x)
601 | x = self.layer3(x)
602 | x = self.layer4(x)
603 | x = self.avgpool(x)
604 | x = x.view(x.size(0), -1)
605 | x = self.fc(x)
606 | return x
607 | #############################################################################
608 |
609 | def input_to_layer2_1_residual(self, x):
610 | x = self.conv1(x)
611 | x = self.bn1(x)
612 | x = self.relu(x)
613 |
614 | x = self.layer1(x)
615 | x = self.layer2[0](x)
616 | x = self.layer2[1].input_to_residual(x)
617 |
618 | return x
619 |
620 | def input_to_layer2_1_conv2(self, x):
621 | x = self.conv1(x)
622 | x = self.bn1(x)
623 | x = self.relu(x)
624 | x = self.layer1(x)
625 | x = self.layer2[0](x)
626 | x = self.layer2[1].input_to_conv2(x)
627 | return x
628 |
629 | def layer2_1_conv2_to_output(self, x, residual):
630 | x = self.layer2[1].conv2_to_output(x, residual)
631 | x = self.layer3(x)
632 | x = self.layer4(x)
633 | x = self.avgpool(x)
634 | x = x.view(x.size(0), -1)
635 | x = self.fc(x)
636 | return x
637 |
638 |
639 | def layer2_1_conv2_to_output_mask(self, x, residual,mask,pattern):
640 | x = self.layer2[1].conv2_to_output_mask(x, residual,mask,pattern)
641 | x = self.layer3(x)
642 | x = self.layer4(x)
643 | x = self.avgpool(x)
644 | x = x.view(x.size(0), -1)
645 | x = self.fc(x)
646 | return x
647 |
648 | def input_to_layer2_1_conv1(self, x):
649 | x = self.conv1(x)
650 | x = self.bn1(x)
651 | x = self.relu(x)
652 | x = self.layer1(x)
653 | x = self.layer2[0](x)
654 | x = self.layer2[1].input_to_conv1(x)
655 | return x
656 |
657 | def layer2_1_conv1_to_output(self, x, residual):
658 | x = self.layer2[1].conv1_to_output(x, residual)
659 | x = self.layer3(x)
660 | x = self.layer4(x)
661 | x = self.avgpool(x)
662 | x = x.view(x.size(0), -1)
663 | x = self.fc(x)
664 | return x
665 |
666 | #############################################################################
667 |
668 | #############################################################################
669 | def input_to_layer3_0_residual(self, x):
670 | x = self.conv1(x)
671 | x = self.bn1(x)
672 | x = self.relu(x)
673 |
674 | x = self.layer1(x)
675 | x = self.layer2(x)
676 | x = self.layer3[0].input_to_residual(x)
677 |
678 | return x
679 |
680 | def layer3_0_residual_to_output(self, residual, conv2):
681 |
682 | x = self.layer3[0].residual_to_output(residual,conv2)
683 | x = self.layer3[1](x)
684 | x = self.layer4(x)
685 |
686 | x = self.avgpool(x)
687 | x = x.view(x.size(0), -1)
688 | x = self.fc(x)
689 | return x
690 |
691 | def input_to_layer3_0_conv2(self, x):
692 | x = self.conv1(x)
693 | x = self.bn1(x)
694 | x = self.relu(x)
695 | x = self.layer1(x)
696 | x = self.layer2(x)
697 | x = self.layer3[0].input_to_conv2(x)
698 | return x
699 |
700 | def layer3_0_conv2_to_output(self, x, residual):
701 | x = self.layer3[0].conv2_to_output(x, residual)
702 | x = self.layer3[1](x)
703 | x = self.layer4(x)
704 | x = self.avgpool(x)
705 | x = x.view(x.size(0), -1)
706 | x = self.fc(x)
707 | return x
708 |
709 | def input_to_layer3_0_conv1(self, x):
710 | x = self.conv1(x)
711 | x = self.bn1(x)
712 | x = self.relu(x)
713 | x = self.layer1(x)
714 | x = self.layer2(x)
715 | x = self.layer3[0].input_to_conv1(x)
716 | return x
717 |
718 | def layer3_0_conv1_to_output(self, x, residual):
719 | x = self.layer3[0].conv1_to_output(x, residual)
720 | x = self.layer3[1](x)
721 | x = self.layer4(x)
722 | x = self.avgpool(x)
723 | x = x.view(x.size(0), -1)
724 | x = self.fc(x)
725 | return x
726 | #############################################################################
727 |
728 | def input_to_layer3_1_residual(self, x):
729 | x = self.conv1(x)
730 | x = self.bn1(x)
731 | x = self.relu(x)
732 |
733 | x = self.layer1(x)
734 | x = self.layer2(x)
735 | x = self.layer3[0](x)
736 | x = self.layer3[1].input_to_residual(x)
737 |
738 | return x
739 |
740 | def input_to_layer3_1_conv2(self, x):
741 | x = self.conv1(x)
742 | x = self.bn1(x)
743 | x = self.relu(x)
744 | x = self.layer1(x)
745 | x = self.layer2(x)
746 | x = self.layer3[0](x)
747 | x = self.layer3[1].input_to_conv2(x)
748 | return x
749 |
750 | def layer3_1_conv2_to_output(self, x, residual):
751 | x = self.layer3[1].conv2_to_output(x, residual)
752 | x = self.layer4(x)
753 | x = self.avgpool(x)
754 | x = x.view(x.size(0), -1)
755 | x = self.fc(x)
756 | return x
757 |
758 | def layer3_1_conv2_to_output_mask(self, x, residual,mask,pattern):
759 | x = self.layer3[1].conv2_to_output_mask(x, residual,mask,pattern)
760 | x = self.layer4(x)
761 | x = self.avgpool(x)
762 | x = x.view(x.size(0), -1)
763 | x = self.fc(x)
764 | return x
765 |
766 | def input_to_layer3_1_conv1(self, x):
767 | x = self.conv1(x)
768 | x = self.bn1(x)
769 | x = self.relu(x)
770 | x = self.layer1(x)
771 | x = self.layer2(x)
772 | x = self.layer3[0](x)
773 | x = self.layer3[1].input_to_conv1(x)
774 | return x
775 |
776 | def layer3_1_conv1_to_output(self, x, residual):
777 | x = self.layer3[1].conv1_to_output(x, residual)
778 | x = self.layer4(x)
779 | x = self.avgpool(x)
780 | x = x.view(x.size(0), -1)
781 | x = self.fc(x)
782 | return x
783 |
784 | #############################################################################
785 | def input_to_layer4_0_residual(self, x):
786 | x = self.conv1(x)
787 | x = self.bn1(x)
788 | x = self.relu(x)
789 |
790 | x = self.layer1(x)
791 | x = self.layer2(x)
792 | x = self.layer3(x)
793 | x = self.layer4[0].input_to_residual(x)
794 |
795 | return x
796 |
797 | def layer4_0_residual_to_output(self, residual, conv2):
798 |
799 | x = self.layer4[0].residual_to_output(residual,conv2)
800 | x = self.layer4[1](x)
801 |
802 | x = self.avgpool(x)
803 | x = x.view(x.size(0), -1)
804 | x = self.fc(x)
805 | return x
806 |
807 | def input_to_layer4_0_conv2(self, x):
808 | x = self.conv1(x)
809 | x = self.bn1(x)
810 | x = self.relu(x)
811 | x = self.layer1(x)
812 | x = self.layer2(x)
813 | x = self.layer3(x)
814 | x = self.layer4[0].input_to_conv2(x)
815 | return x
816 |
817 | def layer4_0_conv2_to_output(self, x, residual):
818 | x = self.layer4[0].conv2_to_output(x, residual)
819 | x = self.layer4[1](x)
820 | x = self.avgpool(x)
821 | x = x.view(x.size(0), -1)
822 | x = self.fc(x)
823 | return x
824 |
825 | def input_to_layer4_0_conv1(self, x):
826 | x = self.conv1(x)
827 | x = self.bn1(x)
828 | x = self.relu(x)
829 | x = self.layer1(x)
830 | x = self.layer2(x)
831 | x = self.layer3(x)
832 | x = self.layer4[0].input_to_conv1(x)
833 | return x
834 |
835 | def layer4_0_conv1_to_output(self, x, residual):
836 | x = self.layer4[0].conv1_to_output(x, residual)
837 | x = self.layer4[1](x)
838 | x = self.avgpool(x)
839 | x = x.view(x.size(0), -1)
840 | x = self.fc(x)
841 | return x
842 | #############################################################################
843 | def input_to_layer4_1_residual(self, x):
844 | x = self.conv1(x)
845 | x = self.bn1(x)
846 | x = self.relu(x)
847 |
848 | x = self.layer1(x)
849 | x = self.layer2(x)
850 | x = self.layer3(x)
851 | x = self.layer4[0](x)
852 | x = self.layer4[1].input_to_residual(x)
853 |
854 | return x
855 |
856 | def input_to_layer4_1_conv2(self, x):
857 | x = self.conv1(x)
858 | x = self.bn1(x)
859 | x = self.relu(x)
860 | x = self.layer1(x)
861 | x = self.layer2(x)
862 | x = self.layer3(x)
863 | x = self.layer4[0](x)
864 | x = self.layer4[1].input_to_conv2(x)
865 | return x
866 |
867 | def layer4_1_conv2_to_output(self, x, residual):
868 | x = self.layer4[1].conv2_to_output(x, residual)
869 | x = self.avgpool(x)
870 | x = x.view(x.size(0), -1)
871 | x = self.fc(x)
872 | return x
873 |
874 | def layer4_1_conv2_to_output_mask(self, x, residual,mask,pattern):
875 | x = self.layer4[1].conv2_to_output_mask(x, residual,mask,pattern)
876 | x = self.avgpool(x)
877 | x = x.view(x.size(0), -1)
878 | x = self.fc(x)
879 | return x
880 |
881 | def input_to_layer4_1_conv1(self, x):
882 | x = self.conv1(x)
883 | x = self.bn1(x)
884 | x = self.relu(x)
885 | x = self.layer1(x)
886 | x = self.layer2(x)
887 | x = self.layer3(x)
888 | x = self.layer4[0](x)
889 | x = self.layer4[1].input_to_conv1(x)
890 | return x
891 |
892 | def layer4_1_conv1_to_output(self, x, residual):
893 | x = self.layer4[1].conv1_to_output(x, residual)
894 | x = self.avgpool(x)
895 | x = x.view(x.size(0), -1)
896 | x = self.fc(x)
897 | return x
898 | #############################################################################
899 |
900 | def resnet18(**kwargs):
901 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
902 |
903 |
904 | def resnet34(**kwargs):
905 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
906 |
907 |
908 | def resnet50(**kwargs):
909 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
910 |
911 |
912 | def resnet101(**kwargs):
913 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
914 |
915 |
916 | def resnet152(**kwargs):
917 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
--------------------------------------------------------------------------------
/train_models/train_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 | from time import time
5 |
6 | import config
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | import torchvision
11 | from torch import nn
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torchvision.transforms import RandomErasing
14 | from dataloader import PostTensorTransform, get_dataloader,get_dataloader_random_ratio
15 | from resnet_nole import *
16 |
17 | import random
18 |
19 |
20 | class Normalize:
21 | def __init__(self, opt, expected_values, variance):
22 | self.n_channels = opt.input_channel
23 | self.expected_values = expected_values
24 | self.variance = variance
25 | assert self.n_channels == len(self.expected_values)
26 |
27 | def __call__(self, x):
28 | x_clone = x.clone()
29 | for channel in range(self.n_channels):
30 | x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
31 | return x_clone
32 |
33 |
34 | class Denormalize:
35 | def __init__(self, opt, expected_values, variance):
36 | self.n_channels = opt.input_channel
37 | self.expected_values = expected_values
38 | self.variance = variance
39 | assert self.n_channels == len(self.expected_values)
40 |
41 | def __call__(self, x):
42 | x_clone = x.clone()
43 | for channel in range(self.n_channels):
44 | x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
45 | return x_clone
46 |
47 |
48 | class Normalizer:
49 | def __init__(self, opt):
50 | self.normalizer = self._get_normalizer(opt)
51 |
52 | def _get_normalizer(self, opt):
53 | if opt.dataset == "cifar10":
54 | normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
55 | else:
56 | raise Exception("Invalid dataset")
57 | return normalizer
58 |
59 | def __call__(self, x):
60 | if self.normalizer:
61 | x = self.normalizer(x)
62 | return x
63 |
64 |
65 | class Denormalizer:
66 | def __init__(self, opt):
67 | self.denormalizer = self._get_denormalizer(opt)
68 |
69 | def _get_denormalizer(self, opt):
70 | print(opt.dataset)
71 | if opt.dataset == "cifar10":
72 | denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
73 | else:
74 | raise Exception("Invalid dataset")
75 | return denormalizer
76 |
77 | def __call__(self, x):
78 | if self.denormalizer:
79 | x = self.denormalizer(x)
80 | return x
81 |
82 |
83 | def get_model(opt):
84 | netC = None
85 | optimizerC = None
86 | schedulerC = None
87 |
88 | if opt.set_arch:
89 |
90 | if opt.set_arch=="resnet18":
91 | netC = resnet18(num_classes = opt.num_classes, in_channels = opt.input_channel)
92 | netC = netC.to(opt.device)
93 |
94 | optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4)
95 |
96 | schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
97 |
98 | return netC, optimizerC, schedulerC
99 |
100 |
101 | def train(train_transform, netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt):
102 | print(" Train:")
103 |
104 | netC.train()
105 | rate_bd = opt.pc
106 | total_loss_ce = 0
107 | total_sample = 0
108 |
109 | total_clean = 0
110 | total_bd = 0
111 | total_cross = 0
112 | total_clean_correct = 0
113 | total_bd_correct = 0
114 | total_cross_correct = 0
115 | criterion_CE = torch.nn.CrossEntropyLoss()
116 | criterion_BCE = torch.nn.BCELoss()
117 |
118 | denormalizer = Denormalizer(opt)
119 | transforms = PostTensorTransform(opt).to(opt.device)
120 | total_time = 0
121 |
122 | avg_acc_cross = 0
123 |
124 | for batch_idx, (inputs, targets) in enumerate(train_dl):
125 | optimizerC.zero_grad()
126 |
127 | inputs, targets = inputs.to(opt.device), targets.to(opt.device)
128 | bs = inputs.shape[0]
129 |
130 | num_bd = int(bs * rate_bd)
131 | num_cross = int(num_bd * opt.cross_ratio)
132 | grid_temps = (identity_grid + opt.s * noise_grid / opt.input_height) * opt.grid_rescale
133 | grid_temps = torch.clamp(grid_temps, -1, 1)
134 |
135 | ins = torch.rand(num_cross, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1
136 | grid_temps2 = grid_temps.repeat(num_cross, 1, 1, 1) + ins / opt.input_height
137 | grid_temps2 = torch.clamp(grid_temps2, -1, 1)
138 |
139 | if num_bd!=0:
140 |
141 | inputs_bd = F.grid_sample(inputs[:num_bd], grid_temps.repeat(num_bd, 1, 1, 1), align_corners=True)
142 |
143 | if opt.attack_mode == "all2one":
144 | targets_bd = torch.ones_like(targets[:num_bd]) * opt.target_label
145 | if opt.attack_mode == "all2all":
146 | targets_bd = torch.remainder(targets[:num_bd] + 1, opt.num_classes)
147 |
148 | inputs_cross = F.grid_sample(inputs[num_bd : (num_bd + num_cross)], grid_temps2, align_corners=True)
149 |
150 | if (num_bd==0 and num_cross==0):
151 | total_inputs = inputs
152 | total_targets = targets
153 | else:
154 | total_inputs = torch.cat([inputs_bd, inputs_cross, inputs[(num_bd + num_cross) :]], dim=0)
155 | total_targets = torch.cat([targets_bd, targets[num_bd:]], dim=0)
156 |
157 | total_inputs = transforms(total_inputs)
158 | start = time()
159 | total_preds = netC(total_inputs)
160 | total_time += time() - start
161 |
162 | loss_ce = criterion_CE(total_preds, total_targets)
163 |
164 | loss = loss_ce
165 | loss.backward()
166 |
167 | optimizerC.step()
168 |
169 | total_sample += bs
170 | total_loss_ce += loss_ce.detach()
171 |
172 | total_clean += bs - num_bd - num_cross
173 | total_bd += num_bd
174 | total_cross += num_cross
175 | total_clean_correct += torch.sum(
176 | torch.argmax(total_preds[(num_bd + num_cross) :], dim=1) == total_targets[(num_bd + num_cross) :]
177 | )
178 | if num_bd:
179 | total_bd_correct += torch.sum(torch.argmax(total_preds[:num_bd], dim=1) == targets_bd)
180 | avg_acc_bd = total_bd_correct * 100.0 / total_bd
181 | else:
182 | avg_acc_bd = 0
183 |
184 | if num_cross:
185 | total_cross_correct += torch.sum(
186 | torch.argmax(total_preds[num_bd : (num_bd + num_cross)], dim=1)
187 | == total_targets[num_bd : (num_bd + num_cross)]
188 | )
189 | avg_acc_cross = total_cross_correct * 100.0 / total_cross
190 | else:
191 | avg_acc_cross = 0
192 |
193 | avg_acc_clean = total_clean_correct * 100.0 / total_clean
194 | avg_loss_ce = total_loss_ce / total_sample
195 |
196 | # Save image for debugging
197 | if not batch_idx % 50:
198 | if not os.path.exists(opt.temps):
199 | os.makedirs(opt.temps)
200 |
201 | path = os.path.join(opt.temps, "backdoor_image.png")
202 | path_cross = os.path.join(opt.temps, "cross_image.png")
203 | if num_bd>0:
204 | torchvision.utils.save_image(inputs_bd, path, normalize=True)
205 | if num_cross>0:
206 | torchvision.utils.save_image(inputs_cross, path_cross, normalize=True)
207 |
208 | if (num_bd>0 and num_cross==0):
209 | print(
210 | batch_idx,
211 | len(train_dl),
212 | "CE Loss: {:.4f} | Clean Acc: {:.4f} | Bd Acc: {:.4f}".format(
213 | avg_loss_ce, avg_acc_clean, avg_acc_bd,
214 | ))
215 | if (num_bd>0 and num_cross>0):
216 | print(
217 | batch_idx,
218 | len(train_dl),
219 | "CE Loss: {:.4f} | Clean Acc: {:.4f} | Bd Acc: {:.4f} | Cross Acc: {:.4f}".format(
220 | avg_loss_ce, avg_acc_clean, avg_acc_bd, avg_acc_cross
221 | ))
222 | else:
223 | print(
224 | batch_idx,
225 | len(train_dl),
226 | "CE Loss: {:.4f} | Clean Acc: {:.4f}".format(avg_loss_ce, avg_acc_clean))
227 | # Image for tensorboard
228 | if batch_idx == len(train_dl) - 2:
229 | if num_bd>0:
230 | residual = inputs_bd - inputs[:num_bd]
231 | batch_img = torch.cat([inputs[:num_bd], inputs_bd, total_inputs[:num_bd], residual], dim=2)
232 | batch_img = denormalizer(batch_img)
233 | batch_img = F.upsample(batch_img, scale_factor=(4, 4))
234 | grid = torchvision.utils.make_grid(batch_img, normalize=True)
235 | path = os.path.join(opt.temps, "batch_img.png")
236 | torchvision.utils.save_image(batch_img, path, normalize=True)
237 |
238 | # for tensorboard
239 | if not epoch % 1:
240 | tf_writer.add_scalars(
241 | "Clean Accuracy", {"Clean": avg_acc_clean, "Bd": avg_acc_bd, "Cross": avg_acc_cross}, epoch
242 | )
243 | if num_bd>0:
244 | tf_writer.add_image("Images", grid, global_step=epoch)
245 |
246 | schedulerC.step()
247 |
248 |
249 | def eval(
250 | test_transform,
251 | netC,
252 | optimizerC,
253 | schedulerC,
254 | test_dl,
255 | noise_grid,
256 | identity_grid,
257 | best_clean_acc,
258 | best_bd_acc,
259 | best_cross_acc,
260 | tf_writer,
261 | epoch,
262 | opt,
263 | ):
264 | print(" Eval:")
265 |
266 | netC.eval()
267 |
268 | total_sample = 0
269 | total_clean_correct = 0
270 | total_bd_correct = 0
271 | total_cross_correct = 0
272 | total_ae_loss = 0
273 |
274 | criterion_BCE = torch.nn.BCELoss()
275 |
276 | for batch_idx, (inputs, targets) in enumerate(test_dl):
277 | with torch.no_grad():
278 | inputs, targets = inputs.to(opt.device), targets.to(opt.device)
279 | #inputs = test_transform(inputs)
280 | bs = inputs.shape[0]
281 | total_sample += bs
282 |
283 | # Evaluate Clean
284 | preds_clean = netC(inputs)
285 | total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
286 |
287 | # Evaluate Backdoor
288 | grid_temps = (identity_grid + opt.s * noise_grid / opt.input_height) * opt.grid_rescale
289 | grid_temps = torch.clamp(grid_temps, -1, 1)
290 |
291 | ins = torch.rand(bs, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1
292 | grid_temps2 = grid_temps.repeat(bs, 1, 1, 1) + ins / opt.input_height
293 | grid_temps2 = torch.clamp(grid_temps2, -1, 1)
294 |
295 | inputs_bd = F.grid_sample(inputs, grid_temps.repeat(bs, 1, 1, 1), align_corners=True)
296 |
297 | if opt.attack_mode == "all2one":
298 | targets_bd = torch.ones_like(targets) * opt.target_label
299 | if opt.attack_mode == "all2all":
300 | targets_bd = torch.remainder(targets + 1, opt.num_classes)
301 |
302 | preds_bd = netC(inputs_bd)
303 | total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)
304 |
305 | acc_clean = total_clean_correct * 100.0 / total_sample
306 | acc_bd = total_bd_correct * 100.0 / total_sample
307 |
308 | # Evaluate cross
309 | if opt.cross_ratio:
310 | inputs_cross = F.grid_sample(inputs, grid_temps2, align_corners=True)
311 | preds_cross = netC(inputs_cross)
312 | total_cross_correct += torch.sum(torch.argmax(preds_cross, 1) == targets)
313 |
314 | acc_cross = total_cross_correct * 100.0 / total_sample
315 |
316 | info_string = (
317 | "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f} | Cross: {:.4f}".format(
318 | acc_clean, best_clean_acc, acc_bd, best_bd_acc, acc_cross, best_cross_acc
319 | )
320 | )
321 | else:
322 | info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(
323 | acc_clean, best_clean_acc, acc_bd, best_bd_acc
324 | )
325 | print(batch_idx, len(test_dl), info_string)
326 |
327 |
328 | # tensorboard
329 | if not epoch % 1:
330 | tf_writer.add_scalars("Test Accuracy", {"Clean": acc_clean, "Bd": acc_bd}, epoch)
331 |
332 | # Save checkpoint
333 | if acc_clean > best_clean_acc or (acc_clean > best_clean_acc - 0.1 and acc_bd > best_bd_acc):
334 | print(" Saving...")
335 | best_clean_acc = acc_clean
336 | best_bd_acc = acc_bd
337 | if opt.cross_ratio:
338 | best_cross_acc = acc_cross
339 | else:
340 | best_cross_acc = torch.tensor([0])
341 | state_dict = {
342 | "netC": netC.state_dict(),
343 | "schedulerC": schedulerC.state_dict(),
344 | "optimizerC": optimizerC.state_dict(),
345 | "best_clean_acc": best_clean_acc,
346 | "best_bd_acc": best_bd_acc,
347 | "best_cross_acc": best_cross_acc,
348 | "epoch_current": epoch,
349 | "identity_grid": identity_grid,
350 | "noise_grid": noise_grid,
351 | }
352 | torch.save(state_dict, opt.ckpt_path)
353 | with open(os.path.join(opt.ckpt_folder, "results.txt"), "w+") as f:
354 | results_dict = {
355 | "clean_acc": best_clean_acc.item(),
356 | "bd_acc": best_bd_acc.item(),
357 | "cross_acc": best_cross_acc.item(),
358 | }
359 | json.dump(results_dict, f, indent=2)
360 |
361 | return best_clean_acc, best_bd_acc, best_cross_acc
362 |
363 |
364 | def main():
365 | opt = config.get_arguments().parse_args()
366 |
367 | if opt.dataset in ["cifar10"]:
368 | opt.num_classes = 10
369 |
370 |
371 | if opt.dataset == "cifar10":
372 | opt.input_height = 32
373 | opt.input_width = 32
374 | opt.input_channel = 3
375 |
376 |
377 | # Dataset
378 |
379 | opt.random_ratio = 0.95
380 | train_dl, train_transform = get_dataloader_random_ratio(opt, True)
381 | test_dl, test_transform = get_dataloader(opt, False)
382 |
383 | # prepare model
384 | netC, optimizerC, schedulerC = get_model(opt)
385 |
386 | # Load pretrained model
387 | mode = opt.attack_mode
388 | opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
389 | if opt.set_arch:
390 | opt.ckpt_folder = opt.ckpt_folder + "/neurips_wanet/" + opt.set_arch + "_" + opt.extra_flag+"_"+str(opt.target_label)
391 | else:
392 | opt.ckpt_folder = opt.ckpt_folder + "/neurips_wanet/" + opt.extra_flag+"_"+str(opt.target_label)
393 | opt.ckpt_path = os.path.join(opt.ckpt_folder, "{}_{}_morph_wanet.pth.tar".format(opt.dataset, mode))
394 | opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")
395 | if not os.path.exists(opt.log_dir):
396 | os.makedirs(opt.log_dir)
397 |
398 | if opt.continue_training:
399 | if os.path.exists(opt.ckpt_path):
400 | print("Continue training!!")
401 | state_dict = torch.load(opt.ckpt_path)
402 | netC.load_state_dict(state_dict["netC"])
403 | optimizerC.load_state_dict(state_dict["optimizerC"])
404 | schedulerC.load_state_dict(state_dict["schedulerC"])
405 | best_clean_acc = state_dict["best_clean_acc"]
406 | best_bd_acc = state_dict["best_bd_acc"]
407 | best_cross_acc = state_dict["best_cross_acc"]
408 | epoch_current = state_dict["epoch_current"]
409 | identity_grid = state_dict["identity_grid"]
410 | noise_grid = state_dict["noise_grid"]
411 | tf_writer = SummaryWriter(log_dir=opt.log_dir)
412 | else:
413 | print("Pretrained model doesnt exist")
414 | exit()
415 | else:
416 | print("Train from scratch!!!")
417 | best_clean_acc = 0.0
418 | best_bd_acc = 0.0
419 | best_cross_acc = 0.0
420 | epoch_current = 0
421 |
422 | # Prepare grid
423 | ins = torch.rand(1, 2, opt.k, opt.k) * 2 - 1
424 | ins = ins / torch.mean(torch.abs(ins))
425 | noise_grid = (
426 | F.upsample(ins, size=opt.input_height, mode="bicubic", align_corners=True)
427 | .permute(0, 2, 3, 1)
428 | .to(opt.device)
429 | )
430 | array1d = torch.linspace(-1, 1, steps=opt.input_height)
431 | x, y = torch.meshgrid(array1d, array1d)
432 | identity_grid = torch.stack((y, x), 2)[None, ...].to(opt.device)
433 |
434 | shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
435 | os.makedirs(opt.log_dir)
436 | with open(os.path.join(opt.ckpt_folder, "opt.json"), "w+") as f:
437 | json.dump(opt.__dict__, f, indent=2)
438 | tf_writer = SummaryWriter(log_dir=opt.log_dir)
439 |
440 |
441 | for epoch in range(epoch_current, opt.n_iters):
442 | print("Epoch {}:".format(epoch + 1))
443 | train(train_transform,netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt)
444 | best_clean_acc, best_bd_acc, best_cross_acc = eval(
445 | test_transform,
446 | netC,
447 | optimizerC,
448 | schedulerC,
449 | test_dl,
450 | noise_grid,
451 | identity_grid,
452 | best_clean_acc,
453 | best_bd_acc,
454 | best_cross_acc,
455 | tf_writer,
456 | epoch,
457 | opt,
458 | )
459 |
460 | if opt.save_all:
461 | if (epoch)%opt.save_freq == 0:
462 | state_dict = {
463 | "netC": netC.state_dict(),
464 | "schedulerC": schedulerC.state_dict(),
465 | "optimizerC": optimizerC.state_dict(),
466 | "epoch_current": epoch,
467 | }
468 | epoch_path = os.path.join(opt.ckpt_folder, "{}_{}_epoch{}.pth.tar".format(opt.dataset, mode,epoch))
469 | torch.save(state_dict, epoch_path)
470 |
471 |
472 | if __name__ == "__main__":
473 | main()
474 |
--------------------------------------------------------------------------------
/unet_blocks.py:
--------------------------------------------------------------------------------
1 | """
2 | Class definitions for a standard U-Net Up-and Down-sampling blocks
3 | http://arxiv.org/abs/1505.0.397
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class EncoderBlock(nn.Module):
11 | """
12 | Instances the Encoder block that forms a part of a U-Net
13 | Parameters:
14 | in_channels (int): Depth (or number of channels) of the tensor that the block acts on
15 | filter_num (int) : Number of filters used in the convolution ops inside the block,
16 | depth of the output of the enc block
17 | dropout(bool) : Flag to decide whether a dropout layer should be applied
18 | dropout_rate (float) : Probability of dropping a convolution output feature channel
19 | """
20 | def __init__(self, filter_num=64, in_channels=1, dropout=False, dropout_rate=0.3):
21 |
22 | super(EncoderBlock,self).__init__()
23 | self.filter_num = int(filter_num)
24 | self.in_channels = int(in_channels)
25 | self.dropout = dropout
26 | self.dropout_rate = dropout_rate
27 |
28 | self.conv1 = nn.Conv2d(in_channels=self.in_channels,
29 | out_channels=self.filter_num,
30 | kernel_size=3,
31 | padding=1)
32 |
33 | self.conv2 = nn.Conv2d(in_channels=self.filter_num,
34 | out_channels=self.filter_num,
35 | kernel_size=3,
36 | padding=1)
37 |
38 | self.bn_op_1 = nn.InstanceNorm2d(num_features=self.filter_num, affine=True)
39 | self.bn_op_2 = nn.InstanceNorm2d(num_features=self.filter_num, affine=True)
40 |
41 | # Use Dropout ops as nn.Module instead of nn.functional definition
42 | # So using .train() and .eval() flags, can modify their behavior for MC-Dropout
43 | if dropout is True:
44 | self.dropout_1 = nn.Dropout(p=dropout_rate)
45 | self.dropout_2 = nn.Dropout(p=dropout_rate)
46 |
47 | def apply_manual_dropout_mask(self, x, seed):
48 | # Mask size : [Batch_size, Channels, Height, Width]
49 | dropout_mask = torch.bernoulli(input=torch.empty(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).fill_(self.dropout_rate),
50 | generator=torch.Generator().manual_seed(seed))
51 |
52 | x = x*dropout_mask.to(x.device)
53 |
54 | return x
55 |
56 | def forward(self, x, seeds=None):
57 |
58 | if seeds is not None:
59 | assert(seeds.shape[0] == 2)
60 |
61 | x = self.conv1(x)
62 | x = self.bn_op_1(x)
63 | x = F.leaky_relu(x)
64 | if self.dropout is True:
65 | if seeds is None:
66 | x = self.dropout_1(x)
67 | else:
68 | x = self.apply_manual_dropout_mask(x, seeds[0].item())
69 |
70 | x = self.conv2(x)
71 | x = self.bn_op_2(x)
72 | x = F.leaky_relu(x)
73 | if self.dropout is True:
74 | if seeds is None:
75 | x = self.dropout_2(x)
76 | else:
77 | x = self.apply_manual_dropout_mask(x, seeds[1].item())
78 |
79 | return x
80 |
81 |
82 | class DecoderBlock(nn.Module):
83 | """
84 | Decoder block used in the U-Net
85 | Parameters:
86 | in_channels (int) : Number of channels of the incoming tensor for the upsampling op
87 | concat_layer_depth (int) : Number of channels to be concatenated via skip connections
88 | filter_num (int) : Number of filters used in convolution, the depth of the output of the dec block
89 | interpolate (bool) : Decides if upsampling needs to performed via interpolation or transposed convolution
90 | dropout(bool) : Flag to decide whether a dropout layer should be applied
91 | dropout_rate (float) : Probability of dropping a convolution output feature channel
92 | """
93 | def __init__(self, in_channels, concat_layer_depth, filter_num, interpolate=False, dropout=False, dropout_rate=0.3):
94 |
95 | # Up-sampling (interpolation or transposed conv) --> EncoderBlock
96 | super(DecoderBlock, self).__init__()
97 | self.filter_num = int(filter_num)
98 | self.in_channels = int(in_channels)
99 | self.concat_layer_depth = int(concat_layer_depth)
100 | self.interpolate = interpolate
101 | self.dropout = dropout
102 | self.dropout_rate = dropout_rate
103 |
104 | # Upsample by interpolation followed by a 3x3 convolution to obtain desired depth
105 | self.up_sample_interpolate = nn.Sequential(nn.Upsample(scale_factor=2,
106 | mode='bilinear',
107 | align_corners=True),
108 |
109 | nn.Conv2d(in_channels=self.in_channels,
110 | out_channels=self.in_channels,
111 | kernel_size=3,
112 | padding=1)
113 | )
114 |
115 | # Upsample via transposed convolution (know to produce artifacts)
116 | self.up_sample_tranposed = nn.ConvTranspose2d(in_channels=self.in_channels,
117 | out_channels=self.in_channels,
118 | kernel_size=3,
119 | stride=2,
120 | padding=1,
121 | output_padding=1)
122 |
123 | self.down_sample = EncoderBlock(in_channels=self.in_channels+self.concat_layer_depth,
124 | filter_num=self.filter_num,
125 | dropout=self.dropout,
126 | dropout_rate=self.dropout_rate)
127 |
128 | def forward(self, x, skip_layer, seeds=None):
129 | if self.interpolate is True:
130 | up_sample_out = F.leaky_relu(self.up_sample_interpolate(x))
131 | else:
132 | up_sample_out = F.leaky_relu(self.up_sample_tranposed(x))
133 |
134 | merged_out = torch.cat([up_sample_out, skip_layer], dim=1)
135 | out = self.down_sample(merged_out, seeds=seeds)
136 | return out
137 |
138 |
139 | class EncoderBlock3D(nn.Module):
140 |
141 | """
142 | Instances the 3D Encoder block that forms a part of a 3D U-Net
143 | Parameters:
144 | in_channels (int): Depth (or number of channels) of the tensor that the block acts on
145 | filter_num (int) : Number of filters used in the convolution ops inside the block,
146 | depth of the output of the enc block
147 | """
148 | def __init__(self, filter_num=64, in_channels=1, dropout=False):
149 |
150 | super(EncoderBlock3D, self).__init__()
151 | self.filter_num = int(filter_num)
152 | self.in_channels = int(in_channels)
153 | self.dropout = dropout
154 |
155 | self.conv1 = nn.Conv3d(in_channels=self.in_channels,
156 | out_channels=self.filter_num,
157 | kernel_size=3,
158 | padding=1)
159 |
160 | self.conv2 = nn.Conv3d(in_channels=self.filter_num,
161 | out_channels=self.filter_num*2,
162 | kernel_size=3,
163 | padding=1)
164 |
165 | self.bn_op_1 = nn.InstanceNorm3d(num_features=self.filter_num)
166 | self.bn_op_2 = nn.InstanceNorm3d(num_features=self.filter_num*2)
167 |
168 | def forward(self, x):
169 |
170 | x = self.conv1(x)
171 | x = self.bn_op_1(x)
172 | x = F.leaky_relu(x)
173 | if self.dropout is True:
174 | x = F.dropout3d(x, p=0.3)
175 |
176 | x = self.conv2(x)
177 | x = self.bn_op_2(x)
178 | x = F.leaky_relu(x)
179 |
180 | if self.dropout is True:
181 | x = F.dropout3d(x, p=0.3)
182 |
183 | return x
184 |
185 |
186 | class DecoderBlock3D(nn.Module):
187 | """
188 | Decoder block used in the 3D U-Net
189 | Parameters:
190 | in_channels (int) : Number of channels of the incoming tensor for the upsampling op
191 | concat_layer_depth (int) : Number of channels to be concatenated via skip connections
192 | filter_num (int) : Number of filters used in convolution, the depth of the output of the dec block
193 | interpolate (bool) : Decides if upsampling needs to performed via interpolation or transposed convolution
194 | """
195 | def __init__(self, in_channels, concat_layer_depth, filter_num, interpolate=False, dropout=False):
196 |
197 | super(DecoderBlock3D, self).__init__()
198 | self.filter_num = int(filter_num)
199 | self.in_channels = int(in_channels)
200 | self.concat_layer_depth = int(concat_layer_depth)
201 | self.interpolate = interpolate
202 | self.dropout = dropout
203 |
204 | # Upsample by interpolation followed by a 3x3x3 convolution to obtain desired depth
205 | self.up_sample_interpolate = nn.Sequential(nn.Upsample(scale_factor=2,
206 | mode='nearest'),
207 |
208 | nn.Conv3d(in_channels=self.in_channels,
209 | out_channels=self.in_channels,
210 | kernel_size=3,
211 | padding=1)
212 | )
213 |
214 | # Upsample via transposed convolution (know to produce artifacts)
215 | self.up_sample_transposed = nn.ConvTranspose3d(in_channels=self.in_channels,
216 | out_channels=self.in_channels,
217 | kernel_size=3,
218 | stride=2,
219 | padding=1,
220 | output_padding=1)
221 |
222 | if self.dropout is True:
223 | self.down_sample = nn.Sequential(nn.Conv3d(in_channels=self.in_channels+self.concat_layer_depth,
224 | out_channels=self.filter_num,
225 | kernel_size=3,
226 | padding=1),
227 |
228 | nn.InstanceNorm3d(num_features=self.filter_num),
229 |
230 | nn.LeakyReLU(),
231 |
232 | nn.Dropout3d(p=0.3),
233 |
234 | nn.Conv3d(in_channels=self.filter_num,
235 | out_channels=self.filter_num,
236 | kernel_size=3,
237 | padding=1),
238 |
239 | nn.InstanceNorm3d(num_features=self.filter_num),
240 |
241 | nn.LeakyReLU(),
242 |
243 | nn.Dropout3d(p=0.3))
244 | else:
245 | self.down_sample = nn.Sequential(nn.Conv3d(in_channels=self.in_channels+self.concat_layer_depth,
246 | out_channels=self.filter_num,
247 | kernel_size=3,
248 | padding=1),
249 |
250 | nn.InstanceNorm3d(num_features=self.filter_num),
251 |
252 | nn.LeakyReLU(),
253 |
254 | nn.Conv3d(in_channels=self.filter_num,
255 | out_channels=self.filter_num,
256 | kernel_size=3,
257 | padding=1),
258 |
259 | nn.InstanceNorm3d(num_features=self.filter_num),
260 |
261 | nn.LeakyReLU())
262 |
263 | def forward(self, x, skip_layer):
264 |
265 | if self.interpolate is True:
266 | up_sample_out = F.leaky_relu(self.up_sample_interpolate(x))
267 | else:
268 | up_sample_out = F.leaky_relu(self.up_sample_transposed(x))
269 |
270 | merged_out = torch.cat([up_sample_out, skip_layer], dim=1)
271 | out = self.down_sample(merged_out)
272 | return out
--------------------------------------------------------------------------------
/unet_model.py:
--------------------------------------------------------------------------------
1 | """
2 | A PyTorch Implementation of a U-Net.
3 | Supports 2D (https://arxiv.org/abs/1505.04597) and 3D(https://arxiv.org/abs/1606.06650) variants
4 | Author: Ishaan Bhat
5 | Email: ishaan@isi.uu.nl
6 | """
7 | from unet_blocks import *
8 | from math import pow
9 |
10 |
11 | class UNet(nn.Module):
12 | """
13 | PyTorch class definition for the U-Net architecture for image segmentation
14 | Parameters:
15 | n_channels (int) : Number of image channels
16 | base_filter_num (int) : Number of filters for the first convolution (doubled for every subsequent block)
17 | num_blocks (int) : Number of encoder/decoder blocks
18 | num_classes(int) : Number of classes that need to be segmented
19 | mode (str): 2D or 3D
20 | use_pooling (bool): Set to 'True' to use MaxPool as downnsampling op.
21 | If 'False', strided convolution would be used to downsample feature maps (http://arxiv.org/abs/1908.02182)
22 | dropout (bool) : Whether dropout should be added to central encoder and decoder blocks (eg: BayesianSegNet)
23 | dropout_rate (float) : Dropout probability
24 | Returns:
25 | out (torch.Tensor) : Prediction of the segmentation map
26 | """
27 | def __init__(self, n_channels=1, base_filter_num=64, num_blocks=4, num_classes=5, mode='2D', dropout=False, dropout_rate=0.3, use_pooling=True):
28 |
29 | super(UNet, self).__init__()
30 | self.contracting_path = nn.ModuleList()
31 | self.expanding_path = nn.ModuleList()
32 | self.downsampling_ops = nn.ModuleList()
33 |
34 | self.num_blocks = num_blocks
35 | self.n_channels = int(n_channels)
36 | self.n_classes = int(num_classes)
37 | self.base_filter_num = int(base_filter_num)
38 | self.enc_layer_depths = [] # Keep track of the output depths of each encoder block
39 | self.mode = mode
40 | self.pooling = use_pooling
41 | self.dropout = dropout
42 | self.dropout_rate = dropout_rate
43 |
44 | if mode == '2D':
45 | self.encoder = EncoderBlock
46 | self.decoder = DecoderBlock
47 | self.pool = nn.MaxPool2d
48 |
49 | elif mode == '3D':
50 | self.encoder = EncoderBlock3D
51 | self.decoder = DecoderBlock3D
52 | self.pool = nn.MaxPool3d
53 | else:
54 | print('{} mode is invalid'.format(mode))
55 |
56 | for block_id in range(num_blocks):
57 | # Due to GPU mem constraints, we cap the filter depth at 512
58 | enc_block_filter_num = min(int(pow(2, block_id)*self.base_filter_num), 512) # Output depth of current encoder stage of the 2-D variant
59 | if block_id == 0:
60 | enc_in_channels = self.n_channels
61 | else:
62 | if self.mode == '2D':
63 | if int(pow(2, block_id)*self.base_filter_num) <= 512:
64 | enc_in_channels = enc_block_filter_num//2
65 | else:
66 | enc_in_channels = 512
67 | else:
68 | enc_in_channels = enc_block_filter_num # In the 3D UNet arch, the encoder features double in the 2nd convolution op
69 |
70 |
71 | # Dropout only applied to central encoder blocks -- See BayesianSegNet by Kendall et al.
72 | if self.dropout is True and block_id >= num_blocks-2:
73 | self.contracting_path.append(self.encoder(in_channels=enc_in_channels,
74 | filter_num=enc_block_filter_num,
75 | dropout=True,
76 | dropout_rate=self.dropout_rate))
77 | else:
78 | self.contracting_path.append(self.encoder(in_channels=enc_in_channels,
79 | filter_num=enc_block_filter_num,
80 | dropout=False))
81 | if self.mode == '2D':
82 | self.enc_layer_depths.append(enc_block_filter_num)
83 | if self.pooling is False:
84 | self.downsampling_ops.append(nn.Sequential(nn.Conv2d(in_channels=self.enc_layer_depths[-1],
85 | out_channels=self.enc_layer_depths[-1],
86 | kernel_size=3,
87 | stride=2,
88 | padding=1),
89 | nn.InstanceNorm2d(num_features=self.filter_num),
90 | nn.LeakyReLU()))
91 | else:
92 | self.enc_layer_depths.append(enc_block_filter_num*2) # Specific to 3D U-Net architecture (due to doubling of #feature_maps inside the 3-D Encoder)
93 | if self.pooling is False:
94 | self.downsampling_ops.append(nn.Sequential(nn.Conv3d(in_channels=self.enc_layer_depths[-1],
95 | out_channels=self.enc_layer_depths[-1],
96 | kernel_size=3,
97 | stride=2,
98 | padding=1),
99 | nn.InstanceNorm3d(num_features=self.enc_layer_depths[-1]),
100 | nn.LeakyReLU()))
101 |
102 | # Bottleneck layer
103 | if self.mode == '2D':
104 | bottle_neck_filter_num = self.enc_layer_depths[-1]*2
105 | bottle_neck_in_channels = self.enc_layer_depths[-1]
106 | self.bottle_neck_layer = self.encoder(filter_num=bottle_neck_filter_num,
107 | in_channels=bottle_neck_in_channels)
108 |
109 | else: # Modified for the 3D UNet architecture
110 | bottle_neck_in_channels = self.enc_layer_depths[-1]
111 | bottle_neck_filter_num = self.enc_layer_depths[-1]*2
112 | self.bottle_neck_layer = nn.Sequential(nn.Conv3d(in_channels=bottle_neck_in_channels,
113 | out_channels=bottle_neck_in_channels,
114 | kernel_size=3,
115 | padding=1),
116 |
117 | nn.InstanceNorm3d(num_features=bottle_neck_in_channels),
118 |
119 | nn.LeakyReLU(),
120 |
121 | nn.Conv3d(in_channels=bottle_neck_in_channels,
122 | out_channels=bottle_neck_filter_num,
123 | kernel_size=3,
124 | padding=1),
125 |
126 | nn.InstanceNorm3d(num_features=bottle_neck_filter_num),
127 |
128 | nn.LeakyReLU())
129 |
130 | # Decoder Path
131 | dec_in_channels = int(bottle_neck_filter_num)
132 | for block_id in range(num_blocks):
133 | if self.dropout is True and block_id < 2:
134 | self.expanding_path.append(self.decoder(in_channels=dec_in_channels,
135 | filter_num=self.enc_layer_depths[-1-block_id],
136 | concat_layer_depth=self.enc_layer_depths[-1-block_id],
137 | interpolate=False,
138 | dropout=True,
139 | dropout_rate=self.dropout_rate))
140 | else:
141 | self.expanding_path.append(self.decoder(in_channels=dec_in_channels,
142 | filter_num=self.enc_layer_depths[-1-block_id],
143 | concat_layer_depth=self.enc_layer_depths[-1-block_id],
144 | interpolate=False,
145 | dropout=False))
146 |
147 | dec_in_channels = self.enc_layer_depths[-1-block_id]
148 |
149 | # Output Layer
150 | if mode == '2D':
151 | self.output = nn.Conv2d(in_channels=int(self.enc_layer_depths[0]),
152 | out_channels=self.n_classes,
153 | kernel_size=1)
154 | else:
155 | self.output = nn.Conv3d(in_channels=int(self.enc_layer_depths[0]),
156 | out_channels=self.n_classes,
157 | kernel_size=1)
158 |
159 | def forward(self, x, seeds=None):
160 |
161 | if self.mode == '2D':
162 | h, w = x.shape[-2:]
163 | else:
164 | d, h, w = x.shape[-3:]
165 |
166 | # Encoder
167 | enc_outputs = []
168 | seed_index = 0
169 | for stage, enc_op in enumerate(self.contracting_path):
170 | if stage >= len(self.contracting_path) - 2:
171 | if seeds is not None:
172 | x = enc_op(x, seeds[seed_index:seed_index+2])
173 | else:
174 | x = enc_op(x)
175 | seed_index += 2 # 2 seeds required per block
176 | else:
177 | x = enc_op(x)
178 | enc_outputs.append(x)
179 |
180 | if self.pooling is True:
181 | x = self.pool(kernel_size=2)(x)
182 | else:
183 | x = self.downsampling_ops[stage](x)
184 |
185 | # Bottle-neck layer
186 | x = self.bottle_neck_layer(x)
187 | # Decoder
188 | for block_id, dec_op in enumerate(self.expanding_path):
189 | if block_id < 2:
190 | if seeds is not None:
191 | x = dec_op(x, enc_outputs[-1-block_id], seeds[seed_index:seed_index+2])
192 | else:
193 | x = dec_op(x, enc_outputs[-1-block_id])
194 | seed_index += 2
195 | else:
196 | x = dec_op(x, enc_outputs[-1-block_id])
197 |
198 |
199 | # Output
200 | x = self.output(x)
201 |
202 | return x
--------------------------------------------------------------------------------