├── train_eval_scripts ├── recoloradv │ ├── __init__.py │ ├── mister_ed │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── pytorch_ssim.py │ │ │ ├── image_utils.py │ │ │ └── discretization.py │ │ ├── cifar10 │ │ │ ├── __init__.py │ │ │ ├── wide_resnets.py │ │ │ ├── cifar_resnets.py │ │ │ └── cifar_loader.py │ │ ├── README.md │ │ ├── config.py │ │ └── scripts │ │ │ └── setup_cifar.py │ ├── norms.py │ ├── examples │ │ ├── evaluate_cifar10.py │ │ └── evaluate_imagenet.py │ ├── utils.py │ ├── perturbations.py │ └── color_spaces.py ├── README.md ├── recolor.py ├── eval_cifar100.py ├── corruption.py ├── eval.py ├── stadv.py ├── attack.py ├── train.py ├── sam.py └── model.py ├── SAM_segmentation ├── checkpoints │ └── exp_log_and_checkpoints_will_be_saved_here.txt ├── metrics │ ├── __init__.py │ └── stream_metrics.py ├── network │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ └── xception.py │ ├── utils.py │ └── _deeplab.py ├── requirements.txt ├── datasets │ ├── __init__.py │ ├── iccv09.py │ ├── utils.py │ ├── voc.py │ └── cityscapes.py ├── utils │ ├── __init__.py │ ├── scheduler.py │ ├── loss.py │ ├── utils.py │ ├── sam.py │ ├── visualizer.py │ └── attack.py ├── .gitignore ├── LICENSE └── README.md ├── README.md ├── eval.py ├── eval_cifar100.py ├── sam.py ├── model.py ├── utils.py └── train.py /train_eval_scripts/recoloradv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SAM_segmentation/checkpoints/exp_log_and_checkpoints_will_be_saved_here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SAM_segmentation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .stream_metrics import StreamSegMetrics, AverageMeter 2 | 3 | -------------------------------------------------------------------------------- /SAM_segmentation/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import * 2 | from ._deeplab import convert_to_separable_conv -------------------------------------------------------------------------------- /SAM_segmentation/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | pillow 5 | scikit-learn 6 | tqdm 7 | matplotlib 8 | visdom -------------------------------------------------------------------------------- /SAM_segmentation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc import VOCSegmentation 2 | from .cityscapes import Cityscapes 3 | from .iccv09 import Iccv2009Dataset -------------------------------------------------------------------------------- /SAM_segmentation/network/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | from . import mobilenetv2 3 | from . import hrnetv2 4 | from . import xception 5 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/README.md: -------------------------------------------------------------------------------- 1 | Code in this directory is adapted from the [`mister_ed`](https://github.com/revbucket/mister_ed) library. -------------------------------------------------------------------------------- /SAM_segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .visualizer import Visualizer 3 | from .scheduler import PolyLR 4 | from .loss import FocalLoss 5 | from .attack import PGD, normalize_voc, normalize_city, normalize_iccv09 6 | from .sam import SAM -------------------------------------------------------------------------------- /SAM_segmentation/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | VOCdevkit 3 | checkpoints/*/*.pth 4 | .vscode 5 | *.pyc 6 | .idea/ 7 | __pycache__ 8 | results 9 | checkpoints_bak 10 | cityscapes 11 | test_results 12 | datasets/data 13 | samples/ 14 | *.zip 15 | iccv09-celoss.csv 16 | iccv09-sam.csv 17 | iccv09.csv 18 | wandb/ -------------------------------------------------------------------------------- /train_eval_scripts/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the code for the image classification project. The main files are as follows: 2 | 3 | 1. `train.py`: train a classification model on CIFAR10(100) / TinyImageNet using SGD/Adam/SAM/AT. For AWP we use the code from the [official repository](https://github.com/csdongxian/AWP) 4 | 2. `attack.py`: test adversarial robustness of a model using torchattacks 5 | 3. `corruption.py`: test general robustness of a model using robustbench 6 | 4. `sam_trainer.py`: train and test a text classification model 7 | -------------------------------------------------------------------------------- /SAM_segmentation/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, StepLR 2 | 3 | class PolyLR(_LRScheduler): 4 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 5 | self.power = power 6 | self.max_iters = max_iters # avoid zero lr 7 | self.min_lr = min_lr 8 | super(PolyLR, self).__init__(optimizer, last_epoch) 9 | 10 | def get_lr(self): 11 | return [ max( base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power, self.min_lr) 12 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /SAM_segmentation/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | class FocalLoss(nn.Module): 6 | def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): 7 | super(FocalLoss, self).__init__() 8 | self.alpha = alpha 9 | self.gamma = gamma 10 | self.ignore_index = ignore_index 11 | self.size_average = size_average 12 | 13 | def forward(self, inputs, targets): 14 | ce_loss = F.cross_entropy( 15 | inputs, targets, reduction='none', ignore_index=self.ignore_index) 16 | pt = torch.exp(-ce_loss) 17 | focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss 18 | if self.size_average: 19 | return focal_loss.mean() 20 | else: 21 | return focal_loss.sum() -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | config_dir = os.path.abspath(os.path.dirname(__file__)) 4 | 5 | def path_resolver(path): 6 | if path.startswith('~/'): 7 | return os.path.expanduser(path) 8 | 9 | if path.startswith('./'): 10 | return os.path.join(*[config_dir] + path.split('/')[1:]) 11 | 12 | 13 | DEFAULT_DATASETS_DIR = path_resolver('~/datasets') 14 | MODEL_PATH = path_resolver('./pretrained_models/') 15 | OUTPUT_IMAGE_PATH = path_resolver('./output_images/') 16 | 17 | 18 | DEFAULT_BATCH_SIZE = 128 19 | DEFAULT_WORKERS = 4 20 | CIFAR10_MEANS = [0.485, 0.456, 0.406] 21 | CIFAR10_STDS = [0.229, 0.224, 0.225] 22 | 23 | WIDE_CIFAR10_MEANS = [0.4914, 0.4822, 0.4465] 24 | WIDE_CIFAR10_STDS = [0.2023, 0.1994, 0.2010] 25 | 26 | 27 | IMAGENET_MEANS = [0.485, 0.456, 0.406] 28 | IMAGENET_STDS = [0.229, 0.224, 0.225] 29 | -------------------------------------------------------------------------------- /SAM_segmentation/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Gongfan Fang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Duality Between Sharpness-Aware Minimization and Adversarial Training 2 | ## ICML 2024 3 | 4 | Yihao Zhang\*, Hangzhou He\*, Jingyu Zhu\*, Huanran Chen, Yifei Wang, [Zeming Wei](https://weizeming.github.io)${}^\dagger$ 5 | 6 | 7 | ## Sharpness-Aware Minimization Alone can Improve Adversarial Robustness (Workshop version) 8 | ### ICML 2023 AdvML-Frontiers Workshop 9 | [Zeming Wei](https://weizeming.github.io)${}^\dagger$\*, Jingyu Zhu\* and [Yihao Zhang](https://zhang-yihao.github.io/)\* 10 | 11 | ## Citation 12 | ``` 13 | @InProceedings{zhang2024duality, 14 | author = {Zhang, Yihao and He, Hangzhou and Zhu, Jingyu and Chen, Huanran and Wang, Yifei and Wei, Zeming}, 15 | title = {On the Duality Between Sharpness-Aware Minimization and Adversarial Training}, 16 | booktitle = {ICML}, 17 | year = {2024} 18 | } 19 | ``` 20 | and/or 21 | ``` 22 | @InProceedings{wei2023sharpness, 23 | author = {Wei, Zeming and Zhu, Jingyu and Zhang, Yihao}, 24 | title = {Sharpness-Aware Minimization Alone can Improve Adversarial Robustness}, 25 | booktitle = {ICML 2023 Workshop on New Frontiers in Adversarial Machine Learning}, 26 | year = {2023} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torch.nn.functional as F 5 | import os 6 | from model import PreActResNet18 7 | from utils import * 8 | 9 | 10 | if __name__ == '__main__': 11 | file_list = os.listdir('models') 12 | model = PreActResNet18() 13 | 14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf') 15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf') 16 | 17 | PGD16 = PGD(10, 2./255., 16./255., 'l2') 18 | PGD32 = PGD(10, 4./255., 32./255., 'l2') 19 | 20 | _, loader = load_dataset('cifar10', 1024) 21 | 22 | for m in file_list: 23 | ckpt = torch.load('models/' + m, map_location='cpu') 24 | model.load_state_dict(ckpt) 25 | model.eval() 26 | model.cuda() 27 | accs = [] 28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]): 29 | acc = 0 30 | for x,y in loader: 31 | x, y = x.cuda(), y.cuda() 32 | delta = attack.perturb(model, x, y) 33 | pred = model((normalize_cifar(x+delta))) 34 | acc += (pred.max(1)[1] == y).float().sum().item() 35 | acc /= 100 36 | accs.append(acc) 37 | print(m) 38 | print(' & '.join([str(a) for a in accs])) -------------------------------------------------------------------------------- /train_eval_scripts/recolor.py: -------------------------------------------------------------------------------- 1 | import recoloradv.mister_ed.config as config 2 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize 3 | 4 | # ReColorAdv 5 | from recoloradv.utils import get_attack_from_name 6 | from model import PreActResNet18 7 | from utils import * 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, model, norm): 12 | super(Model, self).__init__() 13 | self.model = model 14 | self.norm = norm 15 | 16 | def forward(self, x): 17 | return self.model(self.norm(x)) 18 | 19 | 20 | model = PreActResNet18(10) 21 | model.load_state_dict(torch.load('./cifar10_models/cifar10_prn_sgd_sub.pth')) 22 | model.eval() 23 | model.cuda() 24 | 25 | # PGD attack 26 | # Mod = Model(model, normalize_cifar) 27 | # Mod.eval() 28 | # Mod.cuda() 29 | 30 | # get imgs and labels 31 | train_loader, test_loader = load_dataset('cifar10', 1024) 32 | normalizer = DifferentiableNormalize( 33 | mean=config.CIFAR10_MEANS, 34 | std=config.CIFAR10_STDS, 35 | ) 36 | attack = get_attack_from_name('recoloradv', model, normalizer, verbose=True) 37 | acc = 0 38 | for x, y in test_loader: 39 | x, y = x.cuda(), y.cuda() 40 | adv_x = attack.attack(x, y)[0] 41 | pred = model(normalizer(adv_x)) 42 | acc += (pred.max(1)[1] == y).float().sum().item() 43 | break 44 | acc /= 1024 45 | print(acc) 46 | -------------------------------------------------------------------------------- /eval_cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torch.nn.functional as F 5 | import os 6 | from model import PreActResNet18 7 | from utils import * 8 | 9 | 10 | if __name__ == '__main__': 11 | file_list = os.listdir('cifar100_models') 12 | model = PreActResNet18(100) 13 | 14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf', False, normalize_cifar100) 15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf', False, normalize_cifar100) 16 | 17 | PGD16 = PGD(10, 2./255., 16./255., 'l2', False, normalize_cifar100) 18 | PGD32 = PGD(10, 4./255., 32./255., 'l2', False, normalize_cifar100) 19 | 20 | _, loader = load_dataset('cifar100', 1024) 21 | 22 | for m in file_list: 23 | ckpt = torch.load('cifar100_models/' + m, map_location='cpu') 24 | model.load_state_dict(ckpt) 25 | model.eval() 26 | model.cuda() 27 | accs = [] 28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]): 29 | acc = 0 30 | for x,y in loader: 31 | x, y = x.cuda(), y.cuda() 32 | delta = attack.perturb(model, x, y) 33 | pred = model((normalize_cifar(x+delta))) 34 | acc += (pred.max(1)[1] == y).float().sum().item() 35 | acc /= 100 36 | accs.append(acc) 37 | print(m) 38 | print(' & '.join([str(a) for a in accs])) -------------------------------------------------------------------------------- /train_eval_scripts/eval_cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torch.nn.functional as F 5 | import os 6 | from model import PreActResNet18 7 | from utils import * 8 | 9 | 10 | if __name__ == '__main__': 11 | file_list = os.listdir('cifar100_models') 12 | model = PreActResNet18(100) 13 | 14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf', False, normalize_cifar100) 15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf', False, normalize_cifar100) 16 | 17 | PGD16 = PGD(10, 2./255., 16./255., 'l2', False, normalize_cifar100) 18 | PGD32 = PGD(10, 4./255., 32./255., 'l2', False, normalize_cifar100) 19 | 20 | _, loader = load_dataset('cifar100', 1024) 21 | 22 | for m in file_list: 23 | ckpt = torch.load('cifar100_models/' + m, map_location='cpu') 24 | model.load_state_dict(ckpt) 25 | model.eval() 26 | model.cuda() 27 | accs = [] 28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]): 29 | acc = 0 30 | for x,y in loader: 31 | x, y = x.cuda(), y.cuda() 32 | delta = attack.perturb(model, x, y) 33 | pred = model((normalize_cifar(x+delta))) 34 | acc += (pred.max(1)[1] == y).float().sum().item() 35 | acc /= 100 36 | accs.append(acc) 37 | print(m) 38 | print(' & '.join([str(a) for a in accs])) -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/norms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | def smoothness(grid): 6 | """ 7 | Given a variable of dimensions (N, X, Y, [Z], C), computes the sum of 8 | the differences between adjacent points in the grid formed by the 9 | dimensions X, Y, and (optionally) Z. Returns a tensor of dimension N. 10 | """ 11 | 12 | num_dims = len(grid.size()) - 2 13 | batch_size = grid.size()[0] 14 | norm = Variable(torch.zeros(batch_size, dtype=grid.data.dtype, 15 | device=grid.data.device)) 16 | 17 | for dim in range(num_dims): 18 | slice_before = (slice(None),) * (dim + 1) 19 | slice_after = (slice(None),) * (num_dims - dim) 20 | shifted_grids = [ 21 | # left 22 | torch.cat([ 23 | grid[slice_before + (slice(1, None),) + slice_after], 24 | grid[slice_before + (slice(-1, None),) + slice_after], 25 | ], dim + 1), 26 | # right 27 | torch.cat([ 28 | grid[slice_before + (slice(None, 1),) + slice_after], 29 | grid[slice_before + (slice(None, -1),) + slice_after], 30 | ], dim + 1) 31 | ] 32 | for shifted_grid in shifted_grids: 33 | delta = shifted_grid - grid 34 | norm_components = (delta.pow(2).sum(-1) + 1e-10).pow(0.5) 35 | norm.add_(norm_components.sum( 36 | tuple(range(1, len(norm_components.size()))))) 37 | 38 | return norm 39 | -------------------------------------------------------------------------------- /SAM_segmentation/utils/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms.functional import normalize 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import sys 6 | 7 | def denormalize(tensor, mean, std): 8 | mean = np.array(mean) 9 | std = np.array(std) 10 | 11 | _mean = -mean/std 12 | _std = 1/std 13 | return normalize(tensor, _mean, _std) 14 | 15 | class Logger(object): 16 | # 作用:将print的内容保存到文件中,同时在屏幕上显示,且没次输出都刷新文件,但是屏幕不刷新 17 | def __init__(self, filename="log.txt"): 18 | self.terminal = sys.stdout 19 | self.log = open(filename, 'a') 20 | 21 | def write(self, message): 22 | self.terminal.write(message) 23 | self.log.write(message) 24 | self.log.flush() 25 | 26 | def flush(self): 27 | pass 28 | 29 | class Denormalize(object): 30 | def __init__(self, mean, std): 31 | mean = np.array(mean) 32 | std = np.array(std) 33 | self._mean = -mean/std 34 | self._std = 1/std 35 | 36 | def __call__(self, tensor): 37 | if isinstance(tensor, np.ndarray): 38 | return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1) 39 | return normalize(tensor, self._mean, self._std) 40 | 41 | def set_bn_momentum(model, momentum=0.1): 42 | for m in model.modules(): 43 | if isinstance(m, nn.BatchNorm2d): 44 | m.momentum = momentum 45 | 46 | def fix_bn(model): 47 | for m in model.modules(): 48 | if isinstance(m, nn.BatchNorm2d): 49 | m.eval() 50 | 51 | def mkdir(path): 52 | if not os.path.exists(path): 53 | os.mkdir(path) 54 | -------------------------------------------------------------------------------- /sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM(torch.optim.Optimizer): 5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 7 | 8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 9 | super(SAM, self).__init__(params, defaults) 10 | 11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 12 | self.param_groups = self.base_optimizer.param_groups 13 | self.defaults.update(self.base_optimizer.defaults) 14 | 15 | @torch.no_grad() 16 | def first_step(self, zero_grad=False): 17 | grad_norm = self._grad_norm() 18 | for group in self.param_groups: 19 | scale = group["rho"] / (grad_norm + 1e-12) 20 | 21 | for p in group["params"]: 22 | if p.grad is None: continue 23 | self.state[p]["old_p"] = p.data.clone() 24 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 25 | p.add_(e_w) # climb to the local maximum "w + e(w)" 26 | 27 | if zero_grad: self.zero_grad() 28 | 29 | @torch.no_grad() 30 | def second_step(self, zero_grad=False): 31 | for group in self.param_groups: 32 | for p in group["params"]: 33 | if p.grad is None: continue 34 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 35 | 36 | self.base_optimizer.step() # do the actual "sharpness-aware" update 37 | 38 | if zero_grad: self.zero_grad() 39 | 40 | @torch.no_grad() 41 | def step(self, closure=None): 42 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 43 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 44 | 45 | self.first_step(zero_grad=True) 46 | closure() 47 | self.second_step() 48 | 49 | def _grad_norm(self): 50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 51 | norm = torch.norm( 52 | torch.stack([ 53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 54 | for group in self.param_groups for p in group["params"] 55 | if p.grad is not None 56 | ]), 57 | p=2 58 | ) 59 | return norm 60 | 61 | def load_state_dict(self, state_dict): 62 | super().load_state_dict(state_dict) 63 | self.base_optimizer.param_groups = self.param_groups 64 | -------------------------------------------------------------------------------- /SAM_segmentation/utils/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SAM(torch.optim.Optimizer): 4 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 5 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 6 | 7 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 8 | super().__init__(params, defaults) 9 | if isinstance(base_optimizer, torch.optim.Optimizer): 10 | self.base_optimizer = base_optimizer 11 | print("SAM is applied to inner optimizer: ", base_optimizer.__class__.__name__) 12 | else: 13 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 14 | self.param_groups = self.base_optimizer.param_groups 15 | self.defaults.update(self.base_optimizer.defaults) 16 | 17 | @torch.no_grad() 18 | def first_step(self, zero_grad=False): 19 | grad_norm = self._grad_norm() 20 | for group in self.param_groups: 21 | scale = group["rho"] / (grad_norm + 1e-12) 22 | 23 | for p in group["params"]: 24 | if p.grad is None: continue 25 | self.state[p]["old_p"] = p.data.clone() 26 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 27 | p.add_(e_w) # climb to the local maximum "w + e(w)" 28 | 29 | if zero_grad: self.zero_grad() 30 | 31 | @torch.no_grad() 32 | def second_step(self, zero_grad=False): 33 | for group in self.param_groups: 34 | for p in group["params"]: 35 | if p.grad is None: continue 36 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 37 | 38 | self.base_optimizer.step() # do the actual "sharpness-aware" update 39 | 40 | if zero_grad: self.zero_grad() 41 | 42 | @torch.no_grad() 43 | def step(self, closure=None): 44 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 45 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 46 | 47 | self.first_step(zero_grad=True) 48 | closure() 49 | self.second_step() 50 | 51 | def _grad_norm(self): 52 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 53 | norm = torch.norm( 54 | torch.stack([ 55 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 56 | for group in self.param_groups for p in group["params"] 57 | if p.grad is not None 58 | ]), 59 | p=2 60 | ) 61 | return norm 62 | 63 | def load_state_dict(self, state_dict): 64 | super().load_state_dict(state_dict) 65 | self.base_optimizer.param_groups = self.param_groups -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/utils/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | """ Implementation directly lifted from Po-Hsun-Su for pytorch ssim 2 | See github repo here: https://github.com/Po-Hsun-Su/pytorch-ssim 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from math import exp 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 12 | return gauss/gauss.sum() 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 21 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 22 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1*mu2 27 | 28 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 30 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 31 | 32 | C1 = 0.01**2 33 | C2 = 0.03**2 34 | 35 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | def ssim(img1, img2, window_size = 11, size_average = True): 69 | (_, channel, _, _) = img1.size() 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /SAM_segmentation/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | from visdom import Visdom 2 | import json 3 | 4 | class Visualizer(object): 5 | """ Visualizer 6 | """ 7 | def __init__(self, port='13579', env='main', id=None): 8 | #self.cur_win = {} 9 | self.vis = Visdom(port=port, env=env) 10 | self.id = id 11 | self.env = env 12 | # Restore 13 | #ori_win = self.vis.get_window_data() 14 | #ori_win = json.loads(ori_win) 15 | #print(ori_win) 16 | #self.cur_win = { v['title']: k for k, v in ori_win.items() } 17 | 18 | def vis_scalar(self, name, x, y, opts=None): 19 | if not isinstance(x, list): 20 | x = [x] 21 | if not isinstance(y, list): 22 | y = [y] 23 | 24 | if self.id is not None: 25 | name = "[%s]"%self.id + name 26 | default_opts = { 'title': name } 27 | if opts is not None: 28 | default_opts.update(opts) 29 | 30 | #win = self.cur_win.get(name, None) 31 | #if win is not None: 32 | self.vis.line( X=x, Y=y, win=name, opts=default_opts, update='append') 33 | #else: 34 | # self.cur_win[name] = self.vis.line( X=x, Y=y, opts=default_opts) 35 | 36 | def vis_image(self, name, img, env=None, opts=None): 37 | """ vis image in visdom 38 | """ 39 | if env is None: 40 | env = self.env 41 | if self.id is not None: 42 | name = "[%s]"%self.id + name 43 | #win = self.cur_win.get(name, None) 44 | default_opts = { 'title': name } 45 | if opts is not None: 46 | default_opts.update(opts) 47 | #if win is not None: 48 | self.vis.image( img=img, win=name, opts=opts, env=env ) 49 | #else: 50 | # self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env ) 51 | 52 | def vis_table(self, name, tbl, opts=None): 53 | #win = self.cur_win.get(name, None) 54 | 55 | tbl_str = " " 56 | tbl_str+=" \ 57 | \ 58 | \ 59 | " 60 | for k, v in tbl.items(): 61 | tbl_str+= " \ 62 | \ 63 | \ 64 | "%(k, v) 65 | 66 | tbl_str+="
TermValue
%s%s
" 67 | 68 | default_opts = { 'title': name } 69 | if opts is not None: 70 | default_opts.update(opts) 71 | #if win is not None: 72 | self.vis.text(tbl_str, win=name, opts=default_opts) 73 | #else: 74 | #self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts) 75 | 76 | 77 | if __name__=='__main__': 78 | import numpy as np 79 | vis = Visualizer(port=35588, env='main') 80 | tbl = {"lr": 214, "momentum": 0.9} 81 | vis.vis_table("test_table", tbl) 82 | tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"} 83 | vis.vis_table("test_table", tbl) 84 | 85 | vis.vis_scalar(name='loss', x=0, y=1) 86 | vis.vis_scalar(name='loss', x=2, y=4) 87 | vis.vis_scalar(name='loss', x=4, y=6) -------------------------------------------------------------------------------- /SAM_segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Sharpness-Aware Minimization Alone can Improve Adversarial Robustness in Semantic Segmentation 2 | 3 | The semantic segmentation code is adapted from [VainF](https://github.com/VainF/DeepLabV3Plus-Pytorch) 4 | 5 | ## Pascal VOC 6 | 7 | ### 1. Requirements 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### 2. Prepare Datasets 14 | 15 | #### 2.1 Standard Pascal VOC 16 | You can run train.py with "--download" option to download and extract pascal voc dataset. The defaut path is './datasets/data': 17 | 18 | ``` 19 | /datasets 20 | /data 21 | /VOCdevkit 22 | /VOC2012 23 | /SegmentationClass 24 | /JPEGImages 25 | ... 26 | ... 27 | /VOCtrainval_11-May-2012.tar 28 | ... 29 | ``` 30 | 31 | #### 2.2 Pascal VOC trainaug (Recommended!!) 32 | 33 | See chapter 4 of [2] 34 | 35 | The original dataset contains 1464 (train), 1449 (val), and 1456 (test) pixel-level annotated images. We augment the dataset by the extra annotations provided by [76], resulting in 10582 (trainaug) training images. The performance is measured in terms of pixel intersection-over-union averaged across the 21 classes (mIOU). 36 | 37 | *./datasets/data/train_aug.txt* includes the file names of 10582 trainaug images (val images are excluded). Please to download their labels from [Dropbox](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) or [Tencent Weiyun](https://share.weiyun.com/5NmJ6Rk). Those labels come from [DrSleep's repo](https://github.com/DrSleep/tensorflow-deeplab-resnet). 38 | 39 | Extract trainaug labels (SegmentationClassAug) to the VOC2012 directory. 40 | 41 | ``` 42 | /datasets 43 | /data 44 | /VOCdevkit 45 | /VOC2012 46 | /SegmentationClass 47 | /SegmentationClassAug # <= the trainaug labels 48 | /JPEGImages 49 | ... 50 | ... 51 | /VOCtrainval_11-May-2012.tar 52 | ... 53 | ``` 54 | 55 | ### 3. Training on Pascal VOC2012 Aug 56 | 57 | #### 3.2 Training with OS=16 58 | 59 | Run main.py with *"--year 2012_aug"* to train your model on Pascal VOC2012 Aug. You can also parallel your training on 4 GPUs with '--gpu_id 0,1,2,3' 60 | 61 | **Note: There is no SyncBN in this repo, so training with *multple GPUs and small batch size* may degrades the performance. See [PyTorch-Encoding](https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html) for more details about SyncBN** 62 | 63 | ```bash 64 | python main.py --model deeplabv3_mobilenet --gpu_id 3 --year 2012_aug --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16 --optimizer SAM --rho 0.02 --exp_name voc-SAM 65 | ``` 66 | 67 | 68 | ## Reference 69 | 70 | [1] [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) 71 | 72 | [2] [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611) 73 | 74 | [3] [VainF/DeepLabV3Plus-Pytorch](https://github.com/VainF/DeepLabV3Plus-Pytorch) 75 | 76 | [4] [SAM implementation](https://github.com/weizeming/SAM_AT) -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/examples/evaluate_cifar10.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | import sys 5 | import os 6 | from torch import optim 7 | from torch.utils.data import DataLoader 8 | from torchvision.models import resnet50 9 | from torchvision.datasets import ImageNet 10 | from torchvision import transforms 11 | 12 | # mister_ed 13 | from recoloradv.mister_ed import loss_functions as lf 14 | from recoloradv.mister_ed import adversarial_training as advtrain 15 | from recoloradv.mister_ed import adversarial_perturbations as ap 16 | from recoloradv.mister_ed import adversarial_attacks as aa 17 | from recoloradv.mister_ed import spatial_transformers as st 18 | from recoloradv.mister_ed.utils import pytorch_utils as utils 19 | from recoloradv.mister_ed.cifar10 import cifar_loader 20 | 21 | # ReColorAdv 22 | from recoloradv import perturbations as pt 23 | from recoloradv import color_transformers as ct 24 | from recoloradv import color_spaces as cs 25 | from recoloradv.utils import get_attack_from_name, load_pretrained_cifar10_model 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser( 30 | description='Evaluate a model trained on CIFAR-10 ' 31 | 'against ReColorAdv and other attacks' 32 | ) 33 | 34 | parser.add_argument('--checkpoint', type=str, 35 | help='checkpoint to evaluate') 36 | parser.add_argument('--attack', type=str, 37 | help='attack to run, such as "recoloradv" or ' 38 | '"stadv+delta"') 39 | parser.add_argument('--batch_size', type=int, default=100, 40 | help='number of examples/minibatch') 41 | parser.add_argument('--num_batches', type=int, required=False, 42 | help='number of batches (default entire dataset)') 43 | args = parser.parse_args() 44 | 45 | model, normalizer = load_pretrained_cifar10_model(args.checkpoint) 46 | val_loader = cifar_loader.load_cifar_data('val', batch_size=args.batch_size) 47 | 48 | model.eval() 49 | if torch.cuda.is_available(): 50 | model.cuda() 51 | 52 | attack = get_attack_from_name(args.attack, model, normalizer) 53 | 54 | batches_correct = [] 55 | for batch_index, (inputs, labels) in enumerate(val_loader): 56 | if ( 57 | args.num_batches is not None and 58 | batch_index >= args.num_batches 59 | ): 60 | break 61 | 62 | if torch.cuda.is_available(): 63 | inputs = inputs.cuda() 64 | labels = labels.cuda() 65 | 66 | adv_inputs = attack.attack( 67 | inputs, 68 | labels, 69 | )[0] 70 | with torch.no_grad(): 71 | adv_logits = model(normalizer(adv_inputs)) 72 | batch_correct = (adv_logits.argmax(1) == labels).detach() 73 | 74 | batch_accuracy = batch_correct.float().mean().item() 75 | print(f'BATCH {batch_index:05d}', 76 | f'accuracy = {batch_accuracy * 100:.1f}', 77 | sep='\t') 78 | batches_correct.append(batch_correct) 79 | 80 | accuracy = torch.cat(batches_correct).float().mean().item() 81 | print('OVERALL ', 82 | f'accuracy = {accuracy * 100:.1f}', 83 | sep='\t') 84 | -------------------------------------------------------------------------------- /train_eval_scripts/corruption.py: -------------------------------------------------------------------------------- 1 | import torchattacks 2 | from model import PreActResNet18, WRN28_10, DeiT 3 | from utils import * 4 | import recoloradv.mister_ed.config as config 5 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize 6 | from recoloradv.utils import get_attack_from_name 7 | from argparse import ArgumentParser 8 | 9 | from robustbench.data import load_cifar10c, load_cifar100c 10 | from robustbench.utils import clean_accuracy 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('--model_path', default='put filename here', type=str) 14 | args = parser.parse_args() 15 | file_name = args.model_path 16 | 17 | class Model(nn.Module): 18 | def __init__(self, model, norm): 19 | super(Model, self).__init__() 20 | self.model = model 21 | self.norm = norm 22 | 23 | def forward(self, x): 24 | return self.model(self.norm(x)) 25 | 26 | label_dim = 10 27 | if 'cifar10_' in file_name: 28 | label_dim = 10 29 | normalizer = DifferentiableNormalize( 30 | mean=config.CIFAR10_MEANS, 31 | std=config.CIFAR10_STDS, 32 | ) 33 | norm = normalize_cifar 34 | train_loader, test_loader = load_dataset('cifar10', 1000) 35 | elif 'cifar100_' in file_name: 36 | label_dim = 100 37 | normalizer = DifferentiableNormalize( 38 | mean=CIFAR100_MEAN, 39 | std=CIFAR100_STD, 40 | ) 41 | norm = normalize_cifar100 42 | train_loader, test_loader = load_dataset('cifar100', 1000) 43 | elif 'tiny' in file_name: 44 | label_dim = 200 45 | normalizer = DifferentiableNormalize( 46 | mean=TINYIMAGENET_MEAN, 47 | std=TINYIMAGENET_STD, 48 | ) 49 | norm = normalize_tinyimagenet 50 | train_loader, test_loader = load_dataset('tiny-imagenet-200', 1000) 51 | else: 52 | raise ValueError('Unknown dataset') 53 | 54 | if 'prn' in file_name and 'deit' not in file_name and 'wrn' not in file_name: 55 | model = PreActResNet18(label_dim) 56 | elif 'wrn' in file_name: 57 | model = WRN28_10(label_dim) 58 | elif 'deit' in file_name: 59 | model = DeiT(label_dim) 60 | 61 | corruption_test_types = [['brightness'], ['fog'], ['frost'], ['gaussian_blur'], ['impulse_noise'], ['jpeg_compression'], ['shot_noise'], ['snow'], ['speckle_noise']] 62 | for corruptions in corruption_test_types: 63 | print(f'\n##### corruption type: {corruptions}\n') 64 | x_test, y_test = load_cifar10c(n_examples=1000, corruptions=corruptions, severity=3) 65 | for model_name in ['put file name here', 'put file name here']: 66 | model = PreActResNet18(label_dim) 67 | if 'awp' in model_name: 68 | d = torch.load('./models/' + model_name, map_location='cuda:0') 69 | for k in list(d.keys()): 70 | if k.startswith('module.'): 71 | d[k[7:]] = d[k] 72 | del d[k] 73 | model.load_state_dict(d) 74 | 75 | else: 76 | model.load_state_dict(torch.load('./models/' + model_name, map_location='cuda:0')) 77 | model.eval() 78 | model.cuda() 79 | acc = clean_accuracy(model, x_test, y_test, device=torch.device('cuda')) 80 | print(f'Model: {model_name}, CIFAR-10-C accuracy: {acc:.1%}') -------------------------------------------------------------------------------- /train_eval_scripts/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torch.nn.functional as F 5 | import os 6 | from model import PreActResNet18, WRN28_10, DeiT 7 | from autoattack import AutoAttack 8 | from utils import * 9 | import argparse 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--model_path', type=str, required=True) 14 | parser.add_argument('--model', type=str, default='PRN', choices=['PRN', 'WRN', 'DeiT']) 15 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tiny-imagenet-200']) 16 | parser.add_argument('--attacker', default='PGD', choices=['PGD', 'FGSM', 'CW', 'AutoAttack']) 17 | parser.add_argument('--eps', default=8./255., type=float) 18 | parser.add_argument('--batch-size', default=1024, type=int) 19 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2']) 20 | return parser.parse_args() 21 | 22 | args = get_args() 23 | 24 | if __name__ == '__main__': 25 | model_path = args.model_path 26 | dataset = args.dataset 27 | model_name = args.model 28 | label_dim = {'cifar10': 10, 'cifar100': 100, 'tiny-imagenet-200': 200}[dataset] 29 | model = {'PRN': PreActResNet18(label_dim), 'WRN': WRN28_10(label_dim), 'DeiT': DeiT(label_dim)}[model_name] 30 | normalizer = {'cifar10': normalize_cifar, 'cifar100': normalize_cifar, 'tiny-imagenet-200': normalize_tinyimagenet}[dataset] 31 | attacker = args.attacker 32 | 33 | #PGD1 = PGD(10, 0.25/255., 1./255., 'linf') 34 | #PGD2 = PGD(10, 0.5/255., 2./255., 'linf') 35 | 36 | #PGD16 = PGD(10, 2./255., 16./255., 'l2') 37 | #PGD32 = PGD(10, 4./255., 32./255., 'l2') 38 | #FGSM1 = PGD(1, 0.25/255., 1./255., 'linf') 39 | #FGSM16 = PGD(1, 2./255., 16./255., 'l2') 40 | 41 | pgd_iters = 10 if attacker == 'PGD' else 1 42 | eps = args.eps 43 | alpha = eps / 4 44 | norm = args.norm 45 | pgd = PGD(pgd_iters, alpha, eps, norm, False, normalizer) 46 | 47 | _, loader = load_dataset(dataset, args.batch_size) 48 | 49 | ckpt = torch.load(model_path, map_location='cpu') 50 | model.load_state_dict(ckpt) 51 | model.eval() 52 | model.cuda() 53 | acc = 0 54 | if args.attacker in ['PGD', 'FGSM']: 55 | for x,y in loader: 56 | x, y = x.cuda(), y.cuda() 57 | delta = pgd.perturb(model, x, y) 58 | pred = model((normalizer(x+delta))) 59 | acc += (pred.max(1)[1] == y).float().sum().item() 60 | acc /= 100 61 | elif args.attacker == 'CW': 62 | for x,y in loader: 63 | x, y = x.cuda(), y.cuda() 64 | x = normalizer(x) 65 | attacked_images = cw_l2_attack(model, x, y) 66 | pred = model(attacked_images) 67 | acc += (pred.max(1)[1] == y).float().sum().item() 68 | acc /= 100 69 | elif args.attacker == 'AutoAttack': 70 | norm = 'Linf' if args.norm == 'linf' else 'L2' 71 | adversary = AutoAttack(model, norm=norm, eps=args.eps, version='standard') 72 | for x,y in loader: 73 | x, y = x.cuda(), y.cuda() 74 | x = normalizer(x) 75 | adv_images = adversary.run_standard_evaluation(x, y, bs=64) 76 | pred = model(adv_images) 77 | acc += (pred.max(1)[1] == y).float().sum().item() 78 | acc /= 100 79 | print("Model: {}, Dataset: {}, Attack: {}, Accuracy: {}".format(model_name, dataset, args.attacker, acc)) -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/cifar10/wide_resnets.py: -------------------------------------------------------------------------------- 1 | """ Wide Resnet architecture implementation taken from this repo: 2 | https://github.com/meliketoy/wide-resnet.pytorch 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import sys 12 | import numpy as np 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 16 | 17 | def conv_init(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Conv') != -1: 20 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 21 | init.constant(m.bias, 0) 22 | elif classname.find('BatchNorm') != -1: 23 | init.constant(m.weight, 1) 24 | init.constant(m.bias, 0) 25 | 26 | class wide_basic(nn.Module): 27 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 28 | super(wide_basic, self).__init__() 29 | self.bn1 = nn.BatchNorm2d(in_planes) 30 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 31 | self.dropout = nn.Dropout(p=dropout_rate) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 34 | 35 | self.shortcut = nn.Sequential() 36 | if stride != 1 or in_planes != planes: 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 39 | ) 40 | 41 | def forward(self, x): 42 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 43 | out = self.conv2(F.relu(self.bn2(out))) 44 | out += self.shortcut(x) 45 | 46 | return out 47 | 48 | class Wide_ResNet(nn.Module): 49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 50 | super(Wide_ResNet, self).__init__() 51 | self.in_planes = 16 52 | 53 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 54 | n = (depth-4)/6 55 | k = widen_factor 56 | 57 | print('| Wide-Resnet %dx%d' %(depth, k)) 58 | nStages = [16, 16*k, 32*k, 64*k] 59 | 60 | self.conv1 = conv3x3(3,nStages[0]) 61 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 62 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 63 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 64 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 65 | self.linear = nn.Linear(nStages[3], num_classes) 66 | 67 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 68 | strides = [stride] + [1]*(num_blocks-1) 69 | layers = [] 70 | 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 73 | self.in_planes = planes 74 | 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.relu(self.bn1(out)) 83 | out = F.avg_pool2d(out, 8) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | 87 | return out 88 | 89 | if __name__ == '__main__': 90 | net=Wide_ResNet(28, 10, 0.3, 10) 91 | y = net(Variable(torch.randn(1,3,32,32))) 92 | 93 | print(y.size()) -------------------------------------------------------------------------------- /SAM_segmentation/datasets/iccv09.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tarfile 4 | import collections 5 | import torch.utils.data as data 6 | import shutil 7 | import numpy as np 8 | 9 | from PIL import Image 10 | from torchvision.datasets.utils import download_url, check_integrity 11 | 12 | """ 13 | class_names,r,g,b 14 | sky,68,1,84 15 | tree,72,40,140 16 | road,62,74,137 17 | grass,38,130,142 18 | water,31,158,137 19 | building,53,183,121 20 | mountain,109,205,89 21 | foreground,180,222,44 22 | unknown,49,104,142 23 | """ 24 | 25 | mean = [0.4813, 0.4901, 0.4747] # rgb 26 | std = [0.2495, 0.2492, 0.2748] # rgb 27 | 28 | class Iccv2009Dataset(data.Dataset): 29 | 30 | rgb2id = { 31 | (68, 1, 84): 0, 32 | (72, 40, 140): 1, 33 | (62, 74, 137): 2, 34 | (38, 130, 142): 3, 35 | (31, 158, 137): 4, 36 | (53, 183, 121): 5, 37 | (109, 205, 89): 6, 38 | (180, 222, 44): 7, 39 | (49, 104, 142): 8, 40 | } 41 | 42 | def __init__(self, root, split, transform=None): 43 | 44 | self.image_root = os.path.join(root, 'images') 45 | self.mask_root = os.path.join(root, 'labels_colored') 46 | self.split = split 47 | self.images = [] 48 | self.targets = [] 49 | self.transform = transform 50 | 51 | for filename in os.listdir(self.image_root): 52 | if filename.endswith('.jpg'): 53 | self.images.append(os.path.join(self.image_root, filename)) 54 | self.targets.append(os.path.join(self.mask_root, filename[:-4] + '.png')) 55 | 56 | if self.split == 'train': 57 | self.images = self.images[:int(0.7*len(self.images))] 58 | self.targets = self.targets[:int(0.7*len(self.targets))] 59 | elif self.split == 'val': 60 | self.images = self.images[int(0.7*len(self.images)):] 61 | self.targets = self.targets[int(0.7*len(self.targets)):] 62 | else: 63 | raise ValueError('Invalid split name: {}'.format(self.split)) 64 | 65 | def __getitem__(self, index): 66 | image = Image.open(self.images[index]).convert('RGB') 67 | target = Image.open(self.targets[index]) 68 | target = self.encode_mask(np.array(target)) 69 | target = Image.fromarray(target) 70 | if self.transform is not None: 71 | image, target = self.transform(image, target) 72 | 73 | # tensor min-max normalization image, type(image) = Tensor 74 | image = (image - image.min())/(image.max() - image.min()) 75 | 76 | return image, target 77 | 78 | def __len__(self): 79 | return len(self.images) 80 | 81 | @classmethod 82 | def encode_mask(cls, mask): 83 | for k in cls.rgb2id: 84 | mask[(mask == k).all(axis=2)] = cls.rgb2id[k] 85 | return mask[:, :, 0] 86 | 87 | @classmethod 88 | def decode_target(cls, target): 89 | target_rgb = np.zeros((target.shape[0], target.shape[1], 3), dtype=np.uint8) 90 | for k in cls.rgb2id: 91 | target_rgb[(target == cls.rgb2id[k])] = k 92 | return target_rgb 93 | 94 | if __name__ == "__main__": 95 | dataset = Iccv2009Dataset('/mnt/nasv2/hhz/DeepLabV3Plus-Pytorch-master/datasets/data/iccv09', 'train') 96 | # test mask shape and value 97 | for i in range(len(dataset)): 98 | img, mask = dataset[i] 99 | img = np.array(img) 100 | mask = np.array(mask) 101 | print(img.shape, mask.shape) 102 | print(np.unique(mask)) 103 | if i == 10: 104 | break 105 | -------------------------------------------------------------------------------- /SAM_segmentation/metrics/stream_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import confusion_matrix 3 | 4 | class _StreamMetrics(object): 5 | def __init__(self): 6 | """ Overridden by subclasses """ 7 | raise NotImplementedError() 8 | 9 | def update(self, gt, pred): 10 | """ Overridden by subclasses """ 11 | raise NotImplementedError() 12 | 13 | def get_results(self): 14 | """ Overridden by subclasses """ 15 | raise NotImplementedError() 16 | 17 | def to_str(self, metrics): 18 | """ Overridden by subclasses """ 19 | raise NotImplementedError() 20 | 21 | def reset(self): 22 | """ Overridden by subclasses """ 23 | raise NotImplementedError() 24 | 25 | class StreamSegMetrics(_StreamMetrics): 26 | """ 27 | Stream Metrics for Semantic Segmentation Task 28 | """ 29 | def __init__(self, n_classes): 30 | self.n_classes = n_classes 31 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 32 | 33 | def update(self, label_trues, label_preds): 34 | for lt, lp in zip(label_trues, label_preds): 35 | self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() ) 36 | 37 | @staticmethod 38 | def to_str(results): 39 | string = "\n" 40 | for k, v in results.items(): 41 | if k!="Class IoU": 42 | string += "%s: %f\n"%(k, v) 43 | 44 | #string+='Class IoU:\n' 45 | #for k, v in results['Class IoU'].items(): 46 | # string += "\tclass %d: %f\n"%(k, v) 47 | return string 48 | 49 | def _fast_hist(self, label_true, label_pred): 50 | mask = (label_true >= 0) & (label_true < self.n_classes) 51 | hist = np.bincount( 52 | self.n_classes * label_true[mask].astype(int) + label_pred[mask], 53 | minlength=self.n_classes ** 2, 54 | ).reshape(self.n_classes, self.n_classes) 55 | return hist 56 | 57 | def get_results(self): 58 | """Returns accuracy score evaluation result. 59 | - overall accuracy 60 | - mean accuracy 61 | - mean IU 62 | - fwavacc 63 | """ 64 | hist = self.confusion_matrix 65 | acc = np.diag(hist).sum() / hist.sum() 66 | acc_cls = np.diag(hist) / hist.sum(axis=1) 67 | acc_cls = np.nanmean(acc_cls) 68 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 69 | mean_iu = np.nanmean(iu) 70 | freq = hist.sum(axis=1) / hist.sum() 71 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 72 | cls_iu = dict(zip(range(self.n_classes), iu)) 73 | 74 | return { 75 | "Overall Acc": acc, 76 | "Mean Acc": acc_cls, 77 | "FreqW Acc": fwavacc, 78 | "Mean IoU": mean_iu, 79 | "Class IoU": cls_iu, 80 | } 81 | 82 | def reset(self): 83 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 84 | 85 | class AverageMeter(object): 86 | """Computes average values""" 87 | def __init__(self): 88 | self.book = dict() 89 | 90 | def reset_all(self): 91 | self.book.clear() 92 | 93 | def reset(self, id): 94 | item = self.book.get(id, None) 95 | if item is not None: 96 | item[0] = 0 97 | item[1] = 0 98 | 99 | def update(self, id, val): 100 | record = self.book.get(id, None) 101 | if record is None: 102 | self.book[id] = [val, 1] 103 | else: 104 | record[0]+=val 105 | record[1]+=1 106 | 107 | def get_results(self, id): 108 | record = self.book.get(id, None) 109 | assert record is not None 110 | return record[0] / record[1] 111 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | '''Pre-activation version of the BasicBlock.''' 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(PreActBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 20 | ) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(x)) 24 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 25 | out = self.conv1(out) 26 | out = self.conv2(F.relu(self.bn2(out))) 27 | out += shortcut 28 | return out 29 | 30 | 31 | class PreActBottleneck(nn.Module): 32 | '''Pre-activation version of the original Bottleneck module.''' 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1): 36 | super(PreActBottleneck, self).__init__() 37 | self.bn1 = nn.BatchNorm2d(in_planes) 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 43 | 44 | if stride != 1 or in_planes != self.expansion*planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 47 | ) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(x)) 51 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 52 | out = self.conv1(out) 53 | out = self.conv2(F.relu(self.bn2(out))) 54 | out = self.conv3(F.relu(self.bn3(out))) 55 | out += shortcut 56 | return out 57 | 58 | 59 | class PreActResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(PreActResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 65 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 66 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 67 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 68 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 69 | self.bn = nn.BatchNorm2d(512 * block.expansion) 70 | self.linear = nn.Linear(512*block.expansion, num_classes) 71 | 72 | def _make_layer(self, block, planes, num_blocks, stride): 73 | strides = [stride] + [1]*(num_blocks-1) 74 | layers = [] 75 | for stride in strides: 76 | layers.append(block(self.in_planes, planes, stride)) 77 | self.in_planes = planes * block.expansion 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = self.conv1(x) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = self.layer4(out) 86 | out = F.relu(self.bn(out)) 87 | out = F.avg_pool2d(out, 4) 88 | out = out.view(out.size(0), -1) 89 | out = self.linear(out) 90 | return out 91 | 92 | 93 | def PreActResNet18(num_classes=10): 94 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes) 95 | -------------------------------------------------------------------------------- /SAM_segmentation/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from tqdm import tqdm 6 | 7 | 8 | def gen_bar_updater(pbar): 9 | def bar_update(count, block_size, total_size): 10 | if pbar.total is None and total_size: 11 | pbar.total = total_size 12 | progress_bytes = count * block_size 13 | pbar.update(progress_bytes - pbar.n) 14 | 15 | return bar_update 16 | 17 | 18 | def check_integrity(fpath, md5=None): 19 | if md5 is None: 20 | return True 21 | if not os.path.isfile(fpath): 22 | return False 23 | md5o = hashlib.md5() 24 | with open(fpath, 'rb') as f: 25 | # read in 1MB chunks 26 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 27 | md5o.update(chunk) 28 | md5c = md5o.hexdigest() 29 | if md5c != md5: 30 | return False 31 | return True 32 | 33 | 34 | def makedir_exist_ok(dirpath): 35 | """ 36 | Python2 support for os.makedirs(.., exist_ok=True) 37 | """ 38 | try: 39 | os.makedirs(dirpath) 40 | except OSError as e: 41 | if e.errno == errno.EEXIST: 42 | pass 43 | else: 44 | raise 45 | 46 | 47 | def download_url(url, root, filename=None, md5=None): 48 | """Download a file from a url and place it in root. 49 | Args: 50 | url (str): URL to download file from 51 | root (str): Directory to place downloaded file in 52 | filename (str): Name to save the file under. If None, use the basename of the URL 53 | md5 (str): MD5 checksum of the download. If None, do not check 54 | """ 55 | from six.moves import urllib 56 | 57 | root = os.path.expanduser(root) 58 | if not filename: 59 | filename = os.path.basename(url) 60 | fpath = os.path.join(root, filename) 61 | 62 | makedir_exist_ok(root) 63 | 64 | # downloads file 65 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 66 | print('Using downloaded and verified file: ' + fpath) 67 | else: 68 | try: 69 | print('Downloading ' + url + ' to ' + fpath) 70 | urllib.request.urlretrieve( 71 | url, fpath, 72 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) 73 | ) 74 | except OSError: 75 | if url[:5] == 'https': 76 | url = url.replace('https:', 'http:') 77 | print('Failed download. Trying https -> http instead.' 78 | ' Downloading ' + url + ' to ' + fpath) 79 | urllib.request.urlretrieve( 80 | url, fpath, 81 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) 82 | ) 83 | 84 | 85 | def list_dir(root, prefix=False): 86 | """List all directories at a given root 87 | Args: 88 | root (str): Path to directory whose folders need to be listed 89 | prefix (bool, optional): If true, prepends the path to each result, otherwise 90 | only returns the name of the directories found 91 | """ 92 | root = os.path.expanduser(root) 93 | directories = list( 94 | filter( 95 | lambda p: os.path.isdir(os.path.join(root, p)), 96 | os.listdir(root) 97 | ) 98 | ) 99 | 100 | if prefix is True: 101 | directories = [os.path.join(root, d) for d in directories] 102 | 103 | return directories 104 | 105 | 106 | def list_files(root, suffix, prefix=False): 107 | """List all files ending with a suffix at a given root 108 | Args: 109 | root (str): Path to directory whose folders need to be listed 110 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 111 | It uses the Python "str.endswith" method and is passed directly 112 | prefix (bool, optional): If true, prepends the path to each result, otherwise 113 | only returns the name of the files found 114 | """ 115 | root = os.path.expanduser(root) 116 | files = list( 117 | filter( 118 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 119 | os.listdir(root) 120 | ) 121 | ) 122 | 123 | if prefix is True: 124 | files = [os.path.join(root, d) for d in files] 125 | 126 | return files -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/examples/evaluate_imagenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | import sys 5 | import os 6 | from torch import optim 7 | from torch.utils.data import DataLoader 8 | from torchvision.models import resnet50 9 | from torchvision.datasets import ImageNet 10 | from torchvision import transforms 11 | 12 | # mister_ed 13 | from recoloradv.mister_ed import loss_functions as lf 14 | from recoloradv.mister_ed import adversarial_training as advtrain 15 | from recoloradv.mister_ed import adversarial_perturbations as ap 16 | from recoloradv.mister_ed import adversarial_attacks as aa 17 | from recoloradv.mister_ed import spatial_transformers as st 18 | from recoloradv.mister_ed.utils import pytorch_utils as utils 19 | 20 | # ReColorAdv 21 | from recoloradv import perturbations as pt 22 | from recoloradv import color_transformers as ct 23 | from recoloradv import color_spaces as cs 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser( 28 | description='Evaluate a ResNet-50 trained on Imagenet ' 29 | 'against ReColorAdv' 30 | ) 31 | 32 | parser.add_argument('--imagenet_path', type=str, required=True, 33 | help='path to ImageNet dataset') 34 | parser.add_argument('--batch_size', type=int, default=100, 35 | help='number of examples/minibatch') 36 | parser.add_argument('--num_batches', type=int, required=False, 37 | help='number of batches (default entire dataset)') 38 | args = parser.parse_args() 39 | 40 | model = resnet50(pretrained=True, progress=True) 41 | normalizer = utils.DifferentiableNormalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | 44 | dataset = ImageNet( 45 | args.imagenet_path, 46 | split='val', 47 | transform=transforms.Compose([ 48 | transforms.CenterCrop(224), 49 | transforms.ToTensor(), 50 | ]), 51 | ) 52 | val_loader = DataLoader( 53 | dataset, 54 | batch_size=args.batch_size, 55 | shuffle=True, 56 | ) 57 | 58 | model.eval() 59 | if torch.cuda.is_available(): 60 | model.cuda() 61 | 62 | cw_loss = lf.CWLossF6(model, normalizer, kappa=float('inf')) 63 | perturbation_loss = lf.PerturbationNormLoss(lp=2) 64 | adv_loss = lf.RegularizedLoss( 65 | {'cw': cw_loss, 'pert': perturbation_loss}, 66 | {'cw': 1.0, 'pert': 0.05}, 67 | negate=True, 68 | ) 69 | 70 | pgd_attack = aa.PGD( 71 | model, 72 | normalizer, 73 | ap.ThreatModel(pt.ReColorAdv, { 74 | 'xform_class': ct.FullSpatial, 75 | 'cspace': cs.CIELUVColorSpace(), 76 | 'lp_style': 'inf', 77 | 'lp_bound': 0.06, 78 | 'xform_params': { 79 | 'resolution_x': 16, 80 | 'resolution_y': 32, 81 | 'resolution_z': 32, 82 | }, 83 | 'use_smooth_loss': True, 84 | }), 85 | adv_loss, 86 | ) 87 | 88 | batches_correct = [] 89 | for batch_index, (inputs, labels) in enumerate(val_loader): 90 | if ( 91 | args.num_batches is not None and 92 | batch_index >= args.num_batches 93 | ): 94 | break 95 | 96 | if torch.cuda.is_available(): 97 | inputs = inputs.cuda() 98 | labels = labels.cuda() 99 | 100 | adv_inputs = pgd_attack.attack( 101 | inputs, 102 | labels, 103 | optimizer=optim.Adam, 104 | optimizer_kwargs={'lr': 0.001}, 105 | signed=False, 106 | verbose=False, 107 | num_iterations=(100, 300), 108 | ).adversarial_tensors() 109 | with torch.no_grad(): 110 | adv_logits = model(normalizer(adv_inputs)) 111 | batch_correct = (adv_logits.argmax(1) == labels).detach() 112 | 113 | batch_accuracy = batch_correct.float().mean().item() 114 | print(f'BATCH {batch_index:05d}', 115 | f'accuracy = {batch_accuracy * 100:.1f}', 116 | sep='\t') 117 | batches_correct.append(batch_correct) 118 | 119 | accuracy = torch.cat(batches_correct).float().mean().item() 120 | print('OVERALL ', 121 | f'accuracy = {accuracy * 100:.1f}', 122 | sep='\t') 123 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | from typing import Tuple 6 | 7 | from .mister_ed.cifar10 import cifar_resnets 8 | from .mister_ed.utils.pytorch_utils import DifferentiableNormalize 9 | from .mister_ed import config 10 | from .mister_ed import adversarial_perturbations as ap 11 | from .mister_ed import adversarial_attacks as aa 12 | from .mister_ed import spatial_transformers as st 13 | from .mister_ed import loss_functions as lf 14 | from .mister_ed import adversarial_training as advtrain 15 | 16 | from . import perturbations as pt 17 | from . import color_transformers as ct 18 | from . import color_spaces as cs 19 | 20 | 21 | def load_pretrained_cifar10_model( 22 | path: str, resnet_size: int = 32, 23 | ) -> Tuple[nn.Module, DifferentiableNormalize]: 24 | """ 25 | Loads a pretrained CIFAR-10 ResNet from the given path along with its 26 | associated normalizer. 27 | """ 28 | 29 | model: nn.Module = getattr(cifar_resnets, f'resnet{resnet_size}')() 30 | model_state = torch.load(path, map_location=torch.device('cpu')) 31 | model.load_state_dict({re.sub(r'^module\.', '', k): v for k, v in 32 | model_state['state_dict'].items()}) 33 | 34 | normalizer = DifferentiableNormalize( 35 | mean=config.CIFAR10_MEANS, 36 | std=config.CIFAR10_STDS, 37 | ) 38 | 39 | return model, normalizer 40 | 41 | 42 | def get_attack_from_name( 43 | name: str, 44 | classifier: nn.Module, 45 | normalizer: DifferentiableNormalize, 46 | verbose: bool = False, 47 | ) -> advtrain.AdversarialAttackParameters: 48 | """ 49 | Builds an attack from a name like "recoloradv" or "stadv+delta" or 50 | "recoloradv+stadv+delta". 51 | """ 52 | 53 | threats = [] 54 | norm_weights = [] 55 | 56 | for attack_part in name.split('+'): 57 | if attack_part == 'delta': 58 | threats.append(ap.ThreatModel( 59 | ap.DeltaAddition, 60 | ap.PerturbationParameters( 61 | lp_style='inf', 62 | lp_bound=1.0 / 255, 63 | ), 64 | )) 65 | norm_weights.append(0.0) 66 | elif attack_part == 'stadv': 67 | threats.append(ap.ThreatModel( 68 | ap.ParameterizedXformAdv, 69 | ap.PerturbationParameters( 70 | lp_style='inf', 71 | lp_bound=1.0 / 255, 72 | xform_class=st.FullSpatial, 73 | use_stadv=True, 74 | ), 75 | )) 76 | norm_weights.append(1.0) 77 | elif attack_part == 'recoloradv': 78 | threats.append(ap.ThreatModel( 79 | pt.ReColorAdv, 80 | ap.PerturbationParameters( 81 | lp_style='inf', 82 | lp_bound=[8.0/255, 8.0/255, 8.0/255], 83 | xform_params={ 84 | 'resolution_x': 16, 85 | 'resolution_y': 32, 86 | 'resolution_z': 32, 87 | }, 88 | xform_class=ct.FullSpatial, 89 | use_smooth_loss=True, 90 | cspace=cs.CIELUVColorSpace(), 91 | ), 92 | )) 93 | norm_weights.append(1.0) 94 | else: 95 | raise ValueError(f'Invalid attack "{attack_part}"') 96 | 97 | sequence_threat = ap.ThreatModel( 98 | ap.SequentialPerturbation, 99 | threats, 100 | ap.PerturbationParameters(norm_weights=norm_weights), 101 | ) 102 | 103 | # use PGD attack 104 | adv_loss = lf.CWLossF6(classifier, normalizer, kappa=float('inf')) 105 | st_loss = lf.PerturbationNormLoss(lp=2) 106 | loss_fxn = lf.RegularizedLoss({'adv': adv_loss, 'pert': st_loss}, 107 | {'adv': 1.0, 'pert': 0.05}, 108 | negate=True) 109 | 110 | pgd_attack = aa.PGD(classifier, normalizer, sequence_threat, loss_fxn) 111 | return advtrain.AdversarialAttackParameters( 112 | pgd_attack, 113 | 1.0, 114 | attack_specific_params={'attack_kwargs': { 115 | 'num_iterations': 10, 116 | 'optimizer': optim.Adam, 117 | 'optimizer_kwargs': {'lr': 0.001}, 118 | 'signed': False, 119 | 'verbose': verbose, 120 | }}, 121 | ) 122 | -------------------------------------------------------------------------------- /SAM_segmentation/network/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | class _SimpleSegmentationModel(nn.Module): 8 | def __init__(self, backbone, classifier): 9 | super(_SimpleSegmentationModel, self).__init__() 10 | self.backbone = backbone 11 | self.classifier = classifier 12 | 13 | def forward(self, x): 14 | input_shape = x.shape[-2:] 15 | features = self.backbone(x) 16 | x = self.classifier(features) 17 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 18 | return x 19 | 20 | 21 | class IntermediateLayerGetter(nn.ModuleDict): 22 | """ 23 | Module wrapper that returns intermediate layers from a model 24 | 25 | It has a strong assumption that the modules have been registered 26 | into the model in the same order as they are used. 27 | This means that one should **not** reuse the same nn.Module 28 | twice in the forward if you want this to work. 29 | 30 | Additionally, it is only able to query submodules that are directly 31 | assigned to the model. So if `model` is passed, `model.feature1` can 32 | be returned, but not `model.feature1.layer2`. 33 | 34 | Arguments: 35 | model (nn.Module): model on which we will extract the features 36 | return_layers (Dict[name, new_name]): a dict containing the names 37 | of the modules for which the activations will be returned as 38 | the key of the dict, and the value of the dict is the name 39 | of the returned activation (which the user can specify). 40 | 41 | Examples:: 42 | 43 | >>> m = torchvision.models.resnet18(pretrained=True) 44 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 45 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 46 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 47 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 48 | >>> print([(k, v.shape) for k, v in out.items()]) 49 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 50 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 51 | """ 52 | def __init__(self, model, return_layers, hrnet_flag=False): 53 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 54 | raise ValueError("return_layers are not present in model") 55 | 56 | self.hrnet_flag = hrnet_flag 57 | 58 | orig_return_layers = return_layers 59 | return_layers = {k: v for k, v in return_layers.items()} 60 | layers = OrderedDict() 61 | for name, module in model.named_children(): 62 | layers[name] = module 63 | if name in return_layers: 64 | del return_layers[name] 65 | if not return_layers: 66 | break 67 | 68 | super(IntermediateLayerGetter, self).__init__(layers) 69 | self.return_layers = orig_return_layers 70 | 71 | def forward(self, x): 72 | out = OrderedDict() 73 | for name, module in self.named_children(): 74 | if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition 75 | if name == 'transition1': # in transition1, you need to split the module to two streams first 76 | x = [trans(x) for trans in module] 77 | else: # all other transition is just an extra one stream split 78 | x.append(module(x[-1])) 79 | else: # other models (ex:resnet,mobilenet) are convolutions in series. 80 | x = module(x) 81 | 82 | if name in self.return_layers: 83 | out_name = self.return_layers[name] 84 | if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together 85 | output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream 86 | x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False) 87 | x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False) 88 | x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False) 89 | x = torch.cat([x[0], x1, x2, x3], dim=1) 90 | out[out_name] = x 91 | else: 92 | out[out_name] = x 93 | return out 94 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/perturbations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .mister_ed import adversarial_perturbations as ap 5 | from .mister_ed.adversarial_perturbations import initialized 6 | from .mister_ed.utils import pytorch_utils as utils 7 | 8 | from . import color_transformers as ct 9 | from . import norms 10 | from . import color_spaces as cs 11 | 12 | 13 | class ReColorAdv(ap.AdversarialPerturbation): 14 | """ 15 | Puts the color at each pixel in the image through the same transformation. 16 | 17 | Parameters: 18 | - lp_style: number or 'inf' 19 | - lp_bound: maximum norm of color transformation. Can be a tensor of size 20 | (num_channels,), in which case each channel will be bounded by the 21 | cooresponding bound in the tensor. For instance, passing 22 | [0.1, 0.15, 0.05] would allow a norm of 0.1 for R, 0.15 for G, and 0.05 23 | for B. Not supported by all transformations. 24 | - use_smooth_loss: whether to optimize using the loss function 25 | for FullSpatial that rewards smooth vector fields 26 | - xform_class: a subclass of 27 | color_transformers.ParameterizedTransformation 28 | - xform_params: dict of parameters to pass to the xform_class. 29 | - cspace_class: a subclass of color_spaces.ColorSpace that indicates 30 | in which color space the transformation should be performed 31 | (RGB by default) 32 | """ 33 | 34 | def __init__(self, threat_model, perturbation_params, *other_args): 35 | super().__init__(threat_model, perturbation_params) 36 | assert issubclass(perturbation_params.xform_class, 37 | ct.ParameterizedTransformation) 38 | 39 | self.lp_style = perturbation_params.lp_style 40 | self.lp_bound = perturbation_params.lp_bound 41 | self.use_smooth_loss = perturbation_params.use_smooth_loss 42 | self.scalar_step = perturbation_params.scalar_step or 1.0 43 | self.cspace = perturbation_params.cspace or cs.RGBColorSpace() 44 | 45 | def _merge_setup(self, num_examples, new_xform): 46 | """ DANGEROUS TO BE CALLED OUTSIDE OF THIS FILE!!!""" 47 | self.num_examples = num_examples 48 | self.xform = new_xform 49 | self.initialized = True 50 | 51 | def setup(self, originals): 52 | super().setup(originals) 53 | self.xform = self.perturbation_params.xform_class( 54 | shape=originals.shape, manual_gpu=self.use_gpu, 55 | cspace=self.cspace, 56 | **(self.perturbation_params.xform_params or {}), 57 | ) 58 | self.initialized = True 59 | 60 | @initialized 61 | def perturbation_norm(self, x=None, lp_style=None): 62 | lp_style = lp_style or self.lp_style 63 | if self.use_smooth_loss: 64 | assert isinstance(self.xform, ct.FullSpatial) 65 | return self.xform.smoothness_norm() 66 | else: 67 | return self.xform.norm(lp=lp_style) 68 | 69 | @initialized 70 | def constrain_params(self, x=None): 71 | # Do lp projections 72 | if isinstance(self.lp_style, int) or self.lp_style == 'inf': 73 | self.xform.project_params(self.lp_style, self.lp_bound) 74 | 75 | @initialized 76 | def update_params(self, step_fxn): 77 | param_list = list(self.xform.parameters()) 78 | assert len(param_list) == 1 79 | params = param_list[0] 80 | assert params.grad.data is not None 81 | self.add_to_params(step_fxn(params.grad.data) * self.scalar_step) 82 | 83 | @initialized 84 | def add_to_params(self, grad_data): 85 | """ Assumes only one parameters object in the Spatial Transform """ 86 | param_list = list(self.xform.parameters()) 87 | assert len(param_list) == 1 88 | params = param_list[0] 89 | params.data.add_(grad_data) 90 | 91 | @initialized 92 | def random_init(self): 93 | param_list = list(self.xform.parameters()) 94 | assert len(param_list) == 1 95 | param = param_list[0] 96 | random_perturb = utils.random_from_lp_ball(param.data, 97 | self.lp_style, 98 | self.lp_bound) 99 | 100 | param.data.add_(self.xform.identity_params + 101 | random_perturb - self.xform.xform_params.data) 102 | 103 | @initialized 104 | def merge_perturbation(self, other, self_mask): 105 | super().merge_perturbation(other, self_mask) 106 | new_perturbation = ReColorAdv(self.threat_model, 107 | self.perturbation_params) 108 | 109 | new_xform = self.xform.merge_xform(other.xform, self_mask) 110 | new_perturbation._merge_setup(self.num_examples, new_xform) 111 | 112 | return new_perturbation 113 | 114 | def forward(self, x): 115 | if not self.initialized: 116 | self.setup(x) 117 | self.constrain_params() 118 | 119 | return self.cspace.to_rgb( 120 | self.xform.forward(self.cspace.from_rgb(x))) 121 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import dataset, dataloader 4 | from torchvision import datasets, transforms 5 | 6 | cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 7 | cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 8 | 9 | mu = torch.tensor(cifar10_mean).view(3,1,1) 10 | std = torch.tensor(cifar10_std).view(3,1,1) 11 | 12 | def normalize_cifar(x): 13 | return (x - mu.to(x.device))/(std.to(x.device)) 14 | 15 | CIFAR100_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) 16 | CIFAR100_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) 17 | 18 | mu_cifar100 = torch.tensor(CIFAR100_MEAN).view(3,1,1).cuda() 19 | std_cifar100 = torch.tensor(CIFAR100_STD).view(3,1,1).cuda() 20 | 21 | def normalize_cifar100(x): 22 | return (x - mu_cifar100.to(x.device))/(std_cifar100.to(x.device)) 23 | 24 | def load_dataset(dataset='cifar10', batch_size=128): 25 | if dataset == 'cifar10': 26 | transform_ = transforms.Compose([transforms.ToTensor()]) 27 | train_transform_ = transforms.Compose([ 28 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor()]) 31 | 32 | train_loader = torch.utils.data.DataLoader( 33 | datasets.CIFAR10('/data/cifar_data', train=True, download=True, transform=train_transform_), 34 | batch_size=batch_size, shuffle=True) 35 | 36 | test_loader = torch.utils.data.DataLoader( 37 | datasets.CIFAR10('/data/cifar_data', train=False, download=True, transform=transform_), 38 | batch_size=batch_size, shuffle=False) 39 | 40 | return train_loader, test_loader 41 | 42 | elif dataset == 'cifar100': 43 | transform_ = transforms.Compose([transforms.ToTensor()]) 44 | train_transform_ = transforms.Compose([ 45 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor()]) 48 | train_loader = torch.utils.data.DataLoader( 49 | datasets.CIFAR100('/data/cifar_data', train=True, download=True, transform=train_transform_), 50 | batch_size=batch_size, shuffle=True) 51 | test_loader = torch.utils.data.DataLoader( 52 | datasets.CIFAR100('/data/cifar_data', train=False, download=True, transform=transform_), 53 | batch_size=batch_size, shuffle=False) 54 | 55 | return train_loader, test_loader 56 | 57 | class Attack(): 58 | def __init__(self, iters, alpha, eps, norm, criterion, rand_init, rand_perturb, targeted, normalize=normalize_cifar): 59 | self.iters = iters 60 | self.alpha = alpha 61 | self.eps = eps 62 | self.norm = norm 63 | assert norm in ['linf', 'l2'] 64 | self.criterion = criterion # loss function for perturb 65 | self.rand_init = rand_init # random initialization before perturb 66 | self.rand_perturb = rand_perturb # add random noise in each step 67 | self.targetd = targeted # targeted attack 68 | self.normalize = normalize # normalize_cifar 69 | 70 | def perturb(self, model, x, y): 71 | delta = torch.zeros_like(x).to(x.device) 72 | if self.rand_init: 73 | 74 | if self.norm == "linf": 75 | delta.uniform_(-self.eps, self.eps) 76 | elif self.norm == "l2": 77 | delta.normal_() 78 | d_flat = delta.view(delta.size(0),-1) 79 | n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1) 80 | r = torch.zeros_like(n).uniform_(0, 1) 81 | delta *= r/n*self.eps 82 | else: 83 | raise ValueError 84 | 85 | delta = torch.clamp(delta, 0-x, 1-x) 86 | delta.requires_grad = True 87 | 88 | for _ in range(self.iters): 89 | output = model(self.normalize(x+delta)) 90 | loss = self.criterion(output, y) 91 | if self.targetd: 92 | loss *= -1 93 | loss.backward() 94 | g = delta.grad.detach() 95 | if self.norm == "linf": 96 | d = torch.clamp(delta + self.alpha * torch.sign(g), min=-self.eps, max=self.eps).detach() 97 | elif self.norm == "l2": 98 | g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1) 99 | scaled_g = g/(g_norm + 1e-10) 100 | d = (delta + scaled_g*self.alpha).view(delta.size(0),-1).renorm(p=2,dim=0,maxnorm=self.eps).view_as(delta).detach() 101 | d = torch.clamp(d, 0 - x, 1 - x) 102 | delta.data = d 103 | delta.grad.zero_() 104 | 105 | return delta.detach() 106 | 107 | class PGD(Attack): 108 | def __init__(self, iters, alpha, eps, norm, targeted=False, normalize=normalize_cifar): 109 | super().__init__(iters, alpha, eps, norm, nn.CrossEntropyLoss(), True, False, targeted, normalize=normalize) -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/cifar10/cifar_resnets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | MISTER_ED_NOTE: I blatantly copied this code from this github repository: 3 | https://github.com/akamaster/pytorch_resnet_cifar10 4 | 5 | Huge kudos to Yerlan Idelbayev. 6 | ''' 7 | 8 | 9 | 10 | ''' 11 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 12 | 13 | The implementation and structure of this file is hugely influenced by [2] 14 | which is implemented for ImageNet and doesn't have option A for identity. 15 | Moreover, most of the implementations on the web is copy-paste from 16 | torchvision's resnet and has wrong number of params. 17 | 18 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 19 | number of layers and parameters: 20 | 21 | name | layers | params 22 | ResNet20 | 20 | 0.27M 23 | ResNet32 | 32 | 0.46M 24 | ResNet44 | 44 | 0.66M 25 | ResNet56 | 56 | 0.85M 26 | ResNet110 | 110 | 1.7M 27 | ResNet1202| 1202 | 19.4m 28 | 29 | which this implementation indeed has. 30 | 31 | Reference: 32 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 33 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 34 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 35 | 36 | If you use this implementation in you work, please don't forget to mention the 37 | author, Yerlan Idelbayev. 38 | ''' 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | import torch.nn.init as init 43 | 44 | from torch.autograd import Variable 45 | 46 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 47 | 48 | def _weights_init(m): 49 | classname = m.__class__.__name__ 50 | # print(classname) 51 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 52 | try: 53 | init.kaiming_normal_(m.weight) 54 | except AttributeError: 55 | init.kaiming_normal(m.weight) 56 | 57 | class LambdaLayer(nn.Module): 58 | def __init__(self, lambd): 59 | super(LambdaLayer, self).__init__() 60 | self.lambd = lambd 61 | 62 | def forward(self, x): 63 | return self.lambd(x) 64 | 65 | 66 | class BasicBlock(nn.Module): 67 | expansion = 1 68 | 69 | def __init__(self, in_planes, planes, stride=1, option='A'): 70 | super(BasicBlock, self).__init__() 71 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(planes) 73 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(planes) 75 | 76 | self.shortcut = nn.Sequential() 77 | if stride != 1 or in_planes != planes: 78 | if option == 'A': 79 | """ 80 | For CIFAR10 ResNet paper uses option A. 81 | """ 82 | self.shortcut = LambdaLayer(lambda x: 83 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 84 | elif option == 'B': 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 87 | nn.BatchNorm2d(self.expansion * planes) 88 | ) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.bn2(self.conv2(out)) 93 | out += self.shortcut(x) 94 | out = F.relu(out) 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | def __init__(self, block, num_blocks, num_classes=10): 100 | super(ResNet, self).__init__() 101 | self.in_planes = 16 102 | 103 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(16) 105 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 106 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 107 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 108 | self.linear = nn.Linear(64, num_classes) 109 | self.apply(_weights_init) 110 | 111 | 112 | def _make_layer(self, block, planes, num_blocks, stride): 113 | strides = [stride] + [1]*(num_blocks-1) 114 | layers = [] 115 | for stride in strides: 116 | layers.append(block(self.in_planes, planes, stride)) 117 | self.in_planes = planes * block.expansion 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | out = self.layer1(out) 124 | out = self.layer2(out) 125 | out = self.layer3(out) 126 | out = F.avg_pool2d(out, out.size()[3]) 127 | out = out.view(out.size(0), -1) 128 | out = self.linear(out) 129 | return out 130 | 131 | 132 | def resnet20(): 133 | return ResNet(BasicBlock, [3, 3, 3]) 134 | 135 | 136 | def resnet32(): 137 | return ResNet(BasicBlock, [5, 5, 5]) 138 | 139 | 140 | def resnet44(): 141 | return ResNet(BasicBlock, [7, 7, 7]) 142 | 143 | 144 | def resnet56(): 145 | return ResNet(BasicBlock, [9, 9, 9]) 146 | 147 | 148 | def resnet110(): 149 | return ResNet(BasicBlock, [18, 18, 18]) 150 | 151 | 152 | def resnet1202(): 153 | return ResNet(BasicBlock, [200, 200, 200]) 154 | 155 | 156 | def test(net): 157 | import numpy as np 158 | total_params = 0 159 | 160 | for x in filter(lambda p: p.requires_grad, net.parameters()): 161 | total_params += np.prod(x.data.numpy().shape) 162 | print("Total number of params", total_params) 163 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 164 | -------------------------------------------------------------------------------- /train_eval_scripts/stadv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import transforms, datasets 8 | from scipy import optimize 9 | from utils import * 10 | from model import PreActResNet18 11 | 12 | 13 | 14 | def flow_st(images, flows): 15 | images_shape = images.size() 16 | flows_shape = flows.size() 17 | batch_size = images_shape[0] 18 | H = images_shape[2] 19 | W = images_shape[3] 20 | basegrid = torch.stack(torch.meshgrid(torch.arange(0, H), torch.arange(0, W))) # (2,H,W) 21 | sampling_grid = basegrid.unsqueeze(0).type(torch.float32).cuda() + flows.cuda() 22 | sampling_grid_x = torch.clamp(sampling_grid[:, 1], 0.0, W - 1.0).type(torch.float32) 23 | sampling_grid_y = torch.clamp(sampling_grid[:, 0], 0.0, H - 1.0).type(torch.float32) 24 | 25 | x0 = torch.floor(sampling_grid_x).type(torch.int64) 26 | x1 = x0 + 1 27 | y0 = torch.floor(sampling_grid_y).type(torch.int64) 28 | y1 = y0 + 1 29 | 30 | x0 = torch.clamp(x0, 0, W - 2) 31 | x1 = torch.clamp(x1, 0, W - 1) 32 | y0 = torch.clamp(y0, 0, H - 2) 33 | y1 = torch.clamp(y1, 0, H - 1) 34 | 35 | Ia = images[:, :, y0[0, :, :], x0[0, :, :]] 36 | Ib = images[:, :, y1[0, :, :], x0[0, :, :]] 37 | Ic = images[:, :, y0[0, :, :], x1[0, :, :]] 38 | Id = images[:, :, y1[0, :, :], x1[0, :, :]] 39 | 40 | x0 = x0.type(torch.float32) 41 | x1 = x1.type(torch.float32) 42 | y0 = y0.type(torch.float32) 43 | y1 = y1.type(torch.float32) 44 | 45 | wa = (x1 - sampling_grid_x) * (y1 - sampling_grid_y) 46 | wb = (x1 - sampling_grid_x) * (sampling_grid_y - y0) 47 | wc = (sampling_grid_x - x0) * (y1 - sampling_grid_y) 48 | wd = (sampling_grid_x - x0) * (sampling_grid_y - y0) 49 | 50 | perturbed_image = wa.unsqueeze(0) * Ia + wb.unsqueeze(0) * Ib + wc.unsqueeze(0) * Ic + wd.unsqueeze(0) * Id 51 | 52 | return perturbed_image.type(torch.float32).cuda() 53 | 54 | 55 | def flow_loss(flows, padding_mode='constant', epsilon=1e-8): 56 | paddings = (1, 1, 1, 1) 57 | padded_flows = F.pad(flows, paddings, mode=padding_mode, value=0) 58 | shifted_flows = [ 59 | padded_flows[:, :, 2:, 2:], # bottom right (+1,+1) 60 | padded_flows[:, :, 2:, :-2], # bottom left (+1,-1) 61 | padded_flows[:, :, :-2, 2:], # top right (-1,+1) 62 | padded_flows[:, :, :-2, :-2] # top left (-1,-1) 63 | ] 64 | # ||\Delta u^{(p)} - \Delta u^{(q)}||_2^2 + # ||\Delta v^{(p)} - \Delta v^{(q)}||_2^2 65 | loss = 0 66 | for shifted_flow in shifted_flows: 67 | loss += torch.sum(torch.square(flows[:, 1] - shifted_flow[:, 1]) + torch.square( 68 | flows[:, 0] - shifted_flow[:, 0]) + epsilon).cuda() 69 | return loss.type(torch.float32) 70 | 71 | 72 | def adv_loss(logits, targets, confidence=0.0): 73 | confidence = torch.tensor(confidence).cuda() 74 | real = torch.sum(logits * targets, -1) 75 | other = torch.max((1 - targets) * logits - (targets * 10000), -1)[0] 76 | return torch.max(other - real, confidence)[0].type(torch.float32) 77 | 78 | 79 | def func(flows, input, target, model, const=0.05): 80 | input = torch.from_numpy(input).cuda() 81 | target = torch.from_numpy(target).cuda() 82 | flows = torch.from_numpy(flows).view((1, 2,) + input.size()[2:]).cuda() 83 | flows.requires_grad = True 84 | pert_out = flow_st(input, flows) 85 | output = model(pert_out) 86 | L_flow = flow_loss(flows) 87 | L_adv = adv_loss(output, target) 88 | L_final = L_adv + const * L_flow 89 | model.zero_grad() 90 | L_final.backward() 91 | gradient = flows.grad.data.view(-1).detach().cpu().numpy() 92 | return L_final.item(), gradient 93 | 94 | 95 | def attack(input, target, model): 96 | init_flows = np.zeros((1, 2,) + input.size()[2:]).reshape(-1) 97 | results = optimize.fmin_l_bfgs_b(func, init_flows, args=(input.cpu().numpy(), target.cpu().numpy(), model)) 98 | flows = torch.from_numpy(results[0]).view((1, 2,) + input.size()[2:]) 99 | pert_out = flow_st(input, flows) 100 | return pert_out 101 | 102 | 103 | class Model(nn.Module): 104 | def __init__(self, model, norm): 105 | super(Model, self).__init__() 106 | self.model = model 107 | self.norm = norm 108 | 109 | def forward(self, x): 110 | return self.model(self.norm(x)) 111 | 112 | if __name__ == '__main__': 113 | np.random.seed(42) 114 | torch.manual_seed(42) 115 | model = PreActResNet18(10) 116 | model.load_state_dict(torch.load('./cifar10_models/cifar10_prn_sgd_sub.pth')) 117 | Mod = Model(model, normalize_cifar) 118 | Mod.eval() 119 | Mod.cuda() 120 | 121 | use_cuda = True 122 | device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu") 123 | 124 | train_loader, test_loader = load_dataset('cifar10', 1) 125 | norm = normalize_cifar 126 | 127 | adv = [] 128 | adv_label = [] 129 | correct_label = [] 130 | sample = 10000 131 | success = 0 132 | target_s = 0 133 | for i, (x, y) in enumerate(test_loader): 134 | x, y = x.cuda(), y.cuda() 135 | # y : [x] -> [x+1 mod 10] 136 | target = (y + 1) % 10 137 | pert_out = attack(x, target, model) 138 | if pert_out is not None: 139 | output = model(pert_out) 140 | success += (output.max(1)[1] != y).float().sum().item() 141 | target_s += (output.max(1)[1] == target).float().sum().item() 142 | print(output, y, target) 143 | else: 144 | break 145 | 146 | print('success: ', success, 'sample: ', i, 'target: ', target_s) 147 | print(success / sample) 148 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import argparse 8 | from time import time 9 | 10 | from utils import * 11 | from model import PreActResNet18 12 | from sam import SAM 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--fname', type=str, required=True) 18 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100']) 19 | parser.add_argument('--epochs', default=100, type=int) 20 | parser.add_argument('--max-lr', default=0.1, type=float) 21 | parser.add_argument('--opt', default='SAM', choices=['SAM', 'SGD']) 22 | parser.add_argument('--batch-size', default=128, type=int) 23 | parser.add_argument('--device', default=0, type=int) 24 | parser.add_argument('--adv', action='store_true') 25 | parser.add_argument('--rho', default=0.05, type=float) # for SAM 26 | 27 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2']) 28 | parser.add_argument('--train-eps', default=8., type=float) 29 | parser.add_argument('--train-alpha', default=2., type=float) 30 | parser.add_argument('--train-step', default=5, type=int) 31 | 32 | parser.add_argument('--test-eps', default=1., type=float) 33 | parser.add_argument('--test-alpha', default=0.5, type=float) 34 | parser.add_argument('--test-step', default=5, type=int) 35 | return parser.parse_args() 36 | 37 | args = get_args() 38 | 39 | def lr_schedule(epoch): 40 | if epoch < args.epochs * 0.75: 41 | return args.max_lr 42 | elif epoch < args.epochs * 0.9: 43 | return args.max_lr * 0.1 44 | else: 45 | return args.max_lr * 0.01 46 | 47 | if __name__ == '__main__': 48 | 49 | dataset = args.dataset 50 | device = f'cuda:{args.device}' 51 | model = PreActResNet18(10 if dataset == 'cifar10' else 100).to(device) 52 | train_loader, test_loader = load_dataset(dataset, args.batch_size) 53 | params = model.parameters() 54 | criterion = nn.CrossEntropyLoss() 55 | 56 | 57 | if args.opt == 'SGD': 58 | opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4) 59 | elif args.opt == 'SAM': 60 | base_opt = torch.optim.SGD 61 | opt = SAM(params, base_opt,lr=args.max_lr, momentum=0.9, weight_decay=5e-4, rho=args.rho) 62 | normalize = normalize_cifar if dataset == 'cifar10' else normalize_cifar100 63 | 64 | all_log_data = [] 65 | train_pgd = PGD(args.train_step, args.train_alpha / 255., args.train_eps / 255., args.norm, False, normalize) 66 | test_pgd = PGD(args.test_step, args.test_alpha / 255., args.test_eps / 255., args.norm, False, normalize) 67 | 68 | for epoch in range(args.epochs): 69 | start_time = time() 70 | log_data = [0,0,0,0,0,0] # train_loss, train_acc, test_loss, test_acc, test_robust_loss, test_robust 71 | # train 72 | model.train() 73 | lr = lr_schedule(epoch) 74 | opt.param_groups[0].update(lr=lr) 75 | for x, y in train_loader: 76 | x, y = x.to(device), y.to(device) 77 | if args.adv: 78 | delta = train_pgd.perturb(model, x, y) 79 | else: 80 | delta = torch.zeros_like(x).to(x.device) 81 | 82 | output = model(normalize(x + delta)) 83 | loss = criterion(output, y) 84 | 85 | if args.opt == 'SGD': 86 | opt.zero_grad() 87 | loss.backward() 88 | opt.step() 89 | 90 | elif args.opt == 'SAM': 91 | loss.backward() 92 | opt.first_step(zero_grad=True) 93 | 94 | output_2 = model(normalize(x + delta)) 95 | criterion(output_2, y).backward() 96 | opt.second_step(zero_grad=True) 97 | 98 | log_data[0] += (loss * len(y)).item() 99 | log_data[1] += (output.max(1)[1] == y).float().sum().item() 100 | 101 | # test 102 | model.eval() 103 | for x, y in test_loader: 104 | 105 | x, y = x.to(device), y.to(device) 106 | # clean 107 | output = model(normalize(x)).detach() 108 | loss = criterion(output, y) 109 | 110 | log_data[2] += (loss * len(y)).item() 111 | log_data[3] += (output.max(1)[1] == y).float().sum().item() 112 | continue 113 | delta = test_pgd.perturb(model, x, y) 114 | output = model(normalize(x + delta)).detach() 115 | loss = criterion(output, y) 116 | 117 | log_data[4] += (loss * len(y)).item() 118 | log_data[5] += (output.max(1)[1] == y).float().sum().item() 119 | 120 | log_data = np.array(log_data) 121 | log_data[0] /= 60000 122 | log_data[1] /= 60000 123 | log_data[2] /= 10000 124 | log_data[3] /= 10000 125 | log_data[4] /= 10000 126 | log_data[5] /= 10000 127 | all_log_data.append(log_data) 128 | 129 | print(f'Epoch {epoch}:\t',log_data,f'\tTime {time()-start_time:.1f}s') 130 | torch.save(model.state_dict(), f'models/{args.fname}.pth' if args.dataset == 'cifar10' else f'cifar100_models/{args.fname}.pth') 131 | 132 | all_log_data = np.stack(all_log_data,axis=0) 133 | 134 | df = pd.DataFrame(all_log_data) 135 | df.to_csv(f'logs/{args.fname}.csv') 136 | 137 | 138 | plt.plot(all_log_data[:, [2,4]]) 139 | plt.grid() 140 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Loss', fontsize=16) 141 | plt.legend(['clean', 'robust'], fontsize=16) 142 | plt.savefig(f'figs/{args.fname}_loss.png', dpi=200) 143 | plt.clf() 144 | 145 | plt.plot(all_log_data[:, [3,5]]) 146 | plt.grid() 147 | #plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Acc', fontsize=16) 148 | plt.legend(['clean', 'robust'], fontsize=16) 149 | plt.savefig(f'figs/{args.fname}_acc.png', dpi=200) 150 | plt.clf() 151 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/scripts/setup_cifar.py: -------------------------------------------------------------------------------- 1 | """ Script to ensure that: 2 | 1) all dependencies are installed correctly 3 | 2) CIFAR data can be accessed locally 4 | 3) a functional classifier for CIFAR has been loaded. 5 | 6 | """ 7 | 8 | 9 | ############################################################################## 10 | # # 11 | # STEP ONE: DEPENDENCIES ARE INSTALLED # 12 | # # 13 | ############################################################################## 14 | from __future__ import print_function 15 | print("Checking imports...") 16 | import sys 17 | import os 18 | sys.path.append(os.path.abspath(os.path.split(os.path.split(__file__)[0])[0])) 19 | 20 | import torch 21 | import glob 22 | import numpy as np 23 | import math 24 | import config 25 | import torchvision.datasets as datasets 26 | 27 | try: #This block from: https://stackoverflow.com/a/17510727 28 | # For Python 3.0 and later 29 | from urllib.request import urlopen 30 | except ImportError: 31 | # Fall back to Python 2's urllib2 32 | from urllib2 import urlopen 33 | 34 | import hashlib 35 | print("...imports look okay!") 36 | 37 | 38 | ############################################################################## 39 | # # 40 | # STEP TWO: CIFAR DATA HAS BEEN LOADED # 41 | # # 42 | ############################################################################## 43 | 44 | def check_cifar_data_loaded(): 45 | print("Checking CIFAR10 data loaded...") 46 | dataset_dir = config.DEFAULT_DATASETS_DIR 47 | 48 | train_set = datasets.CIFAR10(root=dataset_dir, train=True, download=True) 49 | val_set = datasets.CIFAR10(root=dataset_dir, train=False, download=True) 50 | 51 | print("...CIFAR10 data looks okay!") 52 | 53 | 54 | check_cifar_data_loaded() 55 | 56 | ############################################################################## 57 | # # 58 | # STEP THREE: LOAD CLASSIFIER FOR CIFAR10 # 59 | # # 60 | ############################################################################## 61 | 62 | 63 | # https://stackoverflow.com/a/44873382 64 | def file_hash(filename): 65 | h = hashlib.sha256() 66 | with open(filename, 'rb', buffering=0) as f: 67 | for b in iter(lambda : f.read(128*1024), b''): 68 | h.update(b) 69 | return h.hexdigest() 70 | 71 | 72 | 73 | def load_cifar_classifiers(): 74 | print("Checking CIFAR10 classifier exists...") 75 | 76 | # NOTE: pretrained models are produced by Yerlan Idelbayev 77 | # https://github.com/akamaster/pytorch_resnet_cifar10 78 | # I'm just hosting these on my dropbox for stability purposes 79 | 80 | # Check which models already exist in model directory 81 | resnet_name = lambda flavor: 'cifar10_resnet%s.th' % flavor 82 | total_cifar_files = set([resnet_name(flavor) for flavor in 83 | [1202, 110, 56, 44, 32, 20]]) 84 | total_cifar_files.add('Wide-Resnet28x10') 85 | 86 | try: 87 | os.makedirs(config.MODEL_PATH) 88 | except OSError as err: 89 | if not os.path.isdir(config.MODEL_PATH): 90 | raise err 91 | 92 | extant_models = set([_.split('/')[-1] for _ in 93 | glob.glob(os.path.join(*[config.MODEL_PATH, '*']))]) 94 | 95 | lacking_models = total_cifar_files - extant_models 96 | 97 | LINK_DEPOT = {resnet_name(20) : 'https://www.dropbox.com/s/glchyr9ljnpgvb5/cifar10_resnet20.th?dl=1', 98 | resnet_name(32) : 'https://www.dropbox.com/s/kis991c5w2qtgpq/cifar10_resnet32.th?dl=1', 99 | resnet_name(44) : 'https://www.dropbox.com/s/sigj56ysrti6s6a/cifar10_resnet44.th?dl=1', 100 | resnet_name(56) : 'https://www.dropbox.com/s/3p6d5tkvdgcbru5/c7ifar10_resnet56.th?dl=1', 101 | resnet_name(110) : 'https://www.dropbox.com/s/sp172x5vjlypfw6/cifar10_resnet110.th?dl=1', 102 | resnet_name(1202): 'https://www.dropbox.com/s/4qxfa6dmdliw9ko/cifar10_resnet1202.th?dl=1', 103 | 'Wide-Resnet28x10': 'https://www.dropbox.com/s/5ln2gow7mnxub29/cifar10_wide-resnet28x10.th?dl=1' 104 | } 105 | 106 | 107 | HASH_DEPOT = {resnet_name(20) : '12fca82f0bebc4135bf1f32f6e3710e61d5108578464b84fd6d7f5c1b04036c8', 108 | resnet_name(32) : 'd509ac1820d7f25398913559d7e81a13229b1e7adc5648e3bfa5e22dc137f850', 109 | resnet_name(44) : '014dd6541728a1c700b1642ab640e211dc6eb8ed507d70697458dc8f8a0ae2e4', 110 | resnet_name(56) : '4bfd97631478d6b638d2764fd2baff3edb1d7d82252d54439343b6596b9b5367', 111 | resnet_name(110) : '1d1ed7c27571399c1fef66969bc4df68d6a92c8e6c41170f444e120e5354e3bc', 112 | resnet_name(1202): 'f3b1deed382cd4c986ff8aa090c805d99a646e99d1f9227d7178183648844f62', 113 | 'Wide-Resnet28x10': 'd6a68ec2135294d91f9014abfdb45232d07fda0cdcd67f8c3b3653b28f08a88f'} 114 | 115 | for name in lacking_models: 116 | link = LINK_DEPOT[name] 117 | print("Downloading %s..." % name) 118 | u = urlopen(link) 119 | data = u.read() 120 | u.close() 121 | filename = os.path.join(config.MODEL_PATH, name) 122 | with open(filename, 'wb') as f: 123 | f.write(data) 124 | 125 | try: 126 | assert file_hash(filename) == HASH_DEPOT[name] 127 | except AssertionError as err: 128 | print("Something went wrong downloading %s" % name) 129 | os.remove(filename) 130 | raise err 131 | 132 | # Then load up all that doesn't already exist 133 | 134 | print("...CIFAR10 classifier looks okay") 135 | 136 | 137 | 138 | load_cifar_classifiers() 139 | 140 | 141 | print("\n Okay, you should be good to go now! ") 142 | print("Try running tutorial_{1,2,3}.ipynb in notebooks/") 143 | 144 | -------------------------------------------------------------------------------- /SAM_segmentation/datasets/voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tarfile 4 | import collections 5 | import torch.utils.data as data 6 | import shutil 7 | import numpy as np 8 | 9 | from PIL import Image 10 | from torchvision.datasets.utils import download_url, check_integrity 11 | 12 | DATASET_YEAR_DICT = { 13 | '2012': { 14 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 15 | 'filename': 'VOCtrainval_11-May-2012.tar', 16 | 'md5': '6cd6e144f989b92b3379bac3b3de84fd', 17 | 'base_dir': 'VOCdevkit/VOC2012' 18 | }, 19 | '2011': { 20 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', 21 | 'filename': 'VOCtrainval_25-May-2011.tar', 22 | 'md5': '6c3384ef61512963050cb5d687e5bf1e', 23 | 'base_dir': 'TrainVal/VOCdevkit/VOC2011' 24 | }, 25 | '2010': { 26 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', 27 | 'filename': 'VOCtrainval_03-May-2010.tar', 28 | 'md5': 'da459979d0c395079b5c75ee67908abb', 29 | 'base_dir': 'VOCdevkit/VOC2010' 30 | }, 31 | '2009': { 32 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', 33 | 'filename': 'VOCtrainval_11-May-2009.tar', 34 | 'md5': '59065e4b188729180974ef6572f6a212', 35 | 'base_dir': 'VOCdevkit/VOC2009' 36 | }, 37 | '2008': { 38 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', 39 | 'filename': 'VOCtrainval_11-May-2012.tar', 40 | 'md5': '2629fa636546599198acfcfbfcf1904a', 41 | 'base_dir': 'VOCdevkit/VOC2008' 42 | }, 43 | '2007': { 44 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 45 | 'filename': 'VOCtrainval_06-Nov-2007.tar', 46 | 'md5': 'c52e279531787c972589f7e41ab4ae64', 47 | 'base_dir': 'VOCdevkit/VOC2007' 48 | } 49 | } 50 | 51 | 52 | def voc_cmap(N=256, normalized=False): 53 | def bitget(byteval, idx): 54 | return ((byteval & (1 << idx)) != 0) 55 | 56 | dtype = 'float32' if normalized else 'uint8' 57 | cmap = np.zeros((N, 3), dtype=dtype) 58 | for i in range(N): 59 | r = g = b = 0 60 | c = i 61 | for j in range(8): 62 | r = r | (bitget(c, 0) << 7-j) 63 | g = g | (bitget(c, 1) << 7-j) 64 | b = b | (bitget(c, 2) << 7-j) 65 | c = c >> 3 66 | 67 | cmap[i] = np.array([r, g, b]) 68 | 69 | cmap = cmap/255 if normalized else cmap 70 | return cmap 71 | 72 | class VOCSegmentation(data.Dataset): 73 | """`Pascal VOC `_ Segmentation Dataset. 74 | Args: 75 | root (string): Root directory of the VOC Dataset. 76 | year (string, optional): The dataset year, supports years 2007 to 2012. 77 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 78 | download (bool, optional): If true, downloads the dataset from the internet and 79 | puts it in root directory. If dataset is already downloaded, it is not 80 | downloaded again. 81 | transform (callable, optional): A function/transform that takes in an PIL image 82 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 83 | """ 84 | cmap = voc_cmap() 85 | def __init__(self, 86 | root, 87 | year='2012', 88 | image_set='train', 89 | download=False, 90 | transform=None): 91 | 92 | is_aug=False 93 | if year=='2012_aug': 94 | is_aug = True 95 | year = '2012' 96 | 97 | self.root = os.path.expanduser(root) 98 | self.year = year 99 | self.url = DATASET_YEAR_DICT[year]['url'] 100 | self.filename = DATASET_YEAR_DICT[year]['filename'] 101 | self.md5 = DATASET_YEAR_DICT[year]['md5'] 102 | self.transform = transform 103 | 104 | self.image_set = image_set 105 | base_dir = DATASET_YEAR_DICT[year]['base_dir'] 106 | voc_root = os.path.join(self.root, base_dir) 107 | image_dir = os.path.join(voc_root, 'JPEGImages') 108 | 109 | if download: 110 | download_extract(self.url, self.root, self.filename, self.md5) 111 | 112 | if not os.path.isdir(voc_root): 113 | raise RuntimeError('Dataset not found or corrupted.' + 114 | ' You can use download=True to download it') 115 | 116 | if is_aug and image_set=='train': 117 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug') 118 | assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" 119 | split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt' 120 | else: 121 | mask_dir = os.path.join(voc_root, 'SegmentationClass') 122 | splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') 123 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') 124 | 125 | if not os.path.exists(split_f): 126 | raise ValueError( 127 | 'Wrong image_set entered! Please use image_set="train" ' 128 | 'or image_set="trainval" or image_set="val"') 129 | 130 | with open(os.path.join(split_f), "r") as f: 131 | file_names = [x.strip() for x in f.readlines()] 132 | 133 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 134 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] 135 | assert (len(self.images) == len(self.masks)) 136 | 137 | def __getitem__(self, index): 138 | """ 139 | Args: 140 | index (int): Index 141 | Returns: 142 | tuple: (image, target) where target is the image segmentation. 143 | """ 144 | img = Image.open(self.images[index]).convert('RGB') 145 | target = Image.open(self.masks[index]) 146 | if self.transform is not None: 147 | img, target = self.transform(img, target) 148 | 149 | if img.max() > 1 or img.min() < 0: 150 | img = (img - img.min()) / (img.max() - img.min()) 151 | 152 | return img, target 153 | 154 | 155 | def __len__(self): 156 | return len(self.images) 157 | 158 | @classmethod 159 | def decode_target(cls, mask): 160 | """decode semantic mask to RGB image""" 161 | return cls.cmap[mask] 162 | 163 | def download_extract(url, root, filename, md5): 164 | download_url(url, root, filename, md5) 165 | with tarfile.open(os.path.join(root, filename), "r") as tar: 166 | tar.extractall(path=root) -------------------------------------------------------------------------------- /train_eval_scripts/attack.py: -------------------------------------------------------------------------------- 1 | import torchattacks 2 | from model import PreActResNet18, WRN28_10, DeiT 3 | from utils import * 4 | import recoloradv.mister_ed.config as config 5 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize 6 | from recoloradv.utils import get_attack_from_name 7 | from argparse import ArgumentParser 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument('--model_path', default='cifar10_prn_sam_0_1.pth', type=str) 11 | args = parser.parse_args() 12 | file_name = args.model_path 13 | 14 | class Model(nn.Module): 15 | def __init__(self, model, norm): 16 | super(Model, self).__init__() 17 | self.model = model 18 | self.norm = norm 19 | 20 | def forward(self, x): 21 | return self.model(self.norm(x)) 22 | 23 | label_dim = 10 24 | if 'cifar10_' in file_name: 25 | label_dim = 10 26 | normalizer = DifferentiableNormalize( 27 | mean=config.CIFAR10_MEANS, 28 | std=config.CIFAR10_STDS, 29 | ) 30 | norm = normalize_cifar 31 | train_loader, test_loader = load_dataset('cifar10', 1000) 32 | elif 'cifar100_' in file_name: 33 | label_dim = 100 34 | normalizer = DifferentiableNormalize( 35 | mean=CIFAR100_MEAN, 36 | std=CIFAR100_STD, 37 | ) 38 | norm = normalize_cifar100 39 | train_loader, test_loader = load_dataset('cifar100', 1000) 40 | elif 'tiny' in file_name: 41 | label_dim = 200 42 | normalizer = DifferentiableNormalize( 43 | mean=TINYIMAGENET_MEAN, 44 | std=TINYIMAGENET_STD, 45 | ) 46 | norm = normalize_tinyimagenet 47 | train_loader, test_loader = load_dataset('tiny-imagenet-200', 1000) 48 | else: 49 | raise ValueError('Unknown dataset') 50 | 51 | if 'prn' in file_name and 'deit' not in file_name and 'wrn' not in file_name: 52 | model = PreActResNet18(label_dim) 53 | elif 'wrn' in file_name: 54 | model = WRN28_10(label_dim) 55 | elif 'deit' in file_name: 56 | model = DeiT(label_dim) 57 | 58 | d = torch.load('./models/' + file_name, map_location='cuda:0') 59 | for k in list(d.keys()): 60 | if k.startswith('module.'): 61 | d[k[7:]] = d[k] 62 | del d[k] 63 | 64 | model.load_state_dict(d) 65 | model.eval() 66 | model.cuda() 67 | 68 | normed_model = Model(model, norm) 69 | normed_model.eval() 70 | normed_model.cuda() 71 | 72 | # test clean accuracy on the whole test set 73 | acc = 0. 74 | for x, y in test_loader: 75 | x, y = x.cuda(), y.cuda() 76 | with torch.no_grad(): 77 | pred = normed_model(x) 78 | acc += (pred.max(1)[1] == y).float().sum().item() 79 | acc /= len(test_loader.dataset) 80 | print('Model: {}, Clean Accuracy: {:.4f}'.format(file_name, acc)) 81 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f: 82 | f.write('Model: {}, Clean Accuracy: {:.4f}\n'.format(file_name, acc)) 83 | 84 | # attackers 85 | pgd_1 = torchattacks.PGD(normed_model, eps=1 / 255, alpha=0.25 / 255, steps=10) 86 | pgd_2 = torchattacks.PGD(normed_model, eps=2 / 255, alpha=0.5 / 255, steps=10) 87 | pgd_4 = torchattacks.PGD(normed_model, eps=4 / 255, alpha=1 / 255, steps=10) 88 | pgd_8 = torchattacks.PGD(normed_model, eps=8 / 255, alpha=2 / 255, steps=10) 89 | fgsm_1 = torchattacks.FGSM(normed_model, eps=1 / 255) 90 | fgsm_8 = torchattacks.FGSM(normed_model, eps=8 / 255) 91 | cw = torchattacks.CW(normed_model, c=1, kappa=0, steps=10) 92 | autoattack = torchattacks.APGDT(normed_model, norm='Linf', eps=4/255, steps=5, n_restarts=1, seed=0 93 | , eot_iter=1, rho=.75, verbose=False, n_classes=label_dim) 94 | pgd_l2_32 = torchattacks.PGDL2(normed_model, eps=32 / 255, alpha=8 / 255, steps=10) 95 | pgd_l2_64 = torchattacks.PGDL2(normed_model, eps=64 / 255, alpha=16 / 255, steps=10) 96 | pixle = torchattacks.Pixle(normed_model, max_iterations=5, restarts=5) 97 | fab = torchattacks.FAB(normed_model, eps=8 / 255, norm='L2') 98 | 99 | recolor_attack = get_attack_from_name('recoloradv+stadv+delta', model, normalizer, verbose=True) 100 | stadv_attack = get_attack_from_name('stadv', model, normalizer, verbose=True) 101 | 102 | lib_attacker_list = [pgd_1, pgd_2, pgd_4, pgd_8, fgsm_1, fgsm_8, cw, autoattack, pgd_l2_32, 103 | pgd_l2_64, pixle, fab] 104 | lib_atkname_list = ['pgd_1', 'pgd_2', 'pgd_4', 'pgd_8', 'fgsm_1', 'fgsm_8', 'cw', 'autoattack', 105 | 'pgd_l2_32', 'pgd_l2_64', 'pixle', 'fab'] 106 | sem_attacker_list = [recolor_attack, stadv_attack] 107 | sem_atkname_list = ['recolor', 'stadv'] 108 | for i in range(len(lib_attacker_list)): 109 | try: 110 | lib_attacker = lib_attacker_list[i] 111 | lib_atkname = lib_atkname_list[i] 112 | acc = 0 113 | # get first 1000 imgs and calculate acc 114 | for x, y in test_loader: 115 | x, y = x.cuda(), y.cuda() 116 | adv_x = lib_attacker(x, y) 117 | pred = normed_model(adv_x) 118 | acc += (pred.max(1)[1] == y).float().sum().item() 119 | break 120 | acc /= 1000 121 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, lib_atkname, acc)) 122 | # write to log in ./logs/attack_log.txt 123 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f: 124 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, lib_atkname, acc)) 125 | except: 126 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, lib_atkname, 'failed')) 127 | # write to log in ./logs/attack_log.txt 128 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f: 129 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, lib_atkname, 'failed')) 130 | 131 | for i in range(len(sem_attacker_list)): 132 | try: 133 | sem_attacker = sem_attacker_list[i] 134 | sem_atkname = sem_atkname_list[i] 135 | acc = 0 136 | # get first 1000 imgs and calculate acc 137 | for x, y in test_loader: 138 | x, y = x.cuda(), y.cuda() 139 | adv_x = sem_attacker.attack(x, y)[0] 140 | pred = normed_model(adv_x) 141 | acc += (pred.max(1)[1] == y).float().sum().item() 142 | break 143 | acc /= 1000 144 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, sem_atkname, acc)) 145 | # write to log in ./logs/attack_log.txt 146 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f: 147 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, sem_atkname, acc)) 148 | except: 149 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, sem_atkname, 'failed')) 150 | # write to log in ./logs/attack_log.txt 151 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f: 152 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, sem_atkname, 'failed')) 153 | -------------------------------------------------------------------------------- /SAM_segmentation/utils/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | voc_mu = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 8 | voc_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 9 | def normalize_voc(x): 10 | return (x - voc_mu.to(x.device))/(voc_std.to(x.device)) 11 | 12 | city_mu = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 13 | city_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 14 | def normalize_city(x): 15 | return (x - city_mu.to(x.device))/(city_std.to(x.device)) 16 | 17 | iccv09_mu = torch.tensor([0.4813, 0.4901, 0.4747]).view(3, 1, 1) 18 | iccv09_std = torch.tensor([0.2495, 0.2492, 0.2748]).view(3, 1, 1) 19 | def normalize_iccv09(x): 20 | return (x - iccv09_mu.to(x.device))/(iccv09_std.to(x.device)) 21 | 22 | class Attack(): 23 | def __init__(self, iters, alpha, eps, norm, criterion, rand_init, rand_perturb, targeted, normalize): 24 | self.iters = iters 25 | self.alpha = alpha 26 | self.eps = eps 27 | self.norm = norm 28 | assert norm in ['linf', 'l2'] 29 | self.criterion = criterion # loss function for perturb 30 | self.rand_init = rand_init # random initialization before perturb 31 | self.rand_perturb = rand_perturb # add random noise in each step 32 | self.targetd = targeted # targeted attack 33 | self.normalize = normalize # normalize_cifar 34 | 35 | def perturb(self, model, x, y): 36 | assert x.min() >= 0 and x.max() <= 1 37 | delta = torch.zeros_like(x, device=x.device) 38 | if self.rand_init: 39 | if self.norm == "linf": 40 | delta.uniform_(-self.eps, self.eps) 41 | elif self.norm == "l2": 42 | delta.normal_() 43 | d_flat = delta.view(delta.size(0), -1) 44 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 45 | r = torch.zeros_like(n).uniform_(0, 1) 46 | delta *= r/n*self.eps 47 | else: 48 | raise NotImplementedError("Only linf and l2 norms are implemented.") 49 | 50 | delta = torch.clamp(delta, 0-x, 1-x) 51 | delta.requires_grad = True 52 | 53 | for i in range(self.iters): 54 | # output = model(self.normalize(x+delta)) 55 | output = model(x+delta) 56 | loss = self.criterion(output, y) 57 | if self.targetd: 58 | loss = -loss 59 | loss.backward() 60 | grad = delta.grad.detach() 61 | if self.norm == "linf": 62 | d = torch.clamp(delta + self.alpha * torch.sign(grad), min=-self.eps, max=self.eps).detach() 63 | elif self.norm == "l2": 64 | grad_norm = torch.norm(grad.view(grad.size(0), -1), dim=1).view(-1, 1, 1, 1) 65 | scaled_grad = grad / (grad_norm + 1e-10) 66 | d = (delta + scaled_grad * self.alpha).view(delta.size(0), -1).renorm(p=2, dim=0, maxnorm=self.eps).view_as(delta).detach() 67 | 68 | d = torch.clamp(d, 0-x, 1-x) 69 | delta.data = d 70 | delta.grad.zero_() 71 | 72 | return delta.detach() 73 | 74 | def make_one_hot(input, num_classes): 75 | """Convert class index tensor to one hot encoding tensor. 76 | 77 | Args: 78 | input: A tensor of shape [N, 1, *] 79 | num_classes: An int of number of class 80 | Returns: 81 | A tensor of shape [N, num_classes, *] 82 | """ 83 | shape = np.array(input.shape) 84 | shape[1] = num_classes 85 | shape = tuple(shape) 86 | result = torch.zeros(shape) 87 | result = result.scatter_(1, input.cpu(), 1) 88 | 89 | return result 90 | 91 | class BinaryDiceLoss(nn.Module): 92 | """Dice loss of binary class 93 | Args: 94 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 95 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 96 | predict: A tensor of shape [N, *] 97 | target: A tensor of shape same with predict 98 | reduction: Reduction method to apply, return mean over batch if 'mean', 99 | return sum if 'sum', return a tensor of shape [N,] if 'none' 100 | Returns: 101 | Loss tensor according to arg reduction 102 | Raise: 103 | Exception if unexpected reduction 104 | """ 105 | def __init__(self, smooth=1, p=2, reduction='mean'): 106 | super(BinaryDiceLoss, self).__init__() 107 | self.smooth = smooth 108 | self.p = p 109 | self.reduction = reduction 110 | 111 | def forward(self, predict, target): 112 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 113 | predict = predict.contiguous().view(predict.shape[0], -1) 114 | target = target.contiguous().view(target.shape[0], -1) 115 | 116 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 117 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 118 | 119 | loss = 1 - num / den 120 | 121 | if self.reduction == 'mean': 122 | return loss.mean() 123 | elif self.reduction == 'sum': 124 | return loss.sum() 125 | elif self.reduction == 'none': 126 | return loss 127 | else: 128 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 129 | 130 | 131 | class DiceLoss(nn.Module): 132 | def __init__(self, weight=None, ignore_index=None, **kwargs): 133 | super(DiceLoss, self).__init__() 134 | self.kwargs = kwargs 135 | self.weight = weight 136 | self.ignore_index = ignore_index 137 | 138 | 139 | def forward(self, predict, target): 140 | target = self._convert_target(target) 141 | assert predict.shape == target.shape, 'predict & target shape do not match' 142 | dice = BinaryDiceLoss(**self.kwargs) 143 | total_loss = 0 144 | predict = F.softmax(predict, dim=1) 145 | 146 | for i in range(target.shape[1]): 147 | if i != self.ignore_index: 148 | dice_loss = dice(predict[:, i], target[:, i]) 149 | if self.weight is not None: 150 | assert self.weight.shape[0] == target.shape[1], \ 151 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 152 | dice_loss *= self.weights[i] 153 | total_loss += dice_loss 154 | 155 | return total_loss/target.shape[1] 156 | 157 | def _convert_target(self, target): 158 | device = target.device 159 | target = make_one_hot(target.unsqueeze(1), 9) 160 | target = target.to(device) 161 | return target 162 | 163 | class PGD(Attack): 164 | def __init__(self, iters, alpha, eps, norm, rand_init, targeted=False, normalize=normalize_voc): 165 | # super().__init__(iters, alpha, eps, norm, DiceLoss(ignore_index=255), rand_init=rand_init, rand_perturb=False, targeted=targeted, normalize=normalize) 166 | super().__init__(iters, alpha, eps, norm, nn.CrossEntropyLoss(ignore_index=255, reduction='mean'), rand_init=rand_init, rand_perturb=False, targeted=targeted, normalize=normalize) -------------------------------------------------------------------------------- /SAM_segmentation/network/_deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import _SimpleSegmentationModel 6 | 7 | 8 | __all__ = ["DeepLabV3"] 9 | 10 | 11 | class DeepLabV3(_SimpleSegmentationModel): 12 | """ 13 | Implements DeepLabV3 model from 14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 15 | `_. 16 | 17 | Arguments: 18 | backbone (nn.Module): the network used to compute the features for the model. 19 | The backbone should return an OrderedDict[Tensor], with the key being 20 | "out" for the last feature map used, and "aux" if an auxiliary classifier 21 | is used. 22 | classifier (nn.Module): module that takes the "out" element returned from 23 | the backbone and returns a dense prediction. 24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 25 | """ 26 | pass 27 | 28 | class DeepLabHeadV3Plus(nn.Module): 29 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 30 | super(DeepLabHeadV3Plus, self).__init__() 31 | self.project = nn.Sequential( 32 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 33 | nn.BatchNorm2d(48), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | self.aspp = ASPP(in_channels, aspp_dilate) 38 | 39 | self.classifier = nn.Sequential( 40 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 41 | nn.BatchNorm2d(256), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(256, num_classes, 1) 44 | ) 45 | self._init_weight() 46 | 47 | def forward(self, feature): 48 | low_level_feature = self.project( feature['low_level'] ) 49 | output_feature = self.aspp(feature['out']) 50 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 51 | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) 52 | 53 | def _init_weight(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | nn.init.kaiming_normal_(m.weight) 57 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 58 | nn.init.constant_(m.weight, 1) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | class DeepLabHead(nn.Module): 62 | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): 63 | super(DeepLabHead, self).__init__() 64 | 65 | self.classifier = nn.Sequential( 66 | ASPP(in_channels, aspp_dilate), 67 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 68 | nn.BatchNorm2d(256), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(256, num_classes, 1) 71 | ) 72 | self._init_weight() 73 | 74 | def forward(self, feature): 75 | return self.classifier( feature['out'] ) 76 | 77 | def _init_weight(self): 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.kaiming_normal_(m.weight) 81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 82 | nn.init.constant_(m.weight, 1) 83 | nn.init.constant_(m.bias, 0) 84 | 85 | class AtrousSeparableConvolution(nn.Module): 86 | """ Atrous Separable Convolution 87 | """ 88 | def __init__(self, in_channels, out_channels, kernel_size, 89 | stride=1, padding=0, dilation=1, bias=True): 90 | super(AtrousSeparableConvolution, self).__init__() 91 | self.body = nn.Sequential( 92 | # Separable Conv 93 | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), 94 | # PointWise Conv 95 | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 96 | ) 97 | 98 | self._init_weight() 99 | 100 | def forward(self, x): 101 | return self.body(x) 102 | 103 | def _init_weight(self): 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight) 107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | class ASPPConv(nn.Sequential): 112 | def __init__(self, in_channels, out_channels, dilation): 113 | modules = [ 114 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 115 | nn.BatchNorm2d(out_channels), 116 | nn.ReLU(inplace=True) 117 | ] 118 | super(ASPPConv, self).__init__(*modules) 119 | 120 | class ASPPPooling(nn.Sequential): 121 | def __init__(self, in_channels, out_channels): 122 | super(ASPPPooling, self).__init__( 123 | nn.AdaptiveAvgPool2d(1), 124 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 125 | nn.BatchNorm2d(out_channels), 126 | nn.ReLU(inplace=True)) 127 | 128 | def forward(self, x): 129 | size = x.shape[-2:] 130 | x = super(ASPPPooling, self).forward(x) 131 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 132 | 133 | class ASPP(nn.Module): 134 | def __init__(self, in_channels, atrous_rates): 135 | super(ASPP, self).__init__() 136 | out_channels = 256 137 | modules = [] 138 | modules.append(nn.Sequential( 139 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 140 | nn.BatchNorm2d(out_channels), 141 | nn.ReLU(inplace=True))) 142 | 143 | rate1, rate2, rate3 = tuple(atrous_rates) 144 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 145 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 146 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 147 | modules.append(ASPPPooling(in_channels, out_channels)) 148 | 149 | self.convs = nn.ModuleList(modules) 150 | 151 | self.project = nn.Sequential( 152 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 153 | nn.BatchNorm2d(out_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Dropout(0.1),) 156 | 157 | def forward(self, x): 158 | res = [] 159 | for conv in self.convs: 160 | res.append(conv(x)) 161 | res = torch.cat(res, dim=1) 162 | return self.project(res) 163 | 164 | 165 | 166 | def convert_to_separable_conv(module): 167 | new_module = module 168 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: 169 | new_module = AtrousSeparableConvolution(module.in_channels, 170 | module.out_channels, 171 | module.kernel_size, 172 | module.stride, 173 | module.padding, 174 | module.dilation, 175 | module.bias) 176 | for name, child in module.named_children(): 177 | new_module.add_module(name, convert_to_separable_conv(child)) 178 | return new_module -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/cifar10/cifar_loader.py: -------------------------------------------------------------------------------- 1 | """ Code to build a cifar10 data loader """ 2 | 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from . import cifar_resnets 8 | from . import wide_resnets 9 | from ..utils import pytorch_utils as utils 10 | from .. import config 11 | import os 12 | import re 13 | 14 | 15 | ############################################################################### 16 | # PARSE CONFIGS # 17 | ############################################################################### 18 | 19 | DEFAULT_DATASETS_DIR = config.DEFAULT_DATASETS_DIR 20 | RESNET_WEIGHT_PATH = config.MODEL_PATH 21 | DEFAULT_BATCH_SIZE = config.DEFAULT_BATCH_SIZE 22 | DEFAULT_WORKERS = config.DEFAULT_WORKERS 23 | CIFAR10_MEANS = config.CIFAR10_MEANS 24 | CIFAR10_STDS = config.CIFAR10_STDS 25 | WIDE_CIFAR10_MEANS = config.WIDE_CIFAR10_MEANS 26 | WIDE_CIFAR10_STDS = config.WIDE_CIFAR10_STDS 27 | ############################################################################### 28 | # END PARSE CONFIGS # 29 | ############################################################################### 30 | 31 | 32 | ############################################################################## 33 | # # 34 | # MODEL LOADER # 35 | # # 36 | ############################################################################## 37 | 38 | def load_pretrained_cifar_resnet(flavor=32, 39 | return_normalizer=False, 40 | manual_gpu=None): 41 | """ Helper fxn to initialize/load the pretrained cifar resnet 42 | """ 43 | 44 | # Resolve load path 45 | valid_flavor_numbers = [110, 1202, 20, 32, 44, 56] 46 | assert flavor in valid_flavor_numbers 47 | weight_path = os.path.join(RESNET_WEIGHT_PATH, 48 | 'cifar10_resnet%s.th' % flavor) 49 | 50 | 51 | # Resolve CPU/GPU stuff 52 | if manual_gpu is not None: 53 | use_gpu = manual_gpu 54 | else: 55 | use_gpu = utils.use_gpu() 56 | 57 | if use_gpu: 58 | map_location = None 59 | else: 60 | map_location = (lambda s, l: s) 61 | 62 | 63 | # need to modify the resnet state dict to be proper 64 | # TODO: LOAD THESE INTO MODEL ZOO 65 | bad_state_dict = torch.load(weight_path, map_location=map_location) 66 | correct_state_dict = {re.sub(r'^module\.', '', k): v for k, v in 67 | bad_state_dict['state_dict'].items()} 68 | 69 | 70 | classifier_net = eval("cifar_resnets.resnet%s" % flavor)() 71 | classifier_net.load_state_dict(correct_state_dict) 72 | 73 | if return_normalizer: 74 | normalizer = utils.DifferentiableNormalize(mean=CIFAR10_MEANS, 75 | std=CIFAR10_STDS) 76 | return classifier_net, normalizer 77 | 78 | return classifier_net 79 | 80 | 81 | def load_pretrained_cifar_wide_resnet(use_gpu=False, return_normalizer=False): 82 | """ Helper fxn to initialize/load a pretrained 28x10 CIFAR resnet """ 83 | 84 | weight_path = os.path.join(RESNET_WEIGHT_PATH, 85 | 'cifar10_wide-resnet28x10.th') 86 | state_dict = torch.load(weight_path) 87 | classifier_net = wide_resnets.Wide_ResNet(28, 10, 0, 10) 88 | 89 | classifier_net.load_state_dict(state_dict) 90 | 91 | if return_normalizer: 92 | normalizer = utils.DifferentiableNormalize(mean=WIDE_CIFAR10_MEANS, 93 | std=WIDE_CIFAR10_STDS) 94 | return classifier_net, normalizer 95 | 96 | return classifier_net 97 | 98 | 99 | 100 | 101 | 102 | 103 | ############################################################################## 104 | # # 105 | # DATA LOADER # 106 | # # 107 | ############################################################################## 108 | 109 | def load_cifar_data(train_or_val, extra_args=None, dataset_dir=None, 110 | normalize=False, batch_size=None, manual_gpu=None, 111 | shuffle=True, no_transform=False): 112 | """ Builds a CIFAR10 data loader for either training or evaluation of 113 | CIFAR10 data. See the 'DEFAULTS' section in the fxn for default args 114 | ARGS: 115 | train_or_val: string - one of 'train' or 'val' for whether we should 116 | load training or validation datap 117 | extra_args: dict - if not None is the kwargs to be passed to DataLoader 118 | constructor 119 | dataset_dir: string - if not None is a directory to load the data from 120 | normalize: boolean - if True, we normalize the data by subtracting out 121 | means and dividing by standard devs 122 | manual_gpu : boolean or None- if None, we use the GPU if we can 123 | else, we use the GPU iff this is True 124 | shuffle: boolean - if True, we load the data in a shuffled order 125 | no_transform: boolean - if True, we don't do any random cropping/ 126 | reflections of the data 127 | """ 128 | 129 | ################################################################## 130 | # DEFAULTS # 131 | ################################################################## 132 | # dataset directory 133 | dataset_dir = dataset_dir or DEFAULT_DATASETS_DIR 134 | batch_size = batch_size or DEFAULT_BATCH_SIZE 135 | 136 | # Extra arguments for DataLoader constructor 137 | if manual_gpu is not None: 138 | use_gpu = manual_gpu 139 | else: 140 | use_gpu = utils.use_gpu() 141 | 142 | constructor_kwargs = {'batch_size': batch_size, 143 | 'shuffle': shuffle, 144 | 'num_workers': DEFAULT_WORKERS, 145 | 'pin_memory': use_gpu} 146 | constructor_kwargs.update(extra_args or {}) 147 | 148 | # transform chain 149 | transform_list = [] 150 | if no_transform is False: 151 | transform_list.extend([transforms.RandomHorizontalFlip(), 152 | transforms.RandomCrop(32, 4)]) 153 | transform_list.append(transforms.ToTensor()) 154 | 155 | if normalize: 156 | normalizer = transforms.Normalize(mean=CIFAR10_MEANS, 157 | std=CIFAR10_STDS) 158 | transform_list.append(normalizer) 159 | 160 | 161 | transform_chain = transforms.Compose(transform_list) 162 | # train_or_val validation 163 | assert train_or_val in ['train', 'val'] 164 | 165 | ################################################################## 166 | # Build DataLoader # 167 | ################################################################## 168 | return torch.utils.data.DataLoader( 169 | datasets.CIFAR10(root=dataset_dir, train=train_or_val=='train', 170 | transform=transform_chain, download=True), 171 | **constructor_kwargs) 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | """ Specific utilities for image classification 2 | (i.e. RGB images i.e. tensors of the form NxCxHxW ) 3 | """ 4 | from __future__ import print_function 5 | from . import pytorch_utils as utils 6 | import torch 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torchvision.transforms as transforms 10 | import random 11 | 12 | def nhwc255_xform(img_np_array): 13 | """ Takes in a numpy array and transposes it so that the channel is the last 14 | axis. Also multiplies all values by 255.0 15 | ARGS: 16 | img_np_array : np.ndarray - array of shape (NxHxWxC) or (NxCxHxW) 17 | [assumes that we're in NCHW by default, 18 | but if not ambiguous will handle NHWC too ] 19 | RETURNS: 20 | array of form NHWC 21 | """ 22 | assert isinstance(img_np_array, np.ndarray) 23 | shape = img_np_array.shape 24 | assert len(shape) == 4 25 | 26 | # determine which configuration we're in 27 | ambiguous = (shape[1] == shape[3] == 3) 28 | nhwc = (shape[1] == 3) 29 | 30 | # transpose unless we're unambiguously in nhwc case 31 | if nhwc and not ambiguous: 32 | return img_np_array * 255.0 33 | else: 34 | return np.transpose(img_np_array, (0, 2, 3, 1)) * 255.0 35 | 36 | 37 | def show_images(images, normalize=None, ipython=True, 38 | margin_height=2, margin_color='red', 39 | figsize=(18,16)): 40 | """ Shows pytorch tensors/variables as images """ 41 | 42 | 43 | # first format the first arg to be hz-stacked numpy arrays 44 | if not isinstance(images, list): 45 | images = [images] 46 | images = [np.dstack(image.cpu().numpy()) for image in images] 47 | image_shape = images[0].shape 48 | assert all(image.shape == image_shape for image in images) 49 | assert all(image.ndim == 3 for image in images) # CxHxW 50 | 51 | # now build the list of final rows 52 | rows = [] 53 | if margin_height >0: 54 | assert margin_color in ['red', 'black'] 55 | margin_shape = list(image_shape) 56 | margin_shape[1] = margin_height 57 | margin = np.zeros(margin_shape) 58 | if margin_color == 'red': 59 | margin[0] = 1 60 | else: 61 | margin = None 62 | 63 | for image_row in images: 64 | rows.append(margin) 65 | rows.append(image_row) 66 | 67 | rows = [_ for _ in rows[1:] if _ is not None] 68 | plt.figure(figsize=figsize, dpi=80, facecolor='w', edgecolor='k') 69 | 70 | cat_rows = np.concatenate(rows, 1).transpose(1, 2, 0) 71 | imshow_kwargs = {} 72 | if cat_rows.shape[-1] == 1: # 1 channel: greyscale 73 | cat_rows = cat_rows.squeeze() 74 | imshow_kwargs['cmap'] = 'gray' 75 | 76 | plt.imshow(cat_rows, **imshow_kwargs) 77 | 78 | plt.show() 79 | 80 | 81 | 82 | 83 | def display_adversarial_2row(classifier_net, normalizer, original_images, 84 | adversarial_images, num_to_show=4, which='incorrect', 85 | ipython=False, margin_width=2): 86 | """ Displays adversarial images side-by-side with their unperturbed 87 | counterparts. Opens a window displaying two rows: top row is original 88 | images, bottom row is perturbed 89 | ARGS: 90 | classifier_net : nn - with a .forward method that takes normalized 91 | variables and outputs logits 92 | normalizer : object w/ .forward method - should probably be an instance 93 | of utils.DifferentiableNormalize or utils.IdentityNormalize 94 | original_images: Variable or Tensor (NxCxHxW) - original images to 95 | display. Images in [0., 1.] range 96 | adversarial_images: Variable or Tensor (NxCxHxW) - perturbed images to 97 | display. Should be same shape as original_images 98 | num_to_show : int - number of images to show 99 | which : string in ['incorrect', 'random', 'correct'] - which images to 100 | show. 101 | -- 'incorrect' means successfully attacked images, 102 | -- 'random' means some random selection of images 103 | -- 'correct' means unsuccessfully attacked images 104 | ipython: bool - if True, we use in an ipython notebook so slightly 105 | different way to show Images 106 | margin_width - int : height in pixels of the red margin separating top 107 | and bottom rows. Set to 0 for no margin 108 | RETURNS: 109 | None, but displays images 110 | """ 111 | assert which in ['incorrect', 'random', 'correct'] 112 | 113 | 114 | # If not 'random' selection, prune to only the valid things 115 | to_sample_idxs = [] 116 | if which != 'random': 117 | classifier_net.eval() # can never be too safe =) 118 | 119 | # classify the originals with top1 120 | original_norm_var = normalizer.forward(original_images) 121 | original_out_logits = classifier_net.forward(original_norm_var) 122 | _, original_out_classes = original_out_logits.max(1) 123 | 124 | # classify the adversarials with top1 125 | adv_norm_var = normalizer.forward(adversarial_images) 126 | adv_out_logits = classifier_net.forward(adv_norm_var) 127 | _, adv_out_classes = adv_out_logits.max(1) 128 | 129 | 130 | # collect indices of matching 131 | selector = lambda var: (which == 'correct') == bool(float(var)) 132 | for idx, var_el in enumerate(original_out_classes == adv_out_classes): 133 | if selector(var_el): 134 | to_sample_idxs.append(idx) 135 | else: 136 | to_sample_idxs = list(range(original_images.shape[0])) 137 | 138 | # Now select some indices to show 139 | if to_sample_idxs == []: 140 | print("Couldn't show anything. Try changing the 'which' argument here") 141 | return 142 | 143 | to_show_idxs = random.sample(to_sample_idxs, min([num_to_show, 144 | len(to_sample_idxs)])) 145 | 146 | # Now start building up the images : first horizontally, then vertically 147 | top_row = torch.cat([original_images[idx] for idx in to_show_idxs], dim=2) 148 | bottom_row = torch.cat([adversarial_images[idx] for idx in to_show_idxs], 149 | dim=2) 150 | 151 | if margin_width > 0: 152 | margin = torch.zeros(3, margin_width, top_row.shape[-1]) 153 | margin[0] = 1.0 # make it red 154 | margin = margin.type(type(top_row)) 155 | stack = [top_row, margin, bottom_row] 156 | else: 157 | stack = [top_row, bottom_row] 158 | 159 | plt.imshow(torch.cat(stack, dim=1).cpu().numpy().transpose(1, 2, 0)) 160 | plt.show() 161 | 162 | 163 | def display_adversarial_notebook(): 164 | pass 165 | 166 | def nchw_l2(x, y, squared=True): 167 | """ Computes l2 norm between two NxCxHxW images 168 | ARGS: 169 | x, y: Tensor/Variable (NxCxHxW) - x, y must be same type & shape. 170 | squared : bool - if True we return squared loss, otherwise we return 171 | square root of l2 172 | RETURNS: 173 | ||x - y ||_2 ^2 (no exponent if squared == False), 174 | shape is (Nx1x1x1) 175 | """ 176 | temp = torch.pow(x - y, 2) # square diff 177 | 178 | 179 | for i in range(1, temp.dim()): # reduce on all but first dimension 180 | temp = torch.sum(temp, i, keepdim=True) 181 | 182 | if not squared: 183 | temp = torch.pow(temp, 0.5) 184 | 185 | return temp.squeeze() 186 | -------------------------------------------------------------------------------- /SAM_segmentation/network/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | try: # for torchvision<0.4 3 | from torchvision.models.utils import load_state_dict_from_url 4 | except: # for torchvision>=0.4 5 | from torch.hub import load_state_dict_from_url 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 9 | 10 | 11 | model_urls = { 12 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 13 | } 14 | 15 | 16 | def _make_divisible(v, divisor, min_value=None): 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | :param v: 23 | :param divisor: 24 | :param min_value: 25 | :return: 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | class ConvBNReLU(nn.Sequential): 37 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 38 | #padding = (kernel_size - 1) // 2 39 | super(ConvBNReLU, self).__init__( 40 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 41 | nn.BatchNorm2d(out_planes), 42 | nn.ReLU6(inplace=True) 43 | ) 44 | 45 | def fixed_padding(kernel_size, dilation): 46 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 47 | pad_total = kernel_size_effective - 1 48 | pad_beg = pad_total // 2 49 | pad_end = pad_total - pad_beg 50 | return (pad_beg, pad_end, pad_beg, pad_end) 51 | 52 | class InvertedResidual(nn.Module): 53 | def __init__(self, inp, oup, stride, dilation, expand_ratio): 54 | super(InvertedResidual, self).__init__() 55 | self.stride = stride 56 | assert stride in [1, 2] 57 | 58 | hidden_dim = int(round(inp * expand_ratio)) 59 | self.use_res_connect = self.stride == 1 and inp == oup 60 | 61 | layers = [] 62 | if expand_ratio != 1: 63 | # pw 64 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 65 | 66 | layers.extend([ 67 | # dw 68 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), 69 | # pw-linear 70 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup), 72 | ]) 73 | self.conv = nn.Sequential(*layers) 74 | 75 | self.input_padding = fixed_padding( 3, dilation ) 76 | 77 | def forward(self, x): 78 | x_pad = F.pad(x, self.input_padding) 79 | if self.use_res_connect: 80 | return x + self.conv(x_pad) 81 | else: 82 | return self.conv(x_pad) 83 | 84 | class MobileNetV2(nn.Module): 85 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 86 | """ 87 | MobileNet V2 main class 88 | 89 | Args: 90 | num_classes (int): Number of classes 91 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 92 | inverted_residual_setting: Network structure 93 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 94 | Set to 1 to turn off rounding 95 | """ 96 | super(MobileNetV2, self).__init__() 97 | block = InvertedResidual 98 | input_channel = 32 99 | last_channel = 1280 100 | self.output_stride = output_stride 101 | current_stride = 1 102 | if inverted_residual_setting is None: 103 | inverted_residual_setting = [ 104 | # t, c, n, s 105 | [1, 16, 1, 1], 106 | [6, 24, 2, 2], 107 | [6, 32, 3, 2], 108 | [6, 64, 4, 2], 109 | [6, 96, 3, 1], 110 | [6, 160, 3, 2], 111 | [6, 320, 1, 1], 112 | ] 113 | 114 | # only check the first element, assuming user knows t,c,n,s are required 115 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 116 | raise ValueError("inverted_residual_setting should be non-empty " 117 | "or a 4-element list, got {}".format(inverted_residual_setting)) 118 | 119 | # building first layer 120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 122 | features = [ConvBNReLU(3, input_channel, stride=2)] 123 | current_stride *= 2 124 | dilation=1 125 | previous_dilation = 1 126 | 127 | # building inverted residual blocks 128 | for t, c, n, s in inverted_residual_setting: 129 | output_channel = _make_divisible(c * width_mult, round_nearest) 130 | previous_dilation = dilation 131 | if current_stride == output_stride: 132 | stride = 1 133 | dilation *= s 134 | else: 135 | stride = s 136 | current_stride *= s 137 | output_channel = int(c * width_mult) 138 | 139 | for i in range(n): 140 | if i==0: 141 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) 142 | else: 143 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) 144 | input_channel = output_channel 145 | # building last several layers 146 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 147 | # make it nn.Sequential 148 | self.features = nn.Sequential(*features) 149 | 150 | # building classifier 151 | self.classifier = nn.Sequential( 152 | nn.Dropout(0.2), 153 | nn.Linear(self.last_channel, num_classes), 154 | ) 155 | 156 | # weight initialization 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 160 | if m.bias is not None: 161 | nn.init.zeros_(m.bias) 162 | elif isinstance(m, nn.BatchNorm2d): 163 | nn.init.ones_(m.weight) 164 | nn.init.zeros_(m.bias) 165 | elif isinstance(m, nn.Linear): 166 | nn.init.normal_(m.weight, 0, 0.01) 167 | nn.init.zeros_(m.bias) 168 | 169 | def forward(self, x): 170 | x = self.features(x) 171 | x = x.mean([2, 3]) 172 | x = self.classifier(x) 173 | return x 174 | 175 | 176 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 177 | """ 178 | Constructs a MobileNetV2 architecture from 179 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | progress (bool): If True, displays a progress bar of the download to stderr 184 | """ 185 | model = MobileNetV2(**kwargs) 186 | if pretrained: 187 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 188 | progress=progress) 189 | model.load_state_dict(state_dict) 190 | return model 191 | -------------------------------------------------------------------------------- /train_eval_scripts/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import argparse 8 | from time import time 9 | 10 | from utils import * 11 | from model import PreActResNet18, WRN28_10, DeiT 12 | from sam import SAM, ASAM, ESAM 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--fname', type=str, required=True) 18 | parser.add_argument('--model', type=str, default='PreActResNet18', choices=['PRN', 'WRN', 'DeiT']) 19 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tiny-imagenet-200']) 20 | parser.add_argument('--epochs', default=100, type=int) 21 | parser.add_argument('--max-lr', default=0.1, type=float) 22 | parser.add_argument('--opt', default='SGD', choices=['Adam', 'SGD']) 23 | parser.add_argument('--sam', default='NO', choices=['SAM', 'ASAM', 'ESAM', 'NO']) 24 | parser.add_argument('--batch-size', default=128, type=int) 25 | parser.add_argument('--device', default=0, type=int) 26 | parser.add_argument('--adv', action='store_true') 27 | parser.add_argument('--rho', default=0.05, type=float) # for SAM 28 | 29 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2']) 30 | parser.add_argument('--train-eps', default=8., type=float) 31 | parser.add_argument('--train-alpha', default=2., type=float) 32 | parser.add_argument('--train-step', default=5, type=int) 33 | 34 | parser.add_argument('--test-eps', default=1., type=float) 35 | parser.add_argument('--test-alpha', default=0.5, type=float) 36 | parser.add_argument('--test-step', default=5, type=int) 37 | return parser.parse_args() 38 | 39 | 40 | args = get_args() 41 | 42 | 43 | def lr_schedule(epoch): 44 | if epoch < args.epochs * 0.75: 45 | return args.max_lr 46 | elif epoch < args.epochs * 0.9: 47 | return args.max_lr * 0.1 48 | else: 49 | return args.max_lr * 0.01 50 | 51 | 52 | if __name__ == '__main__': 53 | dataset = args.dataset 54 | device = f'cuda:{args.device}' 55 | model_name = args.model 56 | label_dim = {'cifar10': 10, 'cifar100': 100, 'tiny-imagenet-200': 200}[dataset] 57 | model = {'PRN': PreActResNet18(label_dim), 'WRN': WRN28_10(label_dim), 'DeiT': DeiT(label_dim)}[model_name].to( 58 | device) 59 | train_loader, test_loader = load_dataset(dataset, args.batch_size) 60 | params = model.parameters() 61 | criterion = nn.CrossEntropyLoss() 62 | 63 | if args.sam == 'NO': 64 | if args.opt == 'SGD': 65 | opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4) 66 | elif args.opt == 'Adam': 67 | opt = torch.optim.Adam(params, lr=args.max_lr, weight_decay=5e-4) 68 | else: 69 | raise "Invalid optimizer" 70 | else: 71 | if args.sam == 'SAM': 72 | base_opt = torch.optim.SGD 73 | opt = SAM(params, base_opt, lr=args.max_lr, momentum=0.9, weight_decay=5e-4, rho=args.rho) 74 | elif args.sam == 'ASAM': 75 | base_opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4) 76 | opt = ASAM(base_opt, model, rho=args.rho) 77 | elif args.sam == 'ESAM': 78 | base_opt = torch.optim.SGD(model.parameters(), lr=args.max_lr, momentum=0.9, weight_decay=5e-4) 79 | opt = ESAM(params, base_opt, rho=args.rho) 80 | else: 81 | raise "Invalid SAM optimizer" 82 | 83 | normalize = \ 84 | {'cifar10': normalize_cifar, 'cifar100': normalize_cifar100, 'tiny-imagenet-200': normalize_tinyimagenet}[dataset] 85 | 86 | all_log_data = [] 87 | train_pgd = PGD(args.train_step, args.train_alpha / 255., args.train_eps / 255., args.norm, False, normalize) 88 | test_pgd = PGD(args.test_step, args.test_alpha / 255., args.test_eps / 255., args.norm, False, normalize) 89 | 90 | for epoch in range(args.epochs): 91 | start_time = time() 92 | log_data = [0, 0, 0, 0, 0, 0] # train_loss, train_acc, test_loss, test_acc, test_robust_loss, test_robust 93 | # train 94 | model.train() 95 | lr = lr_schedule(epoch) 96 | if args.sam == 'ASAM': 97 | opt.optimizer.param_groups[0].update(lr=lr) 98 | else: 99 | opt.param_groups[0].update(lr=lr) 100 | for x, y in train_loader: 101 | x, y = x.to(device), y.to(device) 102 | if args.adv: 103 | delta = train_pgd.perturb(model, x, y) 104 | else: 105 | delta = torch.zeros_like(x).to(x.device) 106 | 107 | if args.sam == 'NO': 108 | output = model(normalize(x + delta)) 109 | loss = criterion(output, y) 110 | opt.zero_grad() 111 | loss.backward() 112 | opt.step() 113 | 114 | else: 115 | if args.sam == 'SAM': 116 | output = model(normalize(x + delta)) 117 | loss = criterion(output, y) 118 | loss.backward() 119 | opt.first_step(zero_grad=True) 120 | output_2 = model(normalize(x + delta)) 121 | criterion(output_2, y).backward() 122 | opt.second_step(zero_grad=True) 123 | elif args.sam == 'ASAM': 124 | output = model(normalize(x + delta)) 125 | loss = criterion(output, y) 126 | loss.backward() 127 | opt.ascent_step() 128 | output_2 = model(normalize(x + delta)) 129 | criterion(output_2, y).backward() 130 | opt.descent_step() 131 | elif args.sam == 'ESAM': 132 | def defined_backward(loss): 133 | loss.backward() 134 | paras = [normalize(x + delta), y, criterion, model, defined_backward] 135 | opt.paras = paras 136 | opt.step() 137 | output, loss = opt.returnthings 138 | 139 | log_data[0] += (loss * len(y)).item() 140 | log_data[1] += (output.max(1)[1] == y).float().sum().item() 141 | 142 | # test 143 | model.eval() 144 | for x, y in test_loader: 145 | x, y = x.to(device), y.to(device) 146 | # clean 147 | output = model(normalize(x)).detach() 148 | loss = criterion(output, y) 149 | 150 | log_data[2] += (loss * len(y)).item() 151 | log_data[3] += (output.max(1)[1] == y).float().sum().item() 152 | delta = test_pgd.perturb(model, x, y) 153 | output = model(normalize(x + delta)).detach() 154 | loss = criterion(output, y) 155 | 156 | log_data[4] += (loss * len(y)).item() 157 | log_data[5] += (output.max(1)[1] == y).float().sum().item() 158 | 159 | log_data = np.array(log_data) 160 | num_train = 60000 if 'cifar' in dataset else 100000 161 | num_test = 10000 if 'cifar' in dataset else 10000 162 | log_data[0] /= num_train 163 | log_data[1] /= num_train 164 | log_data[2] /= num_test 165 | log_data[3] /= num_test 166 | log_data[4] /= num_test 167 | log_data[5] /= num_test 168 | all_log_data.append(log_data) 169 | 170 | print(f'Epoch {epoch}:\t', log_data, f'\tTime {time() - start_time:.1f}s') 171 | save_path = '{dataset}_models/{fname}.pth' 172 | torch.save(model.state_dict(), save_path.format(dataset=dataset, fname=args.fname)) 173 | 174 | all_log_data = np.stack(all_log_data, axis=0) 175 | 176 | df = pd.DataFrame(all_log_data) 177 | df.to_csv(f'logs/{args.fname}.csv') 178 | 179 | plt.plot(all_log_data[:, [2, 4]]) 180 | plt.grid() 181 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Loss', fontsize=16) 182 | plt.legend(['clean', 'robust'], fontsize=16) 183 | plt.savefig(f'figs/{args.fname}_loss.png', dpi=200) 184 | plt.clf() 185 | 186 | plt.plot(all_log_data[:, [3, 5]]) 187 | plt.grid() 188 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Acc', fontsize=16) 189 | plt.legend(['clean', 'robust'], fontsize=16) 190 | plt.savefig(f'figs/{args.fname}_acc.png', dpi=200) 191 | plt.clf() 192 | -------------------------------------------------------------------------------- /SAM_segmentation/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | class Cityscapes(data.Dataset): 12 | """Cityscapes Dataset. 13 | 14 | **Parameters:** 15 | - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located. 16 | - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val' 17 | - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types. 18 | - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it. 20 | """ 21 | 22 | # Based on https://github.com/mcordts/cityscapesScripts 23 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 24 | 'has_instances', 'ignore_in_eval', 'color']) 25 | classes = [ 26 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 27 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 28 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 29 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 30 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 31 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 32 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 33 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 34 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 35 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 36 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 37 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 38 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 39 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 40 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 41 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 42 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 43 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 44 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 45 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 46 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 47 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 48 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 49 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 50 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 51 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 52 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 53 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 54 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 55 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 56 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 57 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 58 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 59 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 60 | CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), 61 | ] 62 | 63 | train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] 64 | train_id_to_color.append([0, 0, 0]) 65 | train_id_to_color = np.array(train_id_to_color) 66 | id_to_train_id = np.array([c.train_id for c in classes]) 67 | 68 | #train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35), 69 | # (70, 130, 180), (220, 20, 60), (0, 0, 142)] 70 | #train_id_to_color = np.array(train_id_to_color) 71 | #id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1 72 | 73 | def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None): 74 | self.root = os.path.expanduser(root) 75 | self.mode = 'gtFine' 76 | self.target_type = target_type 77 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 78 | 79 | self.targets_dir = os.path.join(self.root, self.mode, split) 80 | self.transform = transform 81 | 82 | self.split = split 83 | self.images = [] 84 | self.targets = [] 85 | 86 | if split not in ['train', 'test', 'val']: 87 | raise ValueError('Invalid split for mode! Please use split="train", split="test"' 88 | ' or split="val"') 89 | 90 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 91 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 92 | ' specified "split" and "mode" are inside the "root" directory') 93 | 94 | for city in os.listdir(self.images_dir): 95 | img_dir = os.path.join(self.images_dir, city) 96 | target_dir = os.path.join(self.targets_dir, city) 97 | 98 | for file_name in os.listdir(img_dir): 99 | self.images.append(os.path.join(img_dir, file_name)) 100 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 101 | self._get_target_suffix(self.mode, self.target_type)) 102 | self.targets.append(os.path.join(target_dir, target_name)) 103 | 104 | @classmethod 105 | def encode_target(cls, target): 106 | return cls.id_to_train_id[np.array(target)] 107 | 108 | @classmethod 109 | def decode_target(cls, target): 110 | target[target == 255] = 19 111 | #target = target.astype('uint8') + 1 112 | return cls.train_id_to_color[target] 113 | 114 | def __getitem__(self, index): 115 | """ 116 | Args: 117 | index (int): Index 118 | Returns: 119 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 120 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 121 | """ 122 | image = Image.open(self.images[index]).convert('RGB') 123 | target = Image.open(self.targets[index]) 124 | if self.transform: 125 | image, target = self.transform(image, target) 126 | target = self.encode_target(target) 127 | if image.max() > 1 or image.min() < 0: 128 | image = (image - image.min()) / (image.max() - image.min()) 129 | return image, target 130 | 131 | def __len__(self): 132 | return len(self.images) 133 | 134 | def _load_json(self, path): 135 | with open(path, 'r') as file: 136 | data = json.load(file) 137 | return data 138 | 139 | def _get_target_suffix(self, mode, target_type): 140 | if target_type == 'instance': 141 | return '{}_instanceIds.png'.format(mode) 142 | elif target_type == 'semantic': 143 | return '{}_labelIds.png'.format(mode) 144 | elif target_type == 'color': 145 | return '{}_color.png'.format(mode) 146 | elif target_type == 'polygon': 147 | return '{}_polygons.json'.format(mode) 148 | elif target_type == 'depth': 149 | return '{}_disparity.png'.format(mode) -------------------------------------------------------------------------------- /train_eval_scripts/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from collections import defaultdict 4 | 5 | class SAM(torch.optim.Optimizer): 6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 8 | 9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 10 | super(SAM, self).__init__(params, defaults) 11 | 12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 13 | self.param_groups = self.base_optimizer.param_groups 14 | self.defaults.update(self.base_optimizer.defaults) 15 | 16 | @torch.no_grad() 17 | def first_step(self, zero_grad=False): 18 | grad_norm = self._grad_norm() 19 | for group in self.param_groups: 20 | scale = group["rho"] / (grad_norm + 1e-12) 21 | 22 | for p in group["params"]: 23 | if p.grad is None: continue 24 | self.state[p]["old_p"] = p.data.clone() 25 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 26 | p.add_(e_w) # climb to the local maximum "w + e(w)" 27 | 28 | if zero_grad: self.zero_grad() 29 | 30 | @torch.no_grad() 31 | def second_step(self, zero_grad=False): 32 | for group in self.param_groups: 33 | for p in group["params"]: 34 | if p.grad is None: continue 35 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 36 | 37 | self.base_optimizer.step() # do the actual "sharpness-aware" update 38 | 39 | if zero_grad: self.zero_grad() 40 | 41 | @torch.no_grad() 42 | def step(self, closure=None): 43 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 44 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 45 | 46 | self.first_step(zero_grad=True) 47 | closure() 48 | self.second_step() 49 | 50 | def _grad_norm(self): 51 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 52 | norm = torch.norm( 53 | torch.stack([ 54 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 55 | for group in self.param_groups for p in group["params"] 56 | if p.grad is not None 57 | ]), 58 | p=2 59 | ) 60 | return norm 61 | 62 | def load_state_dict(self, state_dict): 63 | super().load_state_dict(state_dict) 64 | self.base_optimizer.param_groups = self.param_groups 65 | 66 | 67 | class ASAM: 68 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 69 | self.optimizer = optimizer 70 | self.model = model 71 | self.rho = rho 72 | self.eta = eta 73 | self.state = defaultdict(dict) 74 | 75 | @torch.no_grad() 76 | def ascent_step(self): 77 | wgrads = [] 78 | for n, p in self.model.named_parameters(): 79 | if p.grad is None: 80 | continue 81 | t_w = self.state[p].get("eps") 82 | if t_w is None: 83 | t_w = torch.clone(p).detach() 84 | self.state[p]["eps"] = t_w 85 | if 'weight' in n: 86 | t_w[...] = p[...] 87 | t_w.abs_().add_(self.eta) 88 | p.grad.mul_(t_w) 89 | wgrads.append(torch.norm(p.grad, p=2)) 90 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 91 | for n, p in self.model.named_parameters(): 92 | if p.grad is None: 93 | continue 94 | t_w = self.state[p].get("eps") 95 | if 'weight' in n: 96 | p.grad.mul_(t_w) 97 | eps = t_w 98 | eps[...] = p.grad[...] 99 | eps.mul_(self.rho / wgrad_norm) 100 | p.add_(eps) 101 | self.optimizer.zero_grad() 102 | 103 | @torch.no_grad() 104 | def descent_step(self): 105 | for n, p in self.model.named_parameters(): 106 | if p.grad is None: 107 | continue 108 | p.sub_(self.state[p]["eps"]) 109 | self.optimizer.step() 110 | self.optimizer.zero_grad() 111 | 112 | 113 | class ESAM(torch.optim.Optimizer): 114 | def __init__(self, params, base_optimizer, rho=0.05, beta=1.0, gamma=1.0, adaptive=False, **kwargs): 115 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 116 | self.beta = beta 117 | self.gamma = gamma 118 | 119 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 120 | super(ESAM, self).__init__(params, defaults) 121 | 122 | self.base_optimizer = base_optimizer 123 | self.param_groups = self.base_optimizer.param_groups 124 | for group in self.param_groups: 125 | group["rho"] = rho 126 | group["adaptive"] = adaptive 127 | self.paras = None 128 | 129 | @torch.no_grad() 130 | def first_step(self, zero_grad=False): 131 | # first order sum 132 | grad_norm = self._grad_norm() 133 | for group in self.param_groups: 134 | scale = group["rho"] / (grad_norm + 1e-7) / self.beta 135 | for p in group["params"]: 136 | p.requires_grad = True 137 | if p.grad is None: continue 138 | # original sam 139 | # e_w = p.grad * scale.to(p) 140 | # asam 141 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 142 | p.add_(e_w * 1) # climb to the local maximum "w + e(w)" 143 | self.state[p]["e_w"] = e_w 144 | 145 | if zero_grad: self.zero_grad() 146 | 147 | ''' 148 | @torch.no_grad() 149 | def first_half(self, zero_grad=False): 150 | #first order sum 151 | for group in self.param_groups: 152 | for p in group["params"]: 153 | if self.state[p]: 154 | p.add_(self.state[p]["e_w"]*0.90) # climb to the local maximum "w + e(w)" 155 | ''' 156 | 157 | @torch.no_grad() 158 | def second_step(self, zero_grad=False): 159 | for group in self.param_groups: 160 | for p in group["params"]: 161 | if p.grad is None or not self.state[p]: continue 162 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" 163 | self.state[p]["e_w"] = 0 164 | 165 | if random.random() > self.beta: 166 | p.requires_grad = False 167 | 168 | self.base_optimizer.step() # do the actual "sharpness-aware" update 169 | 170 | if zero_grad: self.zero_grad() 171 | 172 | def step(self): 173 | inputs, targets, loss_fct, model, defined_backward = self.paras 174 | assert defined_backward is not None, "Sharpness Aware Minimization requires defined_backward, but it was not provided" 175 | 176 | model.require_backward_grad_sync = False 177 | model.require_forward_param_sync = True 178 | 179 | logits = model(inputs) 180 | loss = loss_fct(logits, targets) 181 | 182 | l_before = loss.clone().detach() 183 | predictions = logits 184 | return_loss = loss.clone().detach() 185 | loss = loss.mean() 186 | defined_backward(loss) 187 | 188 | # first step to w + e(w) 189 | self.first_step(True) 190 | 191 | with torch.no_grad(): 192 | l_after = loss_fct(model(inputs), targets) 193 | instance_sharpness = l_after - l_before 194 | 195 | # codes for sorting 196 | prob = self.gamma 197 | if prob >= 0.99: 198 | indices = range(len(targets)) 199 | else: 200 | position = int(len(targets) * prob) 201 | cutoff, _ = torch.topk(instance_sharpness, position) 202 | cutoff = cutoff[-1] 203 | 204 | # cutoff = 0 205 | # select top k% 206 | 207 | indices = [instance_sharpness > cutoff] 208 | 209 | # second forward-backward step 210 | # self.first_half() 211 | 212 | model.require_backward_grad_sync = True 213 | model.require_forward_param_sync = False 214 | 215 | loss = loss_fct(model(inputs[indices]), targets[indices]) 216 | loss = loss.mean() 217 | defined_backward(loss) 218 | self.second_step(True) 219 | 220 | self.returnthings = (predictions, return_loss) 221 | 222 | def _grad_norm(self): 223 | shared_device = self.param_groups[0]["params"][ 224 | 0].device # put everything on the same device, in case of model parallelism 225 | norm = torch.norm( 226 | torch.stack([ 227 | # original sam 228 | # p.grad.norm(p=2).to(shared_device) 229 | # asam 230 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 231 | for group in self.param_groups for p in group["params"] 232 | if p.grad is not None 233 | ]), 234 | p=2 235 | ) 236 | return norm -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/color_spaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains classes that convert from RGB to various other color spaces and back. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from .mister_ed.utils import pytorch_utils as utils 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from recoloradv import norms 11 | import math 12 | 13 | 14 | class ColorSpace(object): 15 | """ 16 | Base class for color spaces. 17 | """ 18 | 19 | def from_rgb(self, imgs): 20 | """ 21 | Converts an Nx3xWxH tensor in RGB color space to a Nx3xWxH tensor in 22 | this color space. All outputs should be in the 0-1 range. 23 | """ 24 | raise NotImplementedError() 25 | 26 | def to_rgb(self, imgs): 27 | """ 28 | Converts an Nx3xWxH tensor in this color space to a Nx3xWxH tensor in 29 | RGB color space. 30 | """ 31 | raise NotImplementedError() 32 | 33 | 34 | class RGBColorSpace(ColorSpace): 35 | """ 36 | RGB color space. Just applies identity transformation. 37 | """ 38 | 39 | def from_rgb(self, imgs): 40 | return imgs 41 | 42 | def to_rgb(self, imgs): 43 | return imgs 44 | 45 | 46 | class YPbPrColorSpace(ColorSpace): 47 | """ 48 | YPbPr color space. Uses ITU-R BT.601 standard by default. 49 | """ 50 | 51 | def __init__(self, kr=0.299, kg=0.587, kb=0.114, luma_factor=1, 52 | chroma_factor=1): 53 | self.kr, self.kg, self.kb = kr, kg, kb 54 | self.luma_factor = luma_factor 55 | self.chroma_factor = chroma_factor 56 | 57 | def from_rgb(self, imgs): 58 | r, g, b = imgs.permute(1, 0, 2, 3) 59 | 60 | y = r * self.kr + g * self.kg + b * self.kb 61 | pb = (b - y) / (2 * (1 - self.kb)) 62 | pr = (r - y) / (2 * (1 - self.kr)) 63 | 64 | return torch.stack([y * self.luma_factor, 65 | pb * self.chroma_factor + 0.5, 66 | pr * self.chroma_factor + 0.5], 1) 67 | 68 | def to_rgb(self, imgs): 69 | y_prime, pb_prime, pr_prime = imgs.permute(1, 0, 2, 3) 70 | y = y_prime / self.luma_factor 71 | pb = (pb_prime - 0.5) / self.chroma_factor 72 | pr = (pr_prime - 0.5) / self.chroma_factor 73 | 74 | b = pb * 2 * (1 - self.kb) + y 75 | r = pr * 2 * (1 - self.kr) + y 76 | g = (y - r * self.kr - b * self.kb) / self.kg 77 | 78 | return torch.stack([r, g, b], 1).clamp(0, 1) 79 | 80 | 81 | class ApproxHSVColorSpace(ColorSpace): 82 | """ 83 | Converts from RGB to approximately the HSV cone using a much smoother 84 | transformation. 85 | """ 86 | 87 | def from_rgb(self, imgs): 88 | r, g, b = imgs.permute(1, 0, 2, 3) 89 | 90 | x = r * np.sqrt(2) / 3 - g / (np.sqrt(2) * 3) - b / (np.sqrt(2) * 3) 91 | y = g / np.sqrt(6) - b / np.sqrt(6) 92 | z, _ = imgs.max(1) 93 | 94 | return torch.stack([z, x + 0.5, y + 0.5], 1) 95 | 96 | def to_rgb(self, imgs): 97 | z, xp, yp = imgs.permute(1, 0, 2, 3) 98 | x, y = xp - 0.5, yp - 0.5 99 | 100 | rp = float(np.sqrt(2)) * x 101 | gp = -x / np.sqrt(2) + y * np.sqrt(3 / 2) 102 | bp = -x / np.sqrt(2) - y * np.sqrt(3 / 2) 103 | 104 | delta = z - torch.max(torch.stack([rp, gp, bp], 1), 1)[0] 105 | r, g, b = rp + delta, gp + delta, bp + delta 106 | 107 | return torch.stack([r, g, b], 1).clamp(0, 1) 108 | 109 | 110 | class HSVConeColorSpace(ColorSpace): 111 | """ 112 | Converts from RGB to the HSV "cone", where (x, y, z) = 113 | (s * v cos h, s * v sin h, v). Note that this cone is then squashed to fit 114 | in [0, 1]^3 by letting (x', y', z') = ((x + 1) / 2, (y + 1) / 2, z). 115 | 116 | WARNING: has a very complex derivative, not very useful in practice 117 | """ 118 | 119 | def from_rgb(self, imgs): 120 | r, g, b = imgs.permute(1, 0, 2, 3) 121 | 122 | mx, argmx = imgs.max(1) 123 | mn, _ = imgs.min(1) 124 | chroma = mx - mn 125 | eps = 1e-10 126 | h_max_r = math.pi / 3 * (g - b) / (chroma + eps) 127 | h_max_g = math.pi / 3 * (b - r) / (chroma + eps) + math.pi * 2 / 3 128 | h_max_b = math.pi / 3 * (r - g) / (chroma + eps) + math.pi * 4 / 3 129 | 130 | h = (((argmx == 0) & (chroma != 0)).float() * h_max_r 131 | + ((argmx == 1) & (chroma != 0)).float() * h_max_g 132 | + ((argmx == 2) & (chroma != 0)).float() * h_max_b) 133 | 134 | x = torch.cos(h) * chroma 135 | y = torch.sin(h) * chroma 136 | z = mx 137 | 138 | return torch.stack([(x + 1) / 2, (y + 1) / 2, z], 1) 139 | 140 | def _to_rgb_part(self, h, chroma, v, n): 141 | """ 142 | Implements the function f(n) defined here: 143 | https://en.wikipedia.org/wiki/HSL_and_HSV#Alternative_HSV_to_RGB 144 | """ 145 | 146 | k = (n + h * math.pi / 3) % 6 147 | return v - chroma * torch.min(k, 4 - k).clamp(0, 1) 148 | 149 | def to_rgb(self, imgs): 150 | xp, yp, z = imgs.permute(1, 0, 2, 3) 151 | x, y = xp * 2 - 1, yp * 2 - 1 152 | 153 | # prevent NaN gradients when calculating atan2 154 | x_nonzero = (1 - 2 * (torch.sign(x) == -1).float()) * (torch.abs(x) + 1e-10) 155 | h = torch.atan2(y, x_nonzero) 156 | v = z.clamp(0, 1) 157 | chroma = torch.min(torch.sqrt(x ** 2 + y ** 2 + 1e-10), v) 158 | 159 | r = self._to_rgb_part(h, chroma, v, 5) 160 | g = self._to_rgb_part(h, chroma, v, 3) 161 | b = self._to_rgb_part(h, chroma, v, 1) 162 | 163 | return torch.stack([r, g, b], 1).clamp(0, 1) 164 | 165 | 166 | class CIEXYZColorSpace(ColorSpace): 167 | """ 168 | The 1931 CIE XYZ color space (assuming input is in sRGB). 169 | 170 | Warning: may have values outside [0, 1] range. Should only be used in 171 | the process of converting to/from other color spaces. 172 | """ 173 | 174 | def from_rgb(self, imgs): 175 | # apply gamma correction 176 | small_values_mask = (imgs < 0.04045).float() 177 | imgs_corrected = ( 178 | (imgs / 12.92) * small_values_mask + 179 | ((imgs + 0.055) / 1.055) ** 2.4 * (1 - small_values_mask) 180 | ) 181 | 182 | # linear transformation to XYZ 183 | r, g, b = imgs_corrected.permute(1, 0, 2, 3) 184 | x = 0.4124 * r + 0.3576 * g + 0.1805 * b 185 | y = 0.2126 * r + 0.7152 * g + 0.0722 * b 186 | z = 0.0193 * r + 0.1192 * g + 0.9504 * b 187 | 188 | return torch.stack([x, y, z], 1) 189 | 190 | def to_rgb(self, imgs): 191 | # linear transformation 192 | x, y, z = imgs.permute(1, 0, 2, 3) 193 | r = 3.2406 * x - 1.5372 * y - 0.4986 * z 194 | g = -0.9689 * x + 1.8758 * y + 0.0415 * z 195 | b = 0.0557 * x - 0.2040 * y + 1.0570 * z 196 | 197 | imgs = torch.stack([r, g, b], 1) 198 | 199 | # apply gamma correction 200 | small_values_mask = (imgs < 0.0031308).float() 201 | imgs_clamped = imgs.clamp(min=1e-10) # prevent NaN gradients 202 | imgs_corrected = ( 203 | (12.92 * imgs) * small_values_mask + 204 | (1.055 * imgs_clamped ** (1 / 2.4) - 0.055) * 205 | (1 - small_values_mask) 206 | ) 207 | 208 | return imgs_corrected 209 | 210 | 211 | class CIELUVColorSpace(ColorSpace): 212 | """ 213 | Converts to the 1976 CIE L*u*v* color space. 214 | """ 215 | 216 | def __init__(self, up_white=0.1978, vp_white=0.4683, y_white=1, 217 | eps=1e-10): 218 | self.xyz_cspace = CIEXYZColorSpace() 219 | self.up_white = up_white 220 | self.vp_white = vp_white 221 | self.y_white = y_white 222 | self.eps = eps 223 | 224 | def from_rgb(self, imgs): 225 | x, y, z = self.xyz_cspace.from_rgb(imgs).permute(1, 0, 2, 3) 226 | 227 | # calculate u' and v' 228 | denom = x + 15 * y + 3 * z + self.eps 229 | up = 4 * x / denom 230 | vp = 9 * y / denom 231 | 232 | # calculate L*, u*, and v* 233 | small_values_mask = (y / self.y_white < (6 / 29) ** 3).float() 234 | y_clamped = y.clamp(min=self.eps) # prevent NaN gradients 235 | L = ( 236 | ((29 / 3) ** 3 * y / self.y_white) * small_values_mask + 237 | (116 * (y_clamped / self.y_white) ** (1 / 3) - 16) * 238 | (1 - small_values_mask) 239 | ) 240 | u = 13 * L * (up - self.up_white) 241 | v = 13 * L * (vp - self.vp_white) 242 | 243 | return torch.stack([L / 100, (u + 100) / 200, (v + 100) / 200], 1) 244 | 245 | def to_rgb(self, imgs): 246 | L = imgs[:, 0, :, :] * 100 247 | u = imgs[:, 1, :, :] * 200 - 100 248 | v = imgs[:, 2, :, :] * 200 - 100 249 | 250 | up = u / (13 * L + self.eps) + self.up_white 251 | vp = v / (13 * L + self.eps) + self.vp_white 252 | 253 | small_values_mask = (L <= 8).float() 254 | y = ( 255 | (self.y_white * L * (3 / 29) ** 3) * small_values_mask + 256 | (self.y_white * ((L + 16) / 116) ** 3) * (1 - small_values_mask) 257 | ) 258 | denom = 4 * vp + self.eps 259 | x = y * 9 * up / denom 260 | z = y * (12 - 3 * up - 20 * vp) / denom 261 | 262 | return self.xyz_cspace.to_rgb( 263 | torch.stack([x, y, z], 1).clamp(0, 1.1)).clamp(0, 1) 264 | -------------------------------------------------------------------------------- /train_eval_scripts/recoloradv/mister_ed/utils/discretization.py: -------------------------------------------------------------------------------- 1 | """ File that holds techniques for discretizing images -- 2 | In general, images of the form NxCxHxW will with values in the [0.,1.] range 3 | need to be converted to the [0, 255 (int)] range to be displayed as images. 4 | 5 | Sometimes the naive rounding scheme can mess up the classification, so this 6 | file holds techniques to discretize these images into tensors with values 7 | of the form i/255.0 for some integers i. 8 | """ 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | from . import pytorch_utils as utils 13 | 14 | ############################################################################## 15 | # # 16 | # HELPER METHODS # 17 | # # 18 | ############################################################################## 19 | 20 | 21 | def discretize_image(img_tensor, zero_one=False): 22 | """ Discretizes an image tensor into a tensor filled with ints ranging 23 | between 0 and 255 24 | ARGS: 25 | img_tensor : floatTensor (NxCxHxW) - tensor to be discretized 26 | pixel_max : int - discretization bucket size 27 | zero_one : bool - if True divides output by 255 before returning it 28 | """ 29 | 30 | assert float(torch.min(img_tensor)) >= 0. 31 | assert float(torch.max(img_tensor)) <= 1.0 32 | 33 | 34 | original_shape = img_tensor.shape 35 | if img_tensor.dim() != 4: 36 | img_tensor = img_tensor.unsqueeze(0) 37 | 38 | int_tensors = [] # actually floatTensor, but full of ints 39 | img_shape = original_shape[1:] 40 | for example in img_tensor: 41 | pixel_channel_tuples = zip(*list(smp.toimage(example).getdata())) 42 | int_tensors.append(img_tensor.new(pixel_channel_tuples).view(img_shape)) 43 | 44 | stacked_tensors = torch.stack(int_tensors) 45 | if zero_one: 46 | return stacked_tensors / 255.0 47 | return stacked_tensors 48 | 49 | 50 | 51 | ############################################################################## 52 | # # 53 | # MAIN DISCRETIZATION TECHNIQUES # 54 | # # 55 | ############################################################################## 56 | 57 | def discretized_adversarial(img_tensor, classifier_net, normalizer, 58 | flavor='greedy'): 59 | """ Takes in an image_tensor and classifier/normalizer pair and outputs a 60 | 'discretized' image_tensor [each val is i/255.0 for some integer i] 61 | with the same classification 62 | ARGS: 63 | img_tensor : tensor (NxCxHxW) - tensor of images with values between 64 | 0.0 and 1.0. 65 | classifier_net : NN - neural net with .forward method to classify 66 | normalized images 67 | normalizer : differentiableNormalizer object - normalizes 0,1 images 68 | into classifier_domain 69 | flavor : string - either 'random' or 'greedy', determining which 70 | 'next_pixel_to_flip' function we use 71 | RETURNS: 72 | img_tensor of the same shape, but no with values of the form i/255.0 73 | for integers i. 74 | """ 75 | 76 | img_tensor = utils.safe_tensor(img_tensor) 77 | 78 | nptf_map = {'random': flip_random_pixel, 79 | 'greedy': flip_greedy_pixel} 80 | next_pixel_to_flip = nptf_map[flavor](classifier_net, normalizer) 81 | 82 | ########################################################################## 83 | # First figure out 'correct' labels and the 'discretized' labels # 84 | ########################################################################## 85 | var_img = utils.safe_var(img_tensor) 86 | norm_var = normalizer.forward(var_img) 87 | norm_output = classifier_net.forward(norm_var) 88 | correct_targets = norm_output.max(1)[1] 89 | 90 | og_discretized = utils.safe_var(discretize_image(img_tensor, zero_one=True)) 91 | norm_discretized = normalizer.forward(og_discretized) 92 | discretized_output = classifier_net.forward(norm_discretized) 93 | discretized_targets = discretized_output.max(1)[1] 94 | 95 | ########################################################################## 96 | # Collect idxs for examples affected by discretization # 97 | ########################################################################## 98 | incorrect_idxs = set() 99 | 100 | for i, el in enumerate(correct_targets.ne(discretized_targets)): 101 | if float(el) != 0: 102 | incorrect_idxs.add(i) 103 | 104 | 105 | ########################################################################## 106 | # Fix all bad images # 107 | ########################################################################## 108 | 109 | corrected_imgs = [] 110 | for idx in incorrect_idxs: 111 | desired_target = correct_targets[idx] 112 | example = og_discretized[idx].data.clone() # tensor 113 | signs = torch.sign(var_img - og_discretized) 114 | bad_discretization = True 115 | pixels_changed_so_far = set() # populated with tuples of idxs 116 | 117 | while bad_discretization: 118 | pixel_idx, grad_sign = next_pixel_to_flip(example, 119 | pixels_changed_so_far, 120 | desired_target) 121 | pixels_changed_so_far.add(pixel_idx) 122 | 123 | if grad_sign == 0: 124 | grad_sign = utils.tuple_getter(signs[idx], pixel_idx) 125 | 126 | new_val = (grad_sign / 255. + utils.tuple_getter(example, pixel_idx)) 127 | utils.tuple_setter(example, pixel_idx, float(new_val)) 128 | 129 | new_out = classifier_net.forward(normalizer.forward(\ 130 | Variable(example.unsqueeze(0)))) 131 | bad_discretization = (int(desired_target) != int(new_out.max(1)[1])) 132 | corrected_imgs.append(example) 133 | 134 | # Stack up results 135 | output = [] 136 | 137 | for idx in range(len(img_tensor)): 138 | if idx in incorrect_idxs: 139 | output.append(corrected_imgs.pop(0)) 140 | else: 141 | output.append(og_discretized[idx].data) 142 | 143 | return torch.stack(output) # Variable 144 | 145 | 146 | 147 | 148 | 149 | ############################################################################# 150 | # # 151 | # FLIP TECHNIQUES # 152 | # # 153 | ############################################################################# 154 | ''' Flip techniques in general have the following specs: 155 | ARGS: 156 | classifier_net : NN - neural net with .forward method to classify 157 | normalized images 158 | normalizer : differentiableNormalizer object - normalizes 0,1 images 159 | into classifier_domain 160 | RETURNS: flip_function 161 | ''' 162 | 163 | ''' 164 | Flip function is a function that takes the following args: 165 | ARGS: 166 | img_tensor : Tensor (CxHxW) - image tensor in range 0.0 to 1.0 and is 167 | already discretized 168 | pixels_changed_so_far: set - set of index_tuples that have already been 169 | modified (we don't want to modify a pixel by 170 | more than 1/255 in any channel) 171 | correct_target : torch.LongTensor (1) - single element in a tensor that 172 | is the target class 173 | (e.g. int between 0 and 9 for CIFAR ) 174 | RETURNS: (idx_tuple, sign) 175 | index_tuple is a triple of indices indicating which pixel-channel needs 176 | to be modified, and sign is in {-1, 0, 1}. If +-1, we will modify the 177 | pixel-channel in that direction, otherwise we'll modify in the opposite 178 | of the direction that discretization rounded to. 179 | ''' 180 | 181 | 182 | def flip_random_pixel(classifier_net, normalizer): 183 | def flip_fxn(img_tensor, pixels_changed_so_far, correct_target): 184 | numel = img_tensor.numel() 185 | if len(pixels_changed_so_far) > numel * .9: 186 | raise Exception("WHAT IS GOING ON???") 187 | 188 | while True: 189 | pixel_idx, _ = utils.random_element_index(img_tensor) 190 | if pixel_idx not in pixels_changed_so_far: 191 | return pixel_idx, 0 192 | 193 | return flip_fxn 194 | 195 | 196 | 197 | def flip_greedy_pixel(classifier_net, normalizer): 198 | def flip_fxn(img_tensor, pixels_changed_so_far, correct_target, 199 | classifier_net=classifier_net, normalizer=normalizer): 200 | # Computes gradient and figures out which px most affects class_out 201 | classifier_net.zero_grad() 202 | img_var = Variable(img_tensor.unsqueeze(0), requires_grad=True) 203 | class_out = classifier_net.forward(normalizer.forward(img_var)) 204 | 205 | criterion = torch.nn.CrossEntropyLoss() 206 | loss = criterion(class_out, correct_target) # RESHAPE HERE 207 | loss.backward() 208 | # Really inefficient algorithm here, can probably do better 209 | new_grad_data = img_var.grad.data.clone().squeeze() 210 | signs = new_grad_data.sign() 211 | for idx_tuple in pixels_changed_so_far: 212 | utils.tuple_setter(new_grad_data, idx_tuple, 0) 213 | 214 | argmax = utils.torch_argmax(new_grad_data.abs()) 215 | return argmax, -1 * utils.tuple_getter(signs, argmax) 216 | 217 | return flip_fxn 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /SAM_segmentation/network/backbone/xception.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Xception is adapted from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py 4 | 5 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 6 | @author: tstandley 7 | Adapted by cadene 8 | Creates an Xception Model as defined in: 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 13 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 14 | REMEMBER to set your image size to 3x299x299 for both test and validation 15 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 16 | std=[0.5, 0.5, 0.5]) 17 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 18 | """ 19 | from __future__ import print_function, division, absolute_import 20 | import math 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.utils.model_zoo as model_zoo 25 | from torch.nn import init 26 | 27 | __all__ = ['xception'] 28 | 29 | pretrained_settings = { 30 | 'xception': { 31 | 'imagenet': { 32 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', 33 | 'input_space': 'RGB', 34 | 'input_size': [3, 299, 299], 35 | 'input_range': [0, 1], 36 | 'mean': [0.5, 0.5, 0.5], 37 | 'std': [0.5, 0.5, 0.5], 38 | 'num_classes': 1000, 39 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 40 | } 41 | } 42 | } 43 | 44 | 45 | class SeparableConv2d(nn.Module): 46 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 47 | super(SeparableConv2d,self).__init__() 48 | 49 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 50 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 51 | 52 | def forward(self,x): 53 | x = self.conv1(x) 54 | x = self.pointwise(x) 55 | return x 56 | 57 | 58 | class Block(nn.Module): 59 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilation=1): 60 | super(Block, self).__init__() 61 | 62 | if out_filters != in_filters or strides!=1: 63 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 64 | self.skipbn = nn.BatchNorm2d(out_filters) 65 | else: 66 | self.skip=None 67 | 68 | rep=[] 69 | 70 | filters=in_filters 71 | if grow_first: 72 | rep.append(nn.ReLU(inplace=True)) 73 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation, dilation=dilation, bias=False)) 74 | rep.append(nn.BatchNorm2d(out_filters)) 75 | filters = out_filters 76 | 77 | for i in range(reps-1): 78 | rep.append(nn.ReLU(inplace=True)) 79 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=dilation,dilation=dilation,bias=False)) 80 | rep.append(nn.BatchNorm2d(filters)) 81 | 82 | if not grow_first: 83 | rep.append(nn.ReLU(inplace=True)) 84 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation,dilation=dilation,bias=False)) 85 | rep.append(nn.BatchNorm2d(out_filters)) 86 | 87 | if not start_with_relu: 88 | rep = rep[1:] 89 | else: 90 | rep[0] = nn.ReLU(inplace=False) 91 | 92 | if strides != 1: 93 | rep.append(nn.MaxPool2d(3,strides,1)) 94 | self.rep = nn.Sequential(*rep) 95 | 96 | def forward(self,inp): 97 | x = self.rep(inp) 98 | 99 | if self.skip is not None: 100 | skip = self.skip(inp) 101 | skip = self.skipbn(skip) 102 | else: 103 | skip = inp 104 | x+=skip 105 | return x 106 | 107 | 108 | class Xception(nn.Module): 109 | """ 110 | Xception optimized for the ImageNet dataset, as specified in 111 | https://arxiv.org/pdf/1610.02357.pdf 112 | """ 113 | def __init__(self, num_classes=1000, replace_stride_with_dilation=None): 114 | """ Constructor 115 | Args: 116 | num_classes: number of classes 117 | """ 118 | super(Xception, self).__init__() 119 | 120 | self.num_classes = num_classes 121 | self.dilation = 1 122 | if replace_stride_with_dilation is None: 123 | # each element in the tuple indicates if we should replace 124 | # the 2x2 stride with a dilated convolution instead 125 | replace_stride_with_dilation = [False, False, False, False] 126 | if len(replace_stride_with_dilation) != 4: 127 | raise ValueError("replace_stride_with_dilation should be None " 128 | "or a 4-element tuple, got {}".format(replace_stride_with_dilation)) 129 | 130 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) # 1 / 2 131 | self.bn1 = nn.BatchNorm2d(32) 132 | self.relu1 = nn.ReLU(inplace=True) 133 | 134 | self.conv2 = nn.Conv2d(32,64,3,bias=False) 135 | self.bn2 = nn.BatchNorm2d(64) 136 | self.relu2 = nn.ReLU(inplace=True) 137 | #do relu here 138 | 139 | self.block1=self._make_block(64,128,2,2,start_with_relu=False,grow_first=True, dilate=replace_stride_with_dilation[0]) # 1 / 4 140 | self.block2=self._make_block(128,256,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[1]) # 1 / 8 141 | self.block3=self._make_block(256,728,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) # 1 / 16 142 | 143 | self.block4=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 144 | self.block5=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 145 | self.block6=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 146 | self.block7=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 147 | 148 | self.block8=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 149 | self.block9=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 150 | self.block10=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 151 | self.block11=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) 152 | 153 | self.block12=self._make_block(728,1024,2,2,start_with_relu=True,grow_first=False, dilate=replace_stride_with_dilation[3]) # 1 / 32 154 | 155 | self.conv3 = SeparableConv2d(1024,1536,3,1,1, dilation=self.dilation) 156 | self.bn3 = nn.BatchNorm2d(1536) 157 | self.relu3 = nn.ReLU(inplace=True) 158 | 159 | #do relu here 160 | self.conv4 = SeparableConv2d(1536,2048,3,1,1, dilation=self.dilation) 161 | self.bn4 = nn.BatchNorm2d(2048) 162 | 163 | self.fc = nn.Linear(2048, num_classes) 164 | 165 | # #------- init weights -------- 166 | # for m in self.modules(): 167 | # if isinstance(m, nn.Conv2d): 168 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 169 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 170 | # elif isinstance(m, nn.BatchNorm2d): 171 | # m.weight.data.fill_(1) 172 | # m.bias.data.zero_() 173 | # #----------------------------- 174 | 175 | def _make_block(self, in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilate=False): 176 | if dilate: 177 | self.dilation *= strides 178 | strides = 1 179 | return Block(in_filters,out_filters,reps,strides,start_with_relu=start_with_relu,grow_first=grow_first, dilation=self.dilation) 180 | 181 | def features(self, input): 182 | x = self.conv1(input) 183 | x = self.bn1(x) 184 | x = self.relu1(x) 185 | 186 | x = self.conv2(x) 187 | x = self.bn2(x) 188 | x = self.relu2(x) 189 | 190 | x = self.block1(x) 191 | x = self.block2(x) 192 | x = self.block3(x) 193 | x = self.block4(x) 194 | x = self.block5(x) 195 | x = self.block6(x) 196 | x = self.block7(x) 197 | x = self.block8(x) 198 | x = self.block9(x) 199 | x = self.block10(x) 200 | x = self.block11(x) 201 | x = self.block12(x) 202 | 203 | x = self.conv3(x) 204 | x = self.bn3(x) 205 | x = self.relu3(x) 206 | 207 | x = self.conv4(x) 208 | x = self.bn4(x) 209 | return x 210 | 211 | def logits(self, features): 212 | x = nn.ReLU(inplace=True)(features) 213 | 214 | x = F.adaptive_avg_pool2d(x, (1, 1)) 215 | x = x.view(x.size(0), -1) 216 | x = self.last_linear(x) 217 | return x 218 | 219 | def forward(self, input): 220 | x = self.features(input) 221 | x = self.logits(x) 222 | return x 223 | 224 | 225 | def xception(num_classes=1000, pretrained='imagenet', replace_stride_with_dilation=None): 226 | model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation) 227 | if pretrained: 228 | settings = pretrained_settings['xception'][pretrained] 229 | assert num_classes == settings['num_classes'], \ 230 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 231 | 232 | model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation) 233 | model.load_state_dict(model_zoo.load_url(settings['url'])) 234 | 235 | # TODO: ugly 236 | model.last_linear = model.fc 237 | del model.fc 238 | return model -------------------------------------------------------------------------------- /train_eval_scripts/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from deit import deit_tiny_patch16_224 5 | 6 | 7 | class PreActBlock(nn.Module): 8 | '''Pre-activation version of the BasicBlock.''' 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(PreActBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(x)) 25 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 26 | out = self.conv1(out) 27 | out = self.conv2(F.relu(self.bn2(out))) 28 | out += shortcut 29 | return out 30 | 31 | 32 | class PreActBottleneck(nn.Module): 33 | '''Pre-activation version of the original Bottleneck module.''' 34 | expansion = 4 35 | 36 | def __init__(self, in_planes, planes, stride=1): 37 | super(PreActBottleneck, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(in_planes) 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes) 43 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 44 | 45 | if stride != 1 or in_planes != self.expansion*planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(x)) 52 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 53 | out = self.conv1(out) 54 | out = self.conv2(F.relu(self.bn2(out))) 55 | out = self.conv3(F.relu(self.bn3(out))) 56 | out += shortcut 57 | return out 58 | 59 | 60 | class PreActResNet(nn.Module): 61 | def __init__(self, block, num_blocks, num_classes=10): 62 | super(PreActResNet, self).__init__() 63 | self.in_planes = 64 64 | 65 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 70 | self.bn = nn.BatchNorm2d(512 * block.expansion) 71 | self.linear = nn.Linear(512*block.expansion, num_classes) 72 | 73 | def _make_layer(self, block, planes, num_blocks, stride): 74 | strides = [stride] + [1]*(num_blocks-1) 75 | layers = [] 76 | for stride in strides: 77 | layers.append(block(self.in_planes, planes, stride)) 78 | self.in_planes = planes * block.expansion 79 | return nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | out = self.conv1(x) 83 | out = self.layer1(out) 84 | out = self.layer2(out) 85 | out = self.layer3(out) 86 | out = self.layer4(out) 87 | out = F.relu(self.bn(out)) 88 | out = F.avg_pool2d(out, 4) 89 | out = out.view(out.size(0), -1) 90 | out = self.linear(out) 91 | return out 92 | 93 | 94 | def PreActResNet18(num_classes=10): 95 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes) 96 | 97 | def initialize_weights(module): 98 | if isinstance(module, nn.Conv2d): 99 | nn.init.kaiming_normal_(module.weight.data, mode='fan_in') 100 | elif isinstance(module, nn.BatchNorm2d): 101 | module.weight.data.uniform_() 102 | module.bias.data.zero_() 103 | elif isinstance(module, nn.Linear): 104 | module.bias.data.zero_() 105 | 106 | 107 | class BasicBlock(nn.Module): 108 | def __init__(self, in_channels, out_channels, stride, drop_rate): 109 | super(BasicBlock, self).__init__() 110 | 111 | self.drop_rate = drop_rate 112 | 113 | self._preactivate_both = (in_channels != out_channels) 114 | 115 | self.bn1 = nn.BatchNorm2d(in_channels) 116 | self.conv1 = nn.Conv2d( 117 | in_channels, 118 | out_channels, 119 | kernel_size=3, 120 | stride=stride, # downsample with first conv 121 | padding=1, 122 | bias=False) 123 | 124 | self.bn2 = nn.BatchNorm2d(out_channels) 125 | self.conv2 = nn.Conv2d( 126 | out_channels, 127 | out_channels, 128 | kernel_size=3, 129 | stride=1, 130 | padding=1, 131 | bias=False) 132 | 133 | self.shortcut = nn.Sequential() 134 | if in_channels != out_channels: 135 | self.shortcut.add_module( 136 | 'conv', 137 | nn.Conv2d( 138 | in_channels, 139 | out_channels, 140 | kernel_size=1, 141 | stride=stride, # downsample 142 | padding=0, 143 | bias=False)) 144 | 145 | def forward(self, x): 146 | if self._preactivate_both: 147 | x = F.relu( 148 | self.bn1(x), inplace=True) # shortcut after preactivation 149 | y = self.conv1(x) 150 | else: 151 | y = F.relu( 152 | self.bn1(x), 153 | inplace=True) # preactivation only for residual path 154 | y = self.conv1(y) 155 | if self.drop_rate > 0: 156 | y = F.dropout( 157 | y, p=self.drop_rate, training=self.training, inplace=False) 158 | 159 | y = F.relu(self.bn2(y), inplace=True) 160 | y = self.conv2(y) 161 | y += self.shortcut(x) 162 | return y 163 | 164 | 165 | class Network(nn.Module): 166 | def __init__(self, config): 167 | super(Network, self).__init__() 168 | 169 | input_shape = config['input_shape'] 170 | n_classes = config['n_classes'] 171 | 172 | base_channels = config['base_channels'] 173 | widening_factor = config['widening_factor'] 174 | drop_rate = config['drop_rate'] 175 | depth = config['depth'] 176 | 177 | block = BasicBlock 178 | n_blocks_per_stage = (depth - 4) // 6 179 | assert n_blocks_per_stage * 6 + 4 == depth 180 | 181 | n_channels = [ 182 | base_channels, base_channels * widening_factor, 183 | base_channels * 2 * widening_factor, 184 | base_channels * 4 * widening_factor 185 | ] 186 | 187 | self.conv = nn.Conv2d( 188 | input_shape[1], 189 | n_channels[0], 190 | kernel_size=3, 191 | stride=1, 192 | padding=1, 193 | bias=False) 194 | 195 | self.stage1 = self._make_stage( 196 | n_channels[0], 197 | n_channels[1], 198 | n_blocks_per_stage, 199 | block, 200 | stride=1, 201 | drop_rate=drop_rate) 202 | self.stage2 = self._make_stage( 203 | n_channels[1], 204 | n_channels[2], 205 | n_blocks_per_stage, 206 | block, 207 | stride=2, 208 | drop_rate=drop_rate) 209 | self.stage3 = self._make_stage( 210 | n_channels[2], 211 | n_channels[3], 212 | n_blocks_per_stage, 213 | block, 214 | stride=2, 215 | drop_rate=drop_rate) 216 | self.bn = nn.BatchNorm2d(n_channels[3]) 217 | 218 | # compute conv feature size 219 | with torch.no_grad(): 220 | self.feature_size = self._forward_conv( 221 | torch.zeros(*input_shape)).view(-1).shape[0] 222 | 223 | self.fc = nn.Linear(self.feature_size, n_classes) 224 | 225 | # initialize weights 226 | self.apply(initialize_weights) 227 | 228 | def _make_stage(self, in_channels, out_channels, n_blocks, block, stride, 229 | drop_rate): 230 | stage = nn.Sequential() 231 | for index in range(n_blocks): 232 | block_name = 'block{}'.format(index + 1) 233 | if index == 0: 234 | stage.add_module( 235 | block_name, 236 | block( 237 | in_channels, 238 | out_channels, 239 | stride=stride, 240 | drop_rate=drop_rate)) 241 | else: 242 | stage.add_module( 243 | block_name, 244 | block( 245 | out_channels, 246 | out_channels, 247 | stride=1, 248 | drop_rate=drop_rate)) 249 | return stage 250 | 251 | def _forward_conv(self, x): 252 | x = self.conv(x) 253 | x = self.stage1(x) 254 | x = self.stage2(x) 255 | x = self.stage3(x) 256 | x = F.relu(self.bn(x), inplace=True) 257 | x = F.adaptive_avg_pool2d(x, output_size=1) 258 | return x 259 | 260 | def forward(self, x): 261 | x = self._forward_conv(x) 262 | x = x.view(x.size(0), -1) 263 | x = self.fc(x) 264 | return x 265 | 266 | 267 | def WRN28_10(num_classes=10): 268 | config = { 269 | 'input_shape': (1, 3, 32, 32), 270 | 'n_classes': num_classes, 271 | 'base_channels': 16, 272 | 'widening_factor': 10, 273 | 'drop_rate': 0.3, 274 | 'depth': 28 275 | } 276 | return Network(config) 277 | 278 | def DeiT(num_classes=10): 279 | model = deit_tiny_patch16_224(pretrained=False, img_size = 32, patch_size = 2, num_classes=num_classes) 280 | return model --------------------------------------------------------------------------------