├── cfg.py ├── presentation.pdf ├── dataset.py ├── statistic.py ├── evaluation.py ├── README.md ├── train_base.py ├── train_adv.py ├── core.py ├── resnet.py └── _jit_internal.py /cfg.py: -------------------------------------------------------------------------------- 1 | nb_class = 110 2 | img_size = 299 3 | crop_size = 267 4 | 5 | root = '/home/zzd/Dataset/IJCAI_2019_train_299' -------------------------------------------------------------------------------- /presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzd1992/Adversarial-Defense-by-Suppressing-High-Frequencies/HEAD/presentation.pdf -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | import os 4 | 5 | class DatasetFile(Dataset): 6 | def __init__(self, root, file, transform=None, target_transform=None): 7 | self.samples = [] 8 | self.root = root 9 | with open(file, 'r') as f: 10 | lines = f.readlines() 11 | for line in lines: 12 | img, label = line.strip().split(',') 13 | self.samples.append((img, int(label))) 14 | 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | 18 | def loader(self, path): 19 | with open(path, 'rb') as f: 20 | img = Image.open(f) 21 | return img.convert('RGB') 22 | 23 | def __getitem__(self, index): 24 | path, target = self.samples[index] 25 | image = self.loader(os.path.join(self.root, path)) 26 | if self.transform is not None: 27 | image = self.transform(image) 28 | 29 | return image, target 30 | 31 | def __len__(self): 32 | return len(self.samples) 33 | 34 | -------------------------------------------------------------------------------- /statistic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class SpectralEnergy(torch.nn.Module): 6 | def __init__(self, w, h): 7 | super(SpectralEnergy, self).__init__() 8 | self.w = w 9 | self.h = h 10 | self._make_interval() 11 | 12 | def _make_interval(self): 13 | if self.w % 2 == 0: 14 | self.w_s, self.w_r = 0, self.w // 2 15 | else: 16 | self.w_s, self.w_r = 1, self.w // 2 17 | 18 | if self.h % 2 == 0: 19 | self.h_s, self.h_r = 0, self.h // 2 20 | else: 21 | self.h_s, self.h_r = 1, self.h // 2 22 | 23 | def forward(self, x): 24 | y = torch.rfft(x, 2, onesided=False, normalized=True) 25 | y = y[..., 0] ** 2 + y[..., 1] ** 2 26 | y[:, :, self.h_s:self.h_s+self.h_r] += y[:, :, -self.h_r:][:, :, torch.arange(self.h_r-1, -1, -1)] 27 | y[..., self.w_s:self.w_s+self.w_r] += y[..., -self.w_r:][..., torch.arange(self.w_r-1, -1, -1)] 28 | 29 | return y[:, :, :self.h_s+self.h_r, :self.w_s+self.w_r] 30 | 31 | 32 | def spectral_energy(x): 33 | ''' 34 | :param x: numpy array 35 | :return: 36 | ''' 37 | y = np.cumsum(x, 0) 38 | y = np.cumsum(y, 1) 39 | y = np.diag(y) 40 | 41 | return y / y[-1] 42 | 43 | 44 | if __name__=='__main__': 45 | N = 25 46 | batchsize = 4 47 | F = SpectralEnergy(N, N) 48 | x = torch.rand(batchsize, 3, N, N) 49 | y = F(x) 50 | y = y.mean(0).mean(0) 51 | y = spectral_energy(y.numpy()) 52 | print(y) 53 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def clean(net, loader): 4 | net.eval() 5 | with torch.no_grad(): 6 | positive = 0.0 7 | for num, (img, label) in enumerate(loader): 8 | img = img.cuda() 9 | label = label.cuda() 10 | preds = net(img) 11 | preds = torch.argmax(preds, -1) 12 | acc = preds.eq(label) 13 | acc = acc.float().sum() 14 | positive += acc.item() 15 | 16 | return positive 17 | 18 | def white_box(net, loader, criterion): 19 | net.eval() 20 | with torch.no_grad(): 21 | positive = 0.0 22 | for num, (img, label) in enumerate(loader): 23 | img = img.cuda() 24 | label = label.cuda() 25 | 26 | preds = net(img) 27 | img_adv = criterion.PGD_L2(net, img, preds) 28 | 29 | preds_adv = net(img_adv) 30 | preds_adv = torch.argmax(preds_adv, -1) 31 | acc = preds_adv.eq(label) 32 | acc = acc.float().sum() 33 | positive += acc.item() 34 | 35 | return positive 36 | 37 | def black_box(net_defense, net_attack, loader, criterion): 38 | net_defense.eval() 39 | net_attack.eval() 40 | with torch.no_grad(): 41 | positive = 0.0 42 | for num, (img, label) in enumerate(loader): 43 | img = img.cuda() 44 | label = label.cuda() 45 | 46 | preds = net_attack(img) 47 | img_adv = criterion.PGD_L2(net_attack, img, preds) 48 | 49 | preds_adv = net_defense(img_adv) 50 | preds_adv = torch.argmax(preds_adv, -1) 51 | acc = preds_adv.eq(label) 52 | acc = acc.float().sum() 53 | positive += acc.item() 54 | 55 | return positive 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial defense by suppressing high frequencies 2 | We develop a **high frequency suppression module** based on discrete Fourier transform which is used for adversarial defense. It is **efficient, differentiable and controllable**. Together with adversarial training, we won the fifth place of the [IJCAI-2019 Alibaba Adversarial AI Challenge](https://security.alibaba.com/alibs2019) (AAAC). This project is a minimum implementation of our solution. 3 | 4 | The motivation of our solution is that adversarial perturbations are dominated by high frequencies while information on clean images converges on low frequencies. Thus, if we suppress high frequencies of adversarial images, the effects of adversarial perturbations will be reduced while the basic information on clean images will be preserved. Our module is processed in frequency domain. See details in our [technical report](https://arxiv.org/abs/1908.06566) and presentation. 5 | 6 | # Results on AAAC 7 | There are about 110,000 images with 110 categories for electric business. The goal is to give a right classification of an image under adversarial perturbations which are generated by top-5 **black box attackers**. The score of a defense model for an image is measured as 0 if misclassified else L2 norm of perturbations. 8 | 9 | | high frequency suppression | adversarial training | model ensemble | score | 10 | | :------: | :------: | :------: | ------: | 11 | | no | no | no | 2.04 | 12 | | no | yes | no | 9.99 | 13 | | yes | no | no | 14.97 | 14 | | yes | yes | no | 19.05 | 15 | | yes | yes | yes | 19.75 | 16 | 17 | As we can see, our high frequency suppression module works well. It is even better than adversarial training on this challenge. The official leaderboard is [here](https://tianchi.aliyun.com/competition/entrance/231701/rankingList/5). 18 | 19 | # How to use our code 20 | ### Requirements 21 | PyTorch >= 0.4.0 22 | 23 | ### Prepare your data 24 | First, modify the meta information in `cfg.py`. `root` means the root path of your data. `crop_size` is the size which images are cropped to during training. We suggest you resize your images into a fixed size before training. 25 | 26 | Then, generate text files of the training data and validation data. Each line of the file records the relative path and the label of an image. The label is an integer started from 0. Here is an example of a line: 27 | ``` 28 | image_00000.jpg,0 29 | ``` 30 | ### Run the scripts 31 | Suppose your training data is recorded in `train.txt` and your validation data is recorded in `valid.txt`. Then train the base model (without adversarial training): 32 | ```bash 33 | python train_base.py train.txt valid.txt 34 | ``` 35 | The model file will be saved. Suppose it is `base.pth`. Then train the final model (with adversarial training): 36 | ```bash 37 | python train_adv.py train.txt valid.txt -pth base.pth 38 | ``` 39 | See the help of `train_base.py` and `train_adv.py` for more details. The default network is **ResNet-18**. 40 | 41 | # Citation 42 | If you find our method is useful, please cite our technical report: 43 | ``` 44 | @article{zhang2019adversarial, 45 | title={Adversarial Defense by Suppressing High-frequency Components.}, 46 | author={Zhang, Zhendong and Jung, Cheolkon and Liang, Xiaolong}, 47 | journal={arXiv: Computer Vision and Pattern Recognition}, 48 | year={2019}} 49 | ``` 50 | -------------------------------------------------------------------------------- /train_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | from torch import nn 5 | import os, argparse, copy, time 6 | from torch.utils.data import DataLoader 7 | from resnet import ResNet18 8 | from dataset import DatasetFile 9 | import cfg 10 | import evaluation 11 | 12 | parse = argparse.ArgumentParser() 13 | parse.add_argument('train_file', help='file which contains the path and label of training images') 14 | parse.add_argument('valid_file', help='file which contains the path and label of validation images') 15 | parse.add_argument('-seed', default=0, type=int, help='random seed') 16 | parse.add_argument('-workers', default=4, type=int, help='workers for data load') 17 | parse.add_argument('-bs', default=64, type=int, help='batch size for training') 18 | parse.add_argument('-epoch', default=40, type=int, help='training epoch') 19 | parse.add_argument('-lr', default=0.1, type=float, help='learning rate') 20 | parse.add_argument('-wd', default=1e-4, type=float, help='weight decay') 21 | parse.add_argument('-init_channels', default=64, type=int, help='channels of the first block of ResNet-18') 22 | parse.add_argument('-r', default=16, type=int, help='radius of our supression module') 23 | parse.add_argument('-pth', default=None, help='pre-trained model file') 24 | args = parse.parse_args() 25 | print(args) 26 | 27 | if not os.path.isdir("checkpoints"): os.mkdir("checkpoints") 28 | 29 | transform_train = transforms.Compose([ 30 | transforms.RandomCrop(cfg.crop_size), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), 33 | transforms.ToTensor(), 34 | ]) 35 | 36 | transform_test = transforms.Compose([ 37 | transforms.CenterCrop(cfg.crop_size), 38 | transforms.ToTensor(), 39 | ]) 40 | 41 | torch.backends.cudnn.benchmark = True 42 | 43 | trainset = DatasetFile(cfg.root, args.train_file, transform=transform_train) 44 | trainloader = DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 45 | 46 | testset = DatasetFile(cfg.root, args.valid_file, transform=transform_test) 47 | testloader = DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 48 | 49 | net = ResNet18(dim=cfg.nb_class, r=args.r, c=args.init_channels) 50 | if args.pth is not None: 51 | net.load_state_dict(torch.load(args.pth)) 52 | net.cuda() 53 | net.train() 54 | opt = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd) 55 | ce_loss = nn.CrossEntropyLoss() 56 | 57 | epoch = args.epoch 58 | best_acc = 0.0 59 | 60 | for i in range(epoch): 61 | if i + 1 in [int(epoch * 0.5), int(epoch * 0.8)]: 62 | for param_group in opt.param_groups: 63 | param_group['lr'] /= 10 64 | error, t = 0.0, time.time() 65 | net.train() 66 | for num, (img, label) in enumerate(trainloader): 67 | opt.zero_grad() 68 | img = img.cuda() 69 | label = label.cuda() 70 | preds = net(img) 71 | loss = ce_loss(preds, label) 72 | loss.backward() 73 | opt.step() 74 | error += loss.item() 75 | print("{}th epoch: {:.5f}\t{:.1f}s".format(i, error/(num+1), time.time()-t)) 76 | 77 | if i%2==0: 78 | accs = evaluation.clean(net, testloader) 79 | accs /= len(testset) 80 | print("metric: {:.4f}".format(accs)) 81 | if accs > best_acc: 82 | best_acc = copy.deepcopy(accs) 83 | if i >= int(epoch * 0.65): 84 | torch.save(net.state_dict(), "checkpoints/base_r{}_epoch{}.pth".format(args.r, i)) 85 | 86 | torch.save(net.state_dict(), "checkpoints/base_r{}_epoch{}.pth".format(args.r, i)) 87 | -------------------------------------------------------------------------------- /train_adv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | from torch import nn 5 | import os, argparse, copy, time 6 | from dataset import DatasetFile 7 | from torch.utils.data import DataLoader 8 | from resnet import ResNet18 9 | from core import Trades 10 | import cfg 11 | import evaluation 12 | 13 | parse = argparse.ArgumentParser() 14 | parse.add_argument('train_file', help='file of training images') 15 | parse.add_argument('valid_file', help='file of validation images') 16 | parse.add_argument('-pth', default=None, help='pre-trained model file (from train_base.py)') 17 | parse.add_argument("-seed", default=0, type=int, help='random seed') 18 | parse.add_argument('-workers', default=4, type=int, help='workers for data load') 19 | parse.add_argument('-bs', default=64, type=int, help='batch size for training') 20 | parse.add_argument('-epoch', default=40, type=int, help='training epoch') 21 | parse.add_argument('-lr', default=0.1, type=float, help='learning rate') 22 | parse.add_argument('-wd', default=1e-4, type=float, help='weight decay') 23 | parse.add_argument('-init_channels', default=64, type=int, help='channels of the first block of ResNet-18') 24 | parse.add_argument('-r', default=16, type=int, help='radius of high frequency suppression module') 25 | parse.add_argument('-epsilon', default=0.05, type=float, help='distance constraint for adversarial perturbations') 26 | parse.add_argument('-step_size', default=0.0075, type=float, help='step size for adversarial sample generation') 27 | parse.add_argument('-perturb_steps', default=3, type=int, help='iterations for adversarial sample generation') 28 | parse.add_argument('-beta', default=1.0, type=float, help='weight of adversarial loss') 29 | args = parse.parse_args() 30 | print(args) 31 | 32 | if not os.path.isdir("checkpoints"): os.mkdir("checkpoints") 33 | 34 | transform_train = transforms.Compose([ 35 | transforms.RandomCrop(cfg.crop_size), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), 38 | transforms.ToTensor(), 39 | ]) 40 | 41 | transform_test = transforms.Compose([ 42 | transforms.CenterCrop(cfg.crop_size), 43 | transforms.ToTensor(), 44 | ]) 45 | 46 | torch.backends.cudnn.benchmark = True 47 | 48 | trainset = DatasetFile(cfg.root, args.train_file, transform=transform_train) 49 | trainloader = DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 50 | 51 | testset = DatasetFile(cfg.root, args.valid_file, transform=transform_test) 52 | testloader = DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 53 | 54 | net = ResNet18(cfg.nb_class) 55 | if args.pth is not None: 56 | net.load_state_dict(torch.load(args.pth)) 57 | net.cuda() 58 | net.train() 59 | 60 | opt = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd) 61 | criterion = Trades(args.step_size, args.epsilon, args.perturb_steps, args.beta) 62 | epoch = args.epoch 63 | best_acc = 0.0 64 | 65 | for i in range(epoch): 66 | if i + 1 in [int(epoch * 0.4), int(epoch * 0.8)]: 67 | for param_group in opt.param_groups: 68 | param_group['lr'] /= 10 69 | error, t = 0.0, time.time() 70 | net.train() 71 | for num, (img, label) in enumerate(trainloader): 72 | opt.zero_grad() 73 | img = img.cuda() 74 | label = label.cuda() 75 | 76 | preds = net(img) 77 | preds_detach = preds.detach() 78 | img_adv = criterion.PGD_L2(net, img, preds_detach) 79 | loss = criterion.loss(net, preds, img_adv, label, opt) 80 | loss.backward() 81 | opt.step() 82 | error += loss.item() 83 | 84 | print("{}th epoch: {:.5f}\t{:.1f}s".format(i, error/(num+1), time.time()-t)) 85 | 86 | if i%2 == 0: 87 | accs = evaluation.clean(net, testloader) 88 | accs /= len(testset) 89 | print("metric: {:.4f}".format(accs)) 90 | if accs > best_acc: 91 | best_acc = copy.deepcopy(accs) 92 | if i >= int(epoch * 0.65): 93 | torch.save(net.state_dict(), "checkpoints/adv_r{}_epoch{}.pth".format(args.r, i)) 94 | 95 | torch.save(net.state_dict(), "checkpoints/adv_r{}_epoch{}.pth".format(args.r, i)) 96 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from _jit_internal import weak_script_method 6 | 7 | 8 | def squared_l2_norm(x): 9 | flattened = x.view(x.shape[0], -1) 10 | return (flattened ** 2).mean(1) 11 | 12 | def l2_norm(x): 13 | return squared_l2_norm(x).sqrt() 14 | 15 | class Trades: 16 | def __init__(self, step_size=0.003, epsilon=0.047, perturb_steps=5, beta=1.0): 17 | self.step_size = step_size 18 | self.epsilon = epsilon 19 | self.perturb_steps = perturb_steps 20 | self.beta = beta 21 | self.criterion_kl = nn.KLDivLoss(reduction="batchmean") 22 | 23 | def reset_steps(self, k): 24 | self.perturb_steps = k 25 | 26 | @weak_script_method 27 | def PGD_L2(self, model, x_natural, logits): 28 | model.eval() 29 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape, device='cuda').detach() 30 | prob = F.softmax(logits, dim=-1) 31 | 32 | for _ in range(self.perturb_steps): 33 | with torch.enable_grad(): 34 | x_adv.requires_grad_() 35 | loss_kl = self.criterion_kl(F.log_softmax(model(x_adv), dim=1), prob) 36 | grad = torch.autograd.grad(loss_kl, [x_adv])[0].detach() 37 | grad /= l2_norm(grad).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + 1e-8 38 | x_adv = x_adv.detach() + self.step_size * grad 39 | 40 | delta = x_adv - x_natural 41 | delta_norm = l2_norm(delta) 42 | cond = delta_norm > self.epsilon 43 | delta[cond] *= self.epsilon / delta_norm[cond].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 44 | x_adv = x_natural + delta 45 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 46 | 47 | return x_adv 48 | 49 | @weak_script_method 50 | def PGD_Linf(self, model, x_natural, logits): 51 | model.eval() 52 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape, device='cuda').detach() 53 | prob = F.softmax(logits, dim=-1) 54 | 55 | for _ in range(self.perturb_steps): 56 | x_adv.requires_grad_() 57 | with torch.enable_grad(): 58 | loss_kl = self.criterion_kl(F.log_softmax(model(x_adv), dim=1), prob) 59 | grad = torch.autograd.grad(loss_kl, [x_adv])[0].detach() 60 | x_adv = x_adv.detach() + self.step_size * torch.sign(grad) 61 | x_adv = torch.min(torch.max(x_adv, x_natural - self.epsilon), x_natural + self.epsilon) 62 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 63 | 64 | return x_adv 65 | 66 | @weak_script_method 67 | def loss(self, model, logits, x_adv, labels, optimizer): 68 | model.train() 69 | optimizer.zero_grad() 70 | prob = F.softmax(logits, dim=-1) 71 | loss_natural = F.cross_entropy(logits, labels) 72 | loss_robust = self.criterion_kl(F.log_softmax(model(x_adv), dim=1), prob) 73 | loss = loss_natural + self.beta * loss_robust 74 | 75 | return loss 76 | 77 | 78 | class HighFreqSuppress(torch.nn.Module): 79 | def __init__(self, w, h, r): 80 | super(HighFreqSuppress, self).__init__() 81 | self.w = w 82 | self.h = h 83 | self.r = r 84 | self.templete() 85 | 86 | def templete(self): 87 | temp = np.zeros((self.w, self.h), "float32") 88 | cw = self.w // 2 89 | ch = self.h // 2 90 | if self.w % 2 == 0: 91 | dw = self.r 92 | else: 93 | dw = self.r + 1 94 | 95 | if self.h % 2 == 0: 96 | dh = self.r 97 | else: 98 | dh = self.r + 1 99 | 100 | temp[cw - self.r:cw + dw, ch - self.r:ch + dh] = 1.0 101 | temp = np.roll(temp, -cw, axis=0) 102 | temp = np.roll(temp, -ch, axis=1) 103 | temp = torch.tensor(temp) 104 | temp = temp.unsqueeze(0).unsqueeze(0).unsqueeze(-1) 105 | self.temp = temp 106 | if torch.cuda.is_available(): 107 | self.temp = self.temp.cuda() 108 | 109 | @weak_script_method 110 | def forward(self, x): 111 | x_hat = torch.rfft(x, 2, onesided=False) 112 | x_hat = x_hat * self.temp 113 | y = torch.irfft(x_hat, 2, onesided=False) 114 | 115 | return y 116 | 117 | def extra_repr(self): 118 | return 'feature_width={}, feature_height={}, radius={}'.format(self.w, self.h, self.r) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from core import HighFreqSuppress 5 | from copy import deepcopy 6 | from _jit_internal import weak_script_method 7 | import cfg 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion*planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion*planes) 25 | ) 26 | 27 | @weak_script_method 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(self.expansion*planes) 53 | ) 54 | 55 | @weak_script_method 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, c=64, num_classes=10, r=16): 67 | super(ResNet, self).__init__() 68 | self.in_planes = deepcopy(c) 69 | self.hfs_train = HighFreqSuppress(cfg.crop_size, cfg.crop_size, r) 70 | self.hfs_eval = HighFreqSuppress(cfg.img_size, cfg.img_size, r) 71 | self.head = head(c) 72 | self.layer1 = self._make_layer(block, c, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 2*c, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 4*c, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 8*c, num_blocks[3], stride=2) 76 | 77 | self.linear = nn.Linear(8*c*block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | @weak_script_method 88 | def forward(self, x, training=True): 89 | if training: 90 | out = self.hfs_train(x) 91 | else: 92 | out = self.hfs_eval(x) 93 | 94 | out = self.head(out) 95 | out = self.layer1(out) 96 | out = self.layer2(out) 97 | out = self.layer3(out) 98 | out = self.layer4(out) 99 | out = out.mean(-1).mean(-1) 100 | out = self.linear(out) 101 | 102 | return out 103 | 104 | def head(c): 105 | head = nn.Sequential(nn.Conv2d(3, c//2, kernel_size=5, stride=2, bias=False), 106 | nn.BatchNorm2d(c//2), 107 | nn.ReLU(True), 108 | nn.Conv2d(c//2, c, kernel_size=3, bias=False), 109 | nn.AvgPool2d(3, 2), 110 | nn.BatchNorm2d(c), 111 | nn.ReLU(True)) 112 | return head 113 | 114 | def ResNet18(dim, c=64, r=16): 115 | return ResNet(BasicBlock, [2,2,2,2], c, dim, r) 116 | 117 | def ResNet34(dim, c=64, r=16): 118 | return ResNet(BasicBlock, [3,4,6,3], c, dim, r) 119 | 120 | def ResNet50(dim, c=64, r=16): 121 | return ResNet(Bottleneck, [3,4,6,3], c, dim, r) 122 | 123 | 124 | if __name__=='__main__': 125 | net = ResNet18(10, c=64) 126 | print(net) 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /_jit_internal.py: -------------------------------------------------------------------------------- 1 | """ 2 | The weak_script annotation needs to be here instead of inside torch/jit/ so it 3 | can be used in other places in torch/ (namely torch.nn) without running into 4 | circular dependency problems 5 | """ 6 | 7 | import weakref 8 | import inspect 9 | from torch._six import builtins 10 | 11 | # Tracks standalone weak script functions 12 | _compiled_weak_fns = weakref.WeakKeyDictionary() 13 | 14 | # Tracks which methods should be converted to strong methods 15 | _weak_script_methods = weakref.WeakKeyDictionary() 16 | 17 | # Converted modules and their corresponding WeakScriptModuleProxy objects 18 | _weak_modules = weakref.WeakKeyDictionary() 19 | 20 | # Types that have been declared as weak modules 21 | _weak_types = weakref.WeakKeyDictionary() 22 | 23 | # Wrapper functions that can call either of 2 functions depending on a boolean 24 | # argument 25 | _boolean_dispatched = weakref.WeakKeyDictionary() 26 | 27 | COMPILATION_PENDING = object() 28 | COMPILED = object() 29 | 30 | 31 | def createResolutionCallback(frames_up=0): 32 | """ 33 | Creates a function which, given a string variable name, 34 | returns the value of the variable in the scope of the caller of 35 | the function which called createResolutionCallback (by default). 36 | 37 | This is used to enable access in-scope Python variables inside 38 | TorchScript fragments. 39 | 40 | frames_up is number of additional frames to go up on the stack. 41 | The default value is 0, which correspond to the frame of the caller 42 | of createResolutionCallback. Also for example, if frames_up is set 43 | to 1, then the frame of the caller's caller of createResolutionCallback 44 | will be taken. 45 | 46 | For example, the following program prints 2:: 47 | 48 | def bar(): 49 | cb = createResolutionCallback(1) 50 | print(cb("foo")) 51 | 52 | def baz(): 53 | foo = 2 54 | bar() 55 | 56 | baz() 57 | """ 58 | frame = inspect.currentframe() 59 | i = 0 60 | while i < frames_up + 1: 61 | frame = frame.f_back 62 | i += 1 63 | 64 | f_locals = frame.f_locals 65 | f_globals = frame.f_globals 66 | 67 | def env(key): 68 | if key in f_locals: 69 | return f_locals[key] 70 | elif key in f_globals: 71 | return f_globals[key] 72 | elif hasattr(builtins, key): 73 | return getattr(builtins, key) 74 | else: 75 | return None 76 | 77 | return env 78 | 79 | 80 | def weak_script(fn, _frames_up=0): 81 | """ 82 | Marks a function as a weak script function. When used in a script function 83 | or ScriptModule, the weak script function will be lazily compiled and 84 | inlined in the graph. When not used in a script function, the weak script 85 | annotation has no effect. 86 | """ 87 | _compiled_weak_fns[fn] = { 88 | "status": COMPILATION_PENDING, 89 | "compiled_fn": None, 90 | "rcb": createResolutionCallback(_frames_up + 1) 91 | } 92 | return fn 93 | 94 | 95 | def weak_module(cls): 96 | _weak_types[cls] = { 97 | "method_stubs": None 98 | } 99 | return cls 100 | 101 | 102 | def weak_script_method(fn): 103 | _weak_script_methods[fn] = { 104 | "rcb": createResolutionCallback(frames_up=2), 105 | "original_method": fn 106 | } 107 | return fn 108 | 109 | 110 | def boolean_dispatch(arg_name, arg_index, default, if_true, if_false): 111 | """ 112 | Dispatches to either of 2 weak script functions based on a boolean argument. 113 | In TorchScript, the boolean argument must be constant so that the correct 114 | function to use can be determined at compile time. 115 | """ 116 | if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None: 117 | raise RuntimeError("both functions must be weak script") 118 | 119 | def fn(*args, **kwargs): 120 | dispatch_flag = False 121 | if arg_name in kwargs: 122 | dispatch_flag = kwargs[arg_name] 123 | elif arg_index < len(args): 124 | dispatch_flag = args[arg_index] 125 | 126 | if dispatch_flag: 127 | return if_true(*args, **kwargs) 128 | else: 129 | return if_false(*args, **kwargs) 130 | 131 | if if_true.__doc__ is None and if_false.__doc__ is not None: 132 | doc = if_false.__doc__ 133 | if_true.__doc__ = doc 134 | elif if_false.__doc__ is None and if_true.__doc__ is not None: 135 | doc = if_true.__doc__ 136 | if_false.__doc__ = doc 137 | elif if_false.__doc__ is None and if_true.__doc__ is None: 138 | # neither function has a docstring 139 | doc = None 140 | else: 141 | raise RuntimeError("only one function can have a docstring") 142 | fn.__doc__ = doc 143 | 144 | _boolean_dispatched[fn] = { 145 | "if_true": if_true, 146 | "if_false": if_false, 147 | "index": arg_index, 148 | "default": default, 149 | "arg_name": arg_name 150 | } 151 | return fn 152 | 153 | 154 | try: 155 | import typing 156 | from typing import Tuple, List 157 | 158 | def is_tuple(ann): 159 | # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule 160 | return ann.__module__ == 'typing' and \ 161 | (getattr(ann, '__origin__', None) is typing.Tuple or 162 | getattr(ann, '__origin__', None) is tuple) 163 | except ImportError: 164 | # A minimal polyfill for versions of Python that don't have typing. 165 | # Note that this means that they also don't support the fancy annotation syntax, so 166 | # those instances will only be used in our tiny `type: ` comment interpreter. 167 | 168 | # The __getitem__ in typing is implemented using metaclasses, but I'm too lazy for that. 169 | class TupleCls(object): 170 | def __getitem__(self, types): 171 | return TupleInstance(types) 172 | 173 | class TupleInstance(object): 174 | def __init__(self, types): 175 | setattr(self, '__args__', types) 176 | 177 | class ListInstance(object): 178 | def __init__(self, types): 179 | setattr(self, '__args__', types) 180 | 181 | class ListCls(object): 182 | def __getitem__(self, types): 183 | return TupleInstance(types) 184 | 185 | Tuple = TupleCls() 186 | List = ListCls() 187 | 188 | def is_tuple(ann): 189 | return isinstance(ann, TupleInstance) 190 | 191 | 192 | # allows BroadcastingList instance to be subscriptable 193 | class BroadcastingListCls(object): 194 | def __getitem__(self, types): 195 | return 196 | 197 | # mypy doesn't support parameters on types, so we have to explicitly type each 198 | # list size 199 | BroadcastingList1 = BroadcastingListCls() 200 | for i in range(2, 7): 201 | globals()["BroadcastingList{}".format(i)] = BroadcastingList1 202 | --------------------------------------------------------------------------------