├── train_eval_scripts
├── recoloradv
│ ├── __init__.py
│ ├── mister_ed
│ │ ├── __init__.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── pytorch_ssim.py
│ │ │ ├── image_utils.py
│ │ │ └── discretization.py
│ │ ├── cifar10
│ │ │ ├── __init__.py
│ │ │ ├── wide_resnets.py
│ │ │ ├── cifar_resnets.py
│ │ │ └── cifar_loader.py
│ │ ├── README.md
│ │ ├── config.py
│ │ └── scripts
│ │ │ └── setup_cifar.py
│ ├── norms.py
│ ├── examples
│ │ ├── evaluate_cifar10.py
│ │ └── evaluate_imagenet.py
│ ├── utils.py
│ ├── perturbations.py
│ └── color_spaces.py
├── README.md
├── recolor.py
├── eval_cifar100.py
├── corruption.py
├── eval.py
├── stadv.py
├── attack.py
├── train.py
├── sam.py
└── model.py
├── SAM_segmentation
├── checkpoints
│ └── exp_log_and_checkpoints_will_be_saved_here.txt
├── metrics
│ ├── __init__.py
│ └── stream_metrics.py
├── network
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── mobilenetv2.py
│ │ └── xception.py
│ ├── utils.py
│ └── _deeplab.py
├── requirements.txt
├── datasets
│ ├── __init__.py
│ ├── iccv09.py
│ ├── utils.py
│ ├── voc.py
│ └── cityscapes.py
├── utils
│ ├── __init__.py
│ ├── scheduler.py
│ ├── loss.py
│ ├── utils.py
│ ├── sam.py
│ ├── visualizer.py
│ └── attack.py
├── .gitignore
├── LICENSE
└── README.md
├── README.md
├── eval.py
├── eval_cifar100.py
├── sam.py
├── model.py
├── utils.py
└── train.py
/train_eval_scripts/recoloradv/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/cifar10/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SAM_segmentation/checkpoints/exp_log_and_checkpoints_will_be_saved_here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SAM_segmentation/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .stream_metrics import StreamSegMetrics, AverageMeter
2 |
3 |
--------------------------------------------------------------------------------
/SAM_segmentation/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling import *
2 | from ._deeplab import convert_to_separable_conv
--------------------------------------------------------------------------------
/SAM_segmentation/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | pillow
5 | scikit-learn
6 | tqdm
7 | matplotlib
8 | visdom
--------------------------------------------------------------------------------
/SAM_segmentation/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .voc import VOCSegmentation
2 | from .cityscapes import Cityscapes
3 | from .iccv09 import Iccv2009Dataset
--------------------------------------------------------------------------------
/SAM_segmentation/network/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from . import resnet
2 | from . import mobilenetv2
3 | from . import hrnetv2
4 | from . import xception
5 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/README.md:
--------------------------------------------------------------------------------
1 | Code in this directory is adapted from the [`mister_ed`](https://github.com/revbucket/mister_ed) library.
--------------------------------------------------------------------------------
/SAM_segmentation/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .visualizer import Visualizer
3 | from .scheduler import PolyLR
4 | from .loss import FocalLoss
5 | from .attack import PGD, normalize_voc, normalize_city, normalize_iccv09
6 | from .sam import SAM
--------------------------------------------------------------------------------
/SAM_segmentation/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | VOCdevkit
3 | checkpoints/*/*.pth
4 | .vscode
5 | *.pyc
6 | .idea/
7 | __pycache__
8 | results
9 | checkpoints_bak
10 | cityscapes
11 | test_results
12 | datasets/data
13 | samples/
14 | *.zip
15 | iccv09-celoss.csv
16 | iccv09-sam.csv
17 | iccv09.csv
18 | wandb/
--------------------------------------------------------------------------------
/train_eval_scripts/README.md:
--------------------------------------------------------------------------------
1 | This folder contains the code for the image classification project. The main files are as follows:
2 |
3 | 1. `train.py`: train a classification model on CIFAR10(100) / TinyImageNet using SGD/Adam/SAM/AT. For AWP we use the code from the [official repository](https://github.com/csdongxian/AWP)
4 | 2. `attack.py`: test adversarial robustness of a model using torchattacks
5 | 3. `corruption.py`: test general robustness of a model using robustbench
6 | 4. `sam_trainer.py`: train and test a text classification model
7 |
--------------------------------------------------------------------------------
/SAM_segmentation/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler, StepLR
2 |
3 | class PolyLR(_LRScheduler):
4 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
5 | self.power = power
6 | self.max_iters = max_iters # avoid zero lr
7 | self.min_lr = min_lr
8 | super(PolyLR, self).__init__(optimizer, last_epoch)
9 |
10 | def get_lr(self):
11 | return [ max( base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power, self.min_lr)
12 | for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/SAM_segmentation/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | class FocalLoss(nn.Module):
6 | def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
7 | super(FocalLoss, self).__init__()
8 | self.alpha = alpha
9 | self.gamma = gamma
10 | self.ignore_index = ignore_index
11 | self.size_average = size_average
12 |
13 | def forward(self, inputs, targets):
14 | ce_loss = F.cross_entropy(
15 | inputs, targets, reduction='none', ignore_index=self.ignore_index)
16 | pt = torch.exp(-ce_loss)
17 | focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
18 | if self.size_average:
19 | return focal_loss.mean()
20 | else:
21 | return focal_loss.sum()
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | config_dir = os.path.abspath(os.path.dirname(__file__))
4 |
5 | def path_resolver(path):
6 | if path.startswith('~/'):
7 | return os.path.expanduser(path)
8 |
9 | if path.startswith('./'):
10 | return os.path.join(*[config_dir] + path.split('/')[1:])
11 |
12 |
13 | DEFAULT_DATASETS_DIR = path_resolver('~/datasets')
14 | MODEL_PATH = path_resolver('./pretrained_models/')
15 | OUTPUT_IMAGE_PATH = path_resolver('./output_images/')
16 |
17 |
18 | DEFAULT_BATCH_SIZE = 128
19 | DEFAULT_WORKERS = 4
20 | CIFAR10_MEANS = [0.485, 0.456, 0.406]
21 | CIFAR10_STDS = [0.229, 0.224, 0.225]
22 |
23 | WIDE_CIFAR10_MEANS = [0.4914, 0.4822, 0.4465]
24 | WIDE_CIFAR10_STDS = [0.2023, 0.1994, 0.2010]
25 |
26 |
27 | IMAGENET_MEANS = [0.485, 0.456, 0.406]
28 | IMAGENET_STDS = [0.229, 0.224, 0.225]
29 |
--------------------------------------------------------------------------------
/SAM_segmentation/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Gongfan Fang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On the Duality Between Sharpness-Aware Minimization and Adversarial Training
2 | ## ICML 2024
3 |
4 | Yihao Zhang\*, Hangzhou He\*, Jingyu Zhu\*, Huanran Chen, Yifei Wang, [Zeming Wei](https://weizeming.github.io)${}^\dagger$
5 |
6 |
7 | ## Sharpness-Aware Minimization Alone can Improve Adversarial Robustness (Workshop version)
8 | ### ICML 2023 AdvML-Frontiers Workshop
9 | [Zeming Wei](https://weizeming.github.io)${}^\dagger$\*, Jingyu Zhu\* and [Yihao Zhang](https://zhang-yihao.github.io/)\*
10 |
11 | ## Citation
12 | ```
13 | @InProceedings{zhang2024duality,
14 | author = {Zhang, Yihao and He, Hangzhou and Zhu, Jingyu and Chen, Huanran and Wang, Yifei and Wei, Zeming},
15 | title = {On the Duality Between Sharpness-Aware Minimization and Adversarial Training},
16 | booktitle = {ICML},
17 | year = {2024}
18 | }
19 | ```
20 | and/or
21 | ```
22 | @InProceedings{wei2023sharpness,
23 | author = {Wei, Zeming and Zhu, Jingyu and Zhang, Yihao},
24 | title = {Sharpness-Aware Minimization Alone can Improve Adversarial Robustness},
25 | booktitle = {ICML 2023 Workshop on New Frontiers in Adversarial Machine Learning},
26 | year = {2023}
27 | }
28 | ```
29 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torch.nn.functional as F
5 | import os
6 | from model import PreActResNet18
7 | from utils import *
8 |
9 |
10 | if __name__ == '__main__':
11 | file_list = os.listdir('models')
12 | model = PreActResNet18()
13 |
14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf')
15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf')
16 |
17 | PGD16 = PGD(10, 2./255., 16./255., 'l2')
18 | PGD32 = PGD(10, 4./255., 32./255., 'l2')
19 |
20 | _, loader = load_dataset('cifar10', 1024)
21 |
22 | for m in file_list:
23 | ckpt = torch.load('models/' + m, map_location='cpu')
24 | model.load_state_dict(ckpt)
25 | model.eval()
26 | model.cuda()
27 | accs = []
28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]):
29 | acc = 0
30 | for x,y in loader:
31 | x, y = x.cuda(), y.cuda()
32 | delta = attack.perturb(model, x, y)
33 | pred = model((normalize_cifar(x+delta)))
34 | acc += (pred.max(1)[1] == y).float().sum().item()
35 | acc /= 100
36 | accs.append(acc)
37 | print(m)
38 | print(' & '.join([str(a) for a in accs]))
--------------------------------------------------------------------------------
/train_eval_scripts/recolor.py:
--------------------------------------------------------------------------------
1 | import recoloradv.mister_ed.config as config
2 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize
3 |
4 | # ReColorAdv
5 | from recoloradv.utils import get_attack_from_name
6 | from model import PreActResNet18
7 | from utils import *
8 |
9 |
10 | class Model(nn.Module):
11 | def __init__(self, model, norm):
12 | super(Model, self).__init__()
13 | self.model = model
14 | self.norm = norm
15 |
16 | def forward(self, x):
17 | return self.model(self.norm(x))
18 |
19 |
20 | model = PreActResNet18(10)
21 | model.load_state_dict(torch.load('./cifar10_models/cifar10_prn_sgd_sub.pth'))
22 | model.eval()
23 | model.cuda()
24 |
25 | # PGD attack
26 | # Mod = Model(model, normalize_cifar)
27 | # Mod.eval()
28 | # Mod.cuda()
29 |
30 | # get imgs and labels
31 | train_loader, test_loader = load_dataset('cifar10', 1024)
32 | normalizer = DifferentiableNormalize(
33 | mean=config.CIFAR10_MEANS,
34 | std=config.CIFAR10_STDS,
35 | )
36 | attack = get_attack_from_name('recoloradv', model, normalizer, verbose=True)
37 | acc = 0
38 | for x, y in test_loader:
39 | x, y = x.cuda(), y.cuda()
40 | adv_x = attack.attack(x, y)[0]
41 | pred = model(normalizer(adv_x))
42 | acc += (pred.max(1)[1] == y).float().sum().item()
43 | break
44 | acc /= 1024
45 | print(acc)
46 |
--------------------------------------------------------------------------------
/eval_cifar100.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torch.nn.functional as F
5 | import os
6 | from model import PreActResNet18
7 | from utils import *
8 |
9 |
10 | if __name__ == '__main__':
11 | file_list = os.listdir('cifar100_models')
12 | model = PreActResNet18(100)
13 |
14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf', False, normalize_cifar100)
15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf', False, normalize_cifar100)
16 |
17 | PGD16 = PGD(10, 2./255., 16./255., 'l2', False, normalize_cifar100)
18 | PGD32 = PGD(10, 4./255., 32./255., 'l2', False, normalize_cifar100)
19 |
20 | _, loader = load_dataset('cifar100', 1024)
21 |
22 | for m in file_list:
23 | ckpt = torch.load('cifar100_models/' + m, map_location='cpu')
24 | model.load_state_dict(ckpt)
25 | model.eval()
26 | model.cuda()
27 | accs = []
28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]):
29 | acc = 0
30 | for x,y in loader:
31 | x, y = x.cuda(), y.cuda()
32 | delta = attack.perturb(model, x, y)
33 | pred = model((normalize_cifar(x+delta)))
34 | acc += (pred.max(1)[1] == y).float().sum().item()
35 | acc /= 100
36 | accs.append(acc)
37 | print(m)
38 | print(' & '.join([str(a) for a in accs]))
--------------------------------------------------------------------------------
/train_eval_scripts/eval_cifar100.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torch.nn.functional as F
5 | import os
6 | from model import PreActResNet18
7 | from utils import *
8 |
9 |
10 | if __name__ == '__main__':
11 | file_list = os.listdir('cifar100_models')
12 | model = PreActResNet18(100)
13 |
14 | PGD1 = PGD(10, 0.25/255., 1./255., 'linf', False, normalize_cifar100)
15 | PGD2 = PGD(10, 0.5/255., 2./255., 'linf', False, normalize_cifar100)
16 |
17 | PGD16 = PGD(10, 2./255., 16./255., 'l2', False, normalize_cifar100)
18 | PGD32 = PGD(10, 4./255., 32./255., 'l2', False, normalize_cifar100)
19 |
20 | _, loader = load_dataset('cifar100', 1024)
21 |
22 | for m in file_list:
23 | ckpt = torch.load('cifar100_models/' + m, map_location='cpu')
24 | model.load_state_dict(ckpt)
25 | model.eval()
26 | model.cuda()
27 | accs = []
28 | for id, attack in enumerate([PGD1, PGD2, PGD16, PGD32]):
29 | acc = 0
30 | for x,y in loader:
31 | x, y = x.cuda(), y.cuda()
32 | delta = attack.perturb(model, x, y)
33 | pred = model((normalize_cifar(x+delta)))
34 | acc += (pred.max(1)[1] == y).float().sum().item()
35 | acc /= 100
36 | accs.append(acc)
37 | print(m)
38 | print(' & '.join([str(a) for a in accs]))
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/norms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 |
5 | def smoothness(grid):
6 | """
7 | Given a variable of dimensions (N, X, Y, [Z], C), computes the sum of
8 | the differences between adjacent points in the grid formed by the
9 | dimensions X, Y, and (optionally) Z. Returns a tensor of dimension N.
10 | """
11 |
12 | num_dims = len(grid.size()) - 2
13 | batch_size = grid.size()[0]
14 | norm = Variable(torch.zeros(batch_size, dtype=grid.data.dtype,
15 | device=grid.data.device))
16 |
17 | for dim in range(num_dims):
18 | slice_before = (slice(None),) * (dim + 1)
19 | slice_after = (slice(None),) * (num_dims - dim)
20 | shifted_grids = [
21 | # left
22 | torch.cat([
23 | grid[slice_before + (slice(1, None),) + slice_after],
24 | grid[slice_before + (slice(-1, None),) + slice_after],
25 | ], dim + 1),
26 | # right
27 | torch.cat([
28 | grid[slice_before + (slice(None, 1),) + slice_after],
29 | grid[slice_before + (slice(None, -1),) + slice_after],
30 | ], dim + 1)
31 | ]
32 | for shifted_grid in shifted_grids:
33 | delta = shifted_grid - grid
34 | norm_components = (delta.pow(2).sum(-1) + 1e-10).pow(0.5)
35 | norm.add_(norm_components.sum(
36 | tuple(range(1, len(norm_components.size())))))
37 |
38 | return norm
39 |
--------------------------------------------------------------------------------
/SAM_segmentation/utils/utils.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms.functional import normalize
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 | import sys
6 |
7 | def denormalize(tensor, mean, std):
8 | mean = np.array(mean)
9 | std = np.array(std)
10 |
11 | _mean = -mean/std
12 | _std = 1/std
13 | return normalize(tensor, _mean, _std)
14 |
15 | class Logger(object):
16 | # 作用:将print的内容保存到文件中,同时在屏幕上显示,且没次输出都刷新文件,但是屏幕不刷新
17 | def __init__(self, filename="log.txt"):
18 | self.terminal = sys.stdout
19 | self.log = open(filename, 'a')
20 |
21 | def write(self, message):
22 | self.terminal.write(message)
23 | self.log.write(message)
24 | self.log.flush()
25 |
26 | def flush(self):
27 | pass
28 |
29 | class Denormalize(object):
30 | def __init__(self, mean, std):
31 | mean = np.array(mean)
32 | std = np.array(std)
33 | self._mean = -mean/std
34 | self._std = 1/std
35 |
36 | def __call__(self, tensor):
37 | if isinstance(tensor, np.ndarray):
38 | return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
39 | return normalize(tensor, self._mean, self._std)
40 |
41 | def set_bn_momentum(model, momentum=0.1):
42 | for m in model.modules():
43 | if isinstance(m, nn.BatchNorm2d):
44 | m.momentum = momentum
45 |
46 | def fix_bn(model):
47 | for m in model.modules():
48 | if isinstance(m, nn.BatchNorm2d):
49 | m.eval()
50 |
51 | def mkdir(path):
52 | if not os.path.exists(path):
53 | os.mkdir(path)
54 |
--------------------------------------------------------------------------------
/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SAM(torch.optim.Optimizer):
5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7 |
8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9 | super(SAM, self).__init__(params, defaults)
10 |
11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12 | self.param_groups = self.base_optimizer.param_groups
13 | self.defaults.update(self.base_optimizer.defaults)
14 |
15 | @torch.no_grad()
16 | def first_step(self, zero_grad=False):
17 | grad_norm = self._grad_norm()
18 | for group in self.param_groups:
19 | scale = group["rho"] / (grad_norm + 1e-12)
20 |
21 | for p in group["params"]:
22 | if p.grad is None: continue
23 | self.state[p]["old_p"] = p.data.clone()
24 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
25 | p.add_(e_w) # climb to the local maximum "w + e(w)"
26 |
27 | if zero_grad: self.zero_grad()
28 |
29 | @torch.no_grad()
30 | def second_step(self, zero_grad=False):
31 | for group in self.param_groups:
32 | for p in group["params"]:
33 | if p.grad is None: continue
34 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
35 |
36 | self.base_optimizer.step() # do the actual "sharpness-aware" update
37 |
38 | if zero_grad: self.zero_grad()
39 |
40 | @torch.no_grad()
41 | def step(self, closure=None):
42 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
43 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
44 |
45 | self.first_step(zero_grad=True)
46 | closure()
47 | self.second_step()
48 |
49 | def _grad_norm(self):
50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51 | norm = torch.norm(
52 | torch.stack([
53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54 | for group in self.param_groups for p in group["params"]
55 | if p.grad is not None
56 | ]),
57 | p=2
58 | )
59 | return norm
60 |
61 | def load_state_dict(self, state_dict):
62 | super().load_state_dict(state_dict)
63 | self.base_optimizer.param_groups = self.param_groups
64 |
--------------------------------------------------------------------------------
/SAM_segmentation/utils/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class SAM(torch.optim.Optimizer):
4 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
5 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
6 |
7 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
8 | super().__init__(params, defaults)
9 | if isinstance(base_optimizer, torch.optim.Optimizer):
10 | self.base_optimizer = base_optimizer
11 | print("SAM is applied to inner optimizer: ", base_optimizer.__class__.__name__)
12 | else:
13 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
14 | self.param_groups = self.base_optimizer.param_groups
15 | self.defaults.update(self.base_optimizer.defaults)
16 |
17 | @torch.no_grad()
18 | def first_step(self, zero_grad=False):
19 | grad_norm = self._grad_norm()
20 | for group in self.param_groups:
21 | scale = group["rho"] / (grad_norm + 1e-12)
22 |
23 | for p in group["params"]:
24 | if p.grad is None: continue
25 | self.state[p]["old_p"] = p.data.clone()
26 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
27 | p.add_(e_w) # climb to the local maximum "w + e(w)"
28 |
29 | if zero_grad: self.zero_grad()
30 |
31 | @torch.no_grad()
32 | def second_step(self, zero_grad=False):
33 | for group in self.param_groups:
34 | for p in group["params"]:
35 | if p.grad is None: continue
36 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
37 |
38 | self.base_optimizer.step() # do the actual "sharpness-aware" update
39 |
40 | if zero_grad: self.zero_grad()
41 |
42 | @torch.no_grad()
43 | def step(self, closure=None):
44 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
45 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
46 |
47 | self.first_step(zero_grad=True)
48 | closure()
49 | self.second_step()
50 |
51 | def _grad_norm(self):
52 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
53 | norm = torch.norm(
54 | torch.stack([
55 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
56 | for group in self.param_groups for p in group["params"]
57 | if p.grad is not None
58 | ]),
59 | p=2
60 | )
61 | return norm
62 |
63 | def load_state_dict(self, state_dict):
64 | super().load_state_dict(state_dict)
65 | self.base_optimizer.param_groups = self.param_groups
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/utils/pytorch_ssim.py:
--------------------------------------------------------------------------------
1 | """ Implementation directly lifted from Po-Hsun-Su for pytorch ssim
2 | See github repo here: https://github.com/Po-Hsun-Su/pytorch-ssim
3 | """
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | import numpy as np
8 | from math import exp
9 |
10 | def gaussian(window_size, sigma):
11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
12 | return gauss/gauss.sum()
13 |
14 | def create_window(window_size, channel):
15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
18 | return window
19 |
20 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
21 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
22 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
23 |
24 | mu1_sq = mu1.pow(2)
25 | mu2_sq = mu2.pow(2)
26 | mu1_mu2 = mu1*mu2
27 |
28 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
29 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
30 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
31 |
32 | C1 = 0.01**2
33 | C2 = 0.03**2
34 |
35 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
36 |
37 | if size_average:
38 | return ssim_map.mean()
39 | else:
40 | return ssim_map.mean(1).mean(1).mean(1)
41 |
42 | class SSIM(torch.nn.Module):
43 | def __init__(self, window_size = 11, size_average = True):
44 | super(SSIM, self).__init__()
45 | self.window_size = window_size
46 | self.size_average = size_average
47 | self.channel = 1
48 | self.window = create_window(window_size, self.channel)
49 |
50 | def forward(self, img1, img2):
51 | (_, channel, _, _) = img1.size()
52 |
53 | if channel == self.channel and self.window.data.type() == img1.data.type():
54 | window = self.window
55 | else:
56 | window = create_window(self.window_size, channel)
57 |
58 | if img1.is_cuda:
59 | window = window.cuda(img1.get_device())
60 | window = window.type_as(img1)
61 |
62 | self.window = window
63 | self.channel = channel
64 |
65 |
66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
67 |
68 | def ssim(img1, img2, window_size = 11, size_average = True):
69 | (_, channel, _, _) = img1.size()
70 | window = create_window(window_size, channel)
71 |
72 | if img1.is_cuda:
73 | window = window.cuda(img1.get_device())
74 | window = window.type_as(img1)
75 |
76 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/SAM_segmentation/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | from visdom import Visdom
2 | import json
3 |
4 | class Visualizer(object):
5 | """ Visualizer
6 | """
7 | def __init__(self, port='13579', env='main', id=None):
8 | #self.cur_win = {}
9 | self.vis = Visdom(port=port, env=env)
10 | self.id = id
11 | self.env = env
12 | # Restore
13 | #ori_win = self.vis.get_window_data()
14 | #ori_win = json.loads(ori_win)
15 | #print(ori_win)
16 | #self.cur_win = { v['title']: k for k, v in ori_win.items() }
17 |
18 | def vis_scalar(self, name, x, y, opts=None):
19 | if not isinstance(x, list):
20 | x = [x]
21 | if not isinstance(y, list):
22 | y = [y]
23 |
24 | if self.id is not None:
25 | name = "[%s]"%self.id + name
26 | default_opts = { 'title': name }
27 | if opts is not None:
28 | default_opts.update(opts)
29 |
30 | #win = self.cur_win.get(name, None)
31 | #if win is not None:
32 | self.vis.line( X=x, Y=y, win=name, opts=default_opts, update='append')
33 | #else:
34 | # self.cur_win[name] = self.vis.line( X=x, Y=y, opts=default_opts)
35 |
36 | def vis_image(self, name, img, env=None, opts=None):
37 | """ vis image in visdom
38 | """
39 | if env is None:
40 | env = self.env
41 | if self.id is not None:
42 | name = "[%s]"%self.id + name
43 | #win = self.cur_win.get(name, None)
44 | default_opts = { 'title': name }
45 | if opts is not None:
46 | default_opts.update(opts)
47 | #if win is not None:
48 | self.vis.image( img=img, win=name, opts=opts, env=env )
49 | #else:
50 | # self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env )
51 |
52 | def vis_table(self, name, tbl, opts=None):
53 | #win = self.cur_win.get(name, None)
54 |
55 | tbl_str = "
"
56 | tbl_str+=" \
57 | | Term | \
58 | Value | \
59 |
"
60 | for k, v in tbl.items():
61 | tbl_str+= " \
62 | | %s | \
63 | %s | \
64 |
"%(k, v)
65 |
66 | tbl_str+="
"
67 |
68 | default_opts = { 'title': name }
69 | if opts is not None:
70 | default_opts.update(opts)
71 | #if win is not None:
72 | self.vis.text(tbl_str, win=name, opts=default_opts)
73 | #else:
74 | #self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts)
75 |
76 |
77 | if __name__=='__main__':
78 | import numpy as np
79 | vis = Visualizer(port=35588, env='main')
80 | tbl = {"lr": 214, "momentum": 0.9}
81 | vis.vis_table("test_table", tbl)
82 | tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"}
83 | vis.vis_table("test_table", tbl)
84 |
85 | vis.vis_scalar(name='loss', x=0, y=1)
86 | vis.vis_scalar(name='loss', x=2, y=4)
87 | vis.vis_scalar(name='loss', x=4, y=6)
--------------------------------------------------------------------------------
/SAM_segmentation/README.md:
--------------------------------------------------------------------------------
1 | # Sharpness-Aware Minimization Alone can Improve Adversarial Robustness in Semantic Segmentation
2 |
3 | The semantic segmentation code is adapted from [VainF](https://github.com/VainF/DeepLabV3Plus-Pytorch)
4 |
5 | ## Pascal VOC
6 |
7 | ### 1. Requirements
8 |
9 | ```bash
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ### 2. Prepare Datasets
14 |
15 | #### 2.1 Standard Pascal VOC
16 | You can run train.py with "--download" option to download and extract pascal voc dataset. The defaut path is './datasets/data':
17 |
18 | ```
19 | /datasets
20 | /data
21 | /VOCdevkit
22 | /VOC2012
23 | /SegmentationClass
24 | /JPEGImages
25 | ...
26 | ...
27 | /VOCtrainval_11-May-2012.tar
28 | ...
29 | ```
30 |
31 | #### 2.2 Pascal VOC trainaug (Recommended!!)
32 |
33 | See chapter 4 of [2]
34 |
35 | The original dataset contains 1464 (train), 1449 (val), and 1456 (test) pixel-level annotated images. We augment the dataset by the extra annotations provided by [76], resulting in 10582 (trainaug) training images. The performance is measured in terms of pixel intersection-over-union averaged across the 21 classes (mIOU).
36 |
37 | *./datasets/data/train_aug.txt* includes the file names of 10582 trainaug images (val images are excluded). Please to download their labels from [Dropbox](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) or [Tencent Weiyun](https://share.weiyun.com/5NmJ6Rk). Those labels come from [DrSleep's repo](https://github.com/DrSleep/tensorflow-deeplab-resnet).
38 |
39 | Extract trainaug labels (SegmentationClassAug) to the VOC2012 directory.
40 |
41 | ```
42 | /datasets
43 | /data
44 | /VOCdevkit
45 | /VOC2012
46 | /SegmentationClass
47 | /SegmentationClassAug # <= the trainaug labels
48 | /JPEGImages
49 | ...
50 | ...
51 | /VOCtrainval_11-May-2012.tar
52 | ...
53 | ```
54 |
55 | ### 3. Training on Pascal VOC2012 Aug
56 |
57 | #### 3.2 Training with OS=16
58 |
59 | Run main.py with *"--year 2012_aug"* to train your model on Pascal VOC2012 Aug. You can also parallel your training on 4 GPUs with '--gpu_id 0,1,2,3'
60 |
61 | **Note: There is no SyncBN in this repo, so training with *multple GPUs and small batch size* may degrades the performance. See [PyTorch-Encoding](https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html) for more details about SyncBN**
62 |
63 | ```bash
64 | python main.py --model deeplabv3_mobilenet --gpu_id 3 --year 2012_aug --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16 --optimizer SAM --rho 0.02 --exp_name voc-SAM
65 | ```
66 |
67 |
68 | ## Reference
69 |
70 | [1] [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
71 |
72 | [2] [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
73 |
74 | [3] [VainF/DeepLabV3Plus-Pytorch](https://github.com/VainF/DeepLabV3Plus-Pytorch)
75 |
76 | [4] [SAM implementation](https://github.com/weizeming/SAM_AT)
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/examples/evaluate_cifar10.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import argparse
4 | import sys
5 | import os
6 | from torch import optim
7 | from torch.utils.data import DataLoader
8 | from torchvision.models import resnet50
9 | from torchvision.datasets import ImageNet
10 | from torchvision import transforms
11 |
12 | # mister_ed
13 | from recoloradv.mister_ed import loss_functions as lf
14 | from recoloradv.mister_ed import adversarial_training as advtrain
15 | from recoloradv.mister_ed import adversarial_perturbations as ap
16 | from recoloradv.mister_ed import adversarial_attacks as aa
17 | from recoloradv.mister_ed import spatial_transformers as st
18 | from recoloradv.mister_ed.utils import pytorch_utils as utils
19 | from recoloradv.mister_ed.cifar10 import cifar_loader
20 |
21 | # ReColorAdv
22 | from recoloradv import perturbations as pt
23 | from recoloradv import color_transformers as ct
24 | from recoloradv import color_spaces as cs
25 | from recoloradv.utils import get_attack_from_name, load_pretrained_cifar10_model
26 |
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser(
30 | description='Evaluate a model trained on CIFAR-10 '
31 | 'against ReColorAdv and other attacks'
32 | )
33 |
34 | parser.add_argument('--checkpoint', type=str,
35 | help='checkpoint to evaluate')
36 | parser.add_argument('--attack', type=str,
37 | help='attack to run, such as "recoloradv" or '
38 | '"stadv+delta"')
39 | parser.add_argument('--batch_size', type=int, default=100,
40 | help='number of examples/minibatch')
41 | parser.add_argument('--num_batches', type=int, required=False,
42 | help='number of batches (default entire dataset)')
43 | args = parser.parse_args()
44 |
45 | model, normalizer = load_pretrained_cifar10_model(args.checkpoint)
46 | val_loader = cifar_loader.load_cifar_data('val', batch_size=args.batch_size)
47 |
48 | model.eval()
49 | if torch.cuda.is_available():
50 | model.cuda()
51 |
52 | attack = get_attack_from_name(args.attack, model, normalizer)
53 |
54 | batches_correct = []
55 | for batch_index, (inputs, labels) in enumerate(val_loader):
56 | if (
57 | args.num_batches is not None and
58 | batch_index >= args.num_batches
59 | ):
60 | break
61 |
62 | if torch.cuda.is_available():
63 | inputs = inputs.cuda()
64 | labels = labels.cuda()
65 |
66 | adv_inputs = attack.attack(
67 | inputs,
68 | labels,
69 | )[0]
70 | with torch.no_grad():
71 | adv_logits = model(normalizer(adv_inputs))
72 | batch_correct = (adv_logits.argmax(1) == labels).detach()
73 |
74 | batch_accuracy = batch_correct.float().mean().item()
75 | print(f'BATCH {batch_index:05d}',
76 | f'accuracy = {batch_accuracy * 100:.1f}',
77 | sep='\t')
78 | batches_correct.append(batch_correct)
79 |
80 | accuracy = torch.cat(batches_correct).float().mean().item()
81 | print('OVERALL ',
82 | f'accuracy = {accuracy * 100:.1f}',
83 | sep='\t')
84 |
--------------------------------------------------------------------------------
/train_eval_scripts/corruption.py:
--------------------------------------------------------------------------------
1 | import torchattacks
2 | from model import PreActResNet18, WRN28_10, DeiT
3 | from utils import *
4 | import recoloradv.mister_ed.config as config
5 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize
6 | from recoloradv.utils import get_attack_from_name
7 | from argparse import ArgumentParser
8 |
9 | from robustbench.data import load_cifar10c, load_cifar100c
10 | from robustbench.utils import clean_accuracy
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--model_path', default='put filename here', type=str)
14 | args = parser.parse_args()
15 | file_name = args.model_path
16 |
17 | class Model(nn.Module):
18 | def __init__(self, model, norm):
19 | super(Model, self).__init__()
20 | self.model = model
21 | self.norm = norm
22 |
23 | def forward(self, x):
24 | return self.model(self.norm(x))
25 |
26 | label_dim = 10
27 | if 'cifar10_' in file_name:
28 | label_dim = 10
29 | normalizer = DifferentiableNormalize(
30 | mean=config.CIFAR10_MEANS,
31 | std=config.CIFAR10_STDS,
32 | )
33 | norm = normalize_cifar
34 | train_loader, test_loader = load_dataset('cifar10', 1000)
35 | elif 'cifar100_' in file_name:
36 | label_dim = 100
37 | normalizer = DifferentiableNormalize(
38 | mean=CIFAR100_MEAN,
39 | std=CIFAR100_STD,
40 | )
41 | norm = normalize_cifar100
42 | train_loader, test_loader = load_dataset('cifar100', 1000)
43 | elif 'tiny' in file_name:
44 | label_dim = 200
45 | normalizer = DifferentiableNormalize(
46 | mean=TINYIMAGENET_MEAN,
47 | std=TINYIMAGENET_STD,
48 | )
49 | norm = normalize_tinyimagenet
50 | train_loader, test_loader = load_dataset('tiny-imagenet-200', 1000)
51 | else:
52 | raise ValueError('Unknown dataset')
53 |
54 | if 'prn' in file_name and 'deit' not in file_name and 'wrn' not in file_name:
55 | model = PreActResNet18(label_dim)
56 | elif 'wrn' in file_name:
57 | model = WRN28_10(label_dim)
58 | elif 'deit' in file_name:
59 | model = DeiT(label_dim)
60 |
61 | corruption_test_types = [['brightness'], ['fog'], ['frost'], ['gaussian_blur'], ['impulse_noise'], ['jpeg_compression'], ['shot_noise'], ['snow'], ['speckle_noise']]
62 | for corruptions in corruption_test_types:
63 | print(f'\n##### corruption type: {corruptions}\n')
64 | x_test, y_test = load_cifar10c(n_examples=1000, corruptions=corruptions, severity=3)
65 | for model_name in ['put file name here', 'put file name here']:
66 | model = PreActResNet18(label_dim)
67 | if 'awp' in model_name:
68 | d = torch.load('./models/' + model_name, map_location='cuda:0')
69 | for k in list(d.keys()):
70 | if k.startswith('module.'):
71 | d[k[7:]] = d[k]
72 | del d[k]
73 | model.load_state_dict(d)
74 |
75 | else:
76 | model.load_state_dict(torch.load('./models/' + model_name, map_location='cuda:0'))
77 | model.eval()
78 | model.cuda()
79 | acc = clean_accuracy(model, x_test, y_test, device=torch.device('cuda'))
80 | print(f'Model: {model_name}, CIFAR-10-C accuracy: {acc:.1%}')
--------------------------------------------------------------------------------
/train_eval_scripts/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torch.nn.functional as F
5 | import os
6 | from model import PreActResNet18, WRN28_10, DeiT
7 | from autoattack import AutoAttack
8 | from utils import *
9 | import argparse
10 |
11 | def get_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--model_path', type=str, required=True)
14 | parser.add_argument('--model', type=str, default='PRN', choices=['PRN', 'WRN', 'DeiT'])
15 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tiny-imagenet-200'])
16 | parser.add_argument('--attacker', default='PGD', choices=['PGD', 'FGSM', 'CW', 'AutoAttack'])
17 | parser.add_argument('--eps', default=8./255., type=float)
18 | parser.add_argument('--batch-size', default=1024, type=int)
19 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2'])
20 | return parser.parse_args()
21 |
22 | args = get_args()
23 |
24 | if __name__ == '__main__':
25 | model_path = args.model_path
26 | dataset = args.dataset
27 | model_name = args.model
28 | label_dim = {'cifar10': 10, 'cifar100': 100, 'tiny-imagenet-200': 200}[dataset]
29 | model = {'PRN': PreActResNet18(label_dim), 'WRN': WRN28_10(label_dim), 'DeiT': DeiT(label_dim)}[model_name]
30 | normalizer = {'cifar10': normalize_cifar, 'cifar100': normalize_cifar, 'tiny-imagenet-200': normalize_tinyimagenet}[dataset]
31 | attacker = args.attacker
32 |
33 | #PGD1 = PGD(10, 0.25/255., 1./255., 'linf')
34 | #PGD2 = PGD(10, 0.5/255., 2./255., 'linf')
35 |
36 | #PGD16 = PGD(10, 2./255., 16./255., 'l2')
37 | #PGD32 = PGD(10, 4./255., 32./255., 'l2')
38 | #FGSM1 = PGD(1, 0.25/255., 1./255., 'linf')
39 | #FGSM16 = PGD(1, 2./255., 16./255., 'l2')
40 |
41 | pgd_iters = 10 if attacker == 'PGD' else 1
42 | eps = args.eps
43 | alpha = eps / 4
44 | norm = args.norm
45 | pgd = PGD(pgd_iters, alpha, eps, norm, False, normalizer)
46 |
47 | _, loader = load_dataset(dataset, args.batch_size)
48 |
49 | ckpt = torch.load(model_path, map_location='cpu')
50 | model.load_state_dict(ckpt)
51 | model.eval()
52 | model.cuda()
53 | acc = 0
54 | if args.attacker in ['PGD', 'FGSM']:
55 | for x,y in loader:
56 | x, y = x.cuda(), y.cuda()
57 | delta = pgd.perturb(model, x, y)
58 | pred = model((normalizer(x+delta)))
59 | acc += (pred.max(1)[1] == y).float().sum().item()
60 | acc /= 100
61 | elif args.attacker == 'CW':
62 | for x,y in loader:
63 | x, y = x.cuda(), y.cuda()
64 | x = normalizer(x)
65 | attacked_images = cw_l2_attack(model, x, y)
66 | pred = model(attacked_images)
67 | acc += (pred.max(1)[1] == y).float().sum().item()
68 | acc /= 100
69 | elif args.attacker == 'AutoAttack':
70 | norm = 'Linf' if args.norm == 'linf' else 'L2'
71 | adversary = AutoAttack(model, norm=norm, eps=args.eps, version='standard')
72 | for x,y in loader:
73 | x, y = x.cuda(), y.cuda()
74 | x = normalizer(x)
75 | adv_images = adversary.run_standard_evaluation(x, y, bs=64)
76 | pred = model(adv_images)
77 | acc += (pred.max(1)[1] == y).float().sum().item()
78 | acc /= 100
79 | print("Model: {}, Dataset: {}, Attack: {}, Accuracy: {}".format(model_name, dataset, args.attacker, acc))
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/cifar10/wide_resnets.py:
--------------------------------------------------------------------------------
1 | """ Wide Resnet architecture implementation taken from this repo:
2 | https://github.com/meliketoy/wide-resnet.pytorch
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 |
11 | import sys
12 | import numpy as np
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
16 |
17 | def conv_init(m):
18 | classname = m.__class__.__name__
19 | if classname.find('Conv') != -1:
20 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
21 | init.constant(m.bias, 0)
22 | elif classname.find('BatchNorm') != -1:
23 | init.constant(m.weight, 1)
24 | init.constant(m.bias, 0)
25 |
26 | class wide_basic(nn.Module):
27 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
28 | super(wide_basic, self).__init__()
29 | self.bn1 = nn.BatchNorm2d(in_planes)
30 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
31 | self.dropout = nn.Dropout(p=dropout_rate)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
34 |
35 | self.shortcut = nn.Sequential()
36 | if stride != 1 or in_planes != planes:
37 | self.shortcut = nn.Sequential(
38 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
39 | )
40 |
41 | def forward(self, x):
42 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
43 | out = self.conv2(F.relu(self.bn2(out)))
44 | out += self.shortcut(x)
45 |
46 | return out
47 |
48 | class Wide_ResNet(nn.Module):
49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
50 | super(Wide_ResNet, self).__init__()
51 | self.in_planes = 16
52 |
53 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
54 | n = (depth-4)/6
55 | k = widen_factor
56 |
57 | print('| Wide-Resnet %dx%d' %(depth, k))
58 | nStages = [16, 16*k, 32*k, 64*k]
59 |
60 | self.conv1 = conv3x3(3,nStages[0])
61 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
62 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
63 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
64 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
65 | self.linear = nn.Linear(nStages[3], num_classes)
66 |
67 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
68 | strides = [stride] + [1]*(num_blocks-1)
69 | layers = []
70 |
71 | for stride in strides:
72 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
73 | self.in_planes = planes
74 |
75 | return nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | out = self.conv1(x)
79 | out = self.layer1(out)
80 | out = self.layer2(out)
81 | out = self.layer3(out)
82 | out = F.relu(self.bn1(out))
83 | out = F.avg_pool2d(out, 8)
84 | out = out.view(out.size(0), -1)
85 | out = self.linear(out)
86 |
87 | return out
88 |
89 | if __name__ == '__main__':
90 | net=Wide_ResNet(28, 10, 0.3, 10)
91 | y = net(Variable(torch.randn(1,3,32,32)))
92 |
93 | print(y.size())
--------------------------------------------------------------------------------
/SAM_segmentation/datasets/iccv09.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tarfile
4 | import collections
5 | import torch.utils.data as data
6 | import shutil
7 | import numpy as np
8 |
9 | from PIL import Image
10 | from torchvision.datasets.utils import download_url, check_integrity
11 |
12 | """
13 | class_names,r,g,b
14 | sky,68,1,84
15 | tree,72,40,140
16 | road,62,74,137
17 | grass,38,130,142
18 | water,31,158,137
19 | building,53,183,121
20 | mountain,109,205,89
21 | foreground,180,222,44
22 | unknown,49,104,142
23 | """
24 |
25 | mean = [0.4813, 0.4901, 0.4747] # rgb
26 | std = [0.2495, 0.2492, 0.2748] # rgb
27 |
28 | class Iccv2009Dataset(data.Dataset):
29 |
30 | rgb2id = {
31 | (68, 1, 84): 0,
32 | (72, 40, 140): 1,
33 | (62, 74, 137): 2,
34 | (38, 130, 142): 3,
35 | (31, 158, 137): 4,
36 | (53, 183, 121): 5,
37 | (109, 205, 89): 6,
38 | (180, 222, 44): 7,
39 | (49, 104, 142): 8,
40 | }
41 |
42 | def __init__(self, root, split, transform=None):
43 |
44 | self.image_root = os.path.join(root, 'images')
45 | self.mask_root = os.path.join(root, 'labels_colored')
46 | self.split = split
47 | self.images = []
48 | self.targets = []
49 | self.transform = transform
50 |
51 | for filename in os.listdir(self.image_root):
52 | if filename.endswith('.jpg'):
53 | self.images.append(os.path.join(self.image_root, filename))
54 | self.targets.append(os.path.join(self.mask_root, filename[:-4] + '.png'))
55 |
56 | if self.split == 'train':
57 | self.images = self.images[:int(0.7*len(self.images))]
58 | self.targets = self.targets[:int(0.7*len(self.targets))]
59 | elif self.split == 'val':
60 | self.images = self.images[int(0.7*len(self.images)):]
61 | self.targets = self.targets[int(0.7*len(self.targets)):]
62 | else:
63 | raise ValueError('Invalid split name: {}'.format(self.split))
64 |
65 | def __getitem__(self, index):
66 | image = Image.open(self.images[index]).convert('RGB')
67 | target = Image.open(self.targets[index])
68 | target = self.encode_mask(np.array(target))
69 | target = Image.fromarray(target)
70 | if self.transform is not None:
71 | image, target = self.transform(image, target)
72 |
73 | # tensor min-max normalization image, type(image) = Tensor
74 | image = (image - image.min())/(image.max() - image.min())
75 |
76 | return image, target
77 |
78 | def __len__(self):
79 | return len(self.images)
80 |
81 | @classmethod
82 | def encode_mask(cls, mask):
83 | for k in cls.rgb2id:
84 | mask[(mask == k).all(axis=2)] = cls.rgb2id[k]
85 | return mask[:, :, 0]
86 |
87 | @classmethod
88 | def decode_target(cls, target):
89 | target_rgb = np.zeros((target.shape[0], target.shape[1], 3), dtype=np.uint8)
90 | for k in cls.rgb2id:
91 | target_rgb[(target == cls.rgb2id[k])] = k
92 | return target_rgb
93 |
94 | if __name__ == "__main__":
95 | dataset = Iccv2009Dataset('/mnt/nasv2/hhz/DeepLabV3Plus-Pytorch-master/datasets/data/iccv09', 'train')
96 | # test mask shape and value
97 | for i in range(len(dataset)):
98 | img, mask = dataset[i]
99 | img = np.array(img)
100 | mask = np.array(mask)
101 | print(img.shape, mask.shape)
102 | print(np.unique(mask))
103 | if i == 10:
104 | break
105 |
--------------------------------------------------------------------------------
/SAM_segmentation/metrics/stream_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import confusion_matrix
3 |
4 | class _StreamMetrics(object):
5 | def __init__(self):
6 | """ Overridden by subclasses """
7 | raise NotImplementedError()
8 |
9 | def update(self, gt, pred):
10 | """ Overridden by subclasses """
11 | raise NotImplementedError()
12 |
13 | def get_results(self):
14 | """ Overridden by subclasses """
15 | raise NotImplementedError()
16 |
17 | def to_str(self, metrics):
18 | """ Overridden by subclasses """
19 | raise NotImplementedError()
20 |
21 | def reset(self):
22 | """ Overridden by subclasses """
23 | raise NotImplementedError()
24 |
25 | class StreamSegMetrics(_StreamMetrics):
26 | """
27 | Stream Metrics for Semantic Segmentation Task
28 | """
29 | def __init__(self, n_classes):
30 | self.n_classes = n_classes
31 | self.confusion_matrix = np.zeros((n_classes, n_classes))
32 |
33 | def update(self, label_trues, label_preds):
34 | for lt, lp in zip(label_trues, label_preds):
35 | self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
36 |
37 | @staticmethod
38 | def to_str(results):
39 | string = "\n"
40 | for k, v in results.items():
41 | if k!="Class IoU":
42 | string += "%s: %f\n"%(k, v)
43 |
44 | #string+='Class IoU:\n'
45 | #for k, v in results['Class IoU'].items():
46 | # string += "\tclass %d: %f\n"%(k, v)
47 | return string
48 |
49 | def _fast_hist(self, label_true, label_pred):
50 | mask = (label_true >= 0) & (label_true < self.n_classes)
51 | hist = np.bincount(
52 | self.n_classes * label_true[mask].astype(int) + label_pred[mask],
53 | minlength=self.n_classes ** 2,
54 | ).reshape(self.n_classes, self.n_classes)
55 | return hist
56 |
57 | def get_results(self):
58 | """Returns accuracy score evaluation result.
59 | - overall accuracy
60 | - mean accuracy
61 | - mean IU
62 | - fwavacc
63 | """
64 | hist = self.confusion_matrix
65 | acc = np.diag(hist).sum() / hist.sum()
66 | acc_cls = np.diag(hist) / hist.sum(axis=1)
67 | acc_cls = np.nanmean(acc_cls)
68 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
69 | mean_iu = np.nanmean(iu)
70 | freq = hist.sum(axis=1) / hist.sum()
71 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
72 | cls_iu = dict(zip(range(self.n_classes), iu))
73 |
74 | return {
75 | "Overall Acc": acc,
76 | "Mean Acc": acc_cls,
77 | "FreqW Acc": fwavacc,
78 | "Mean IoU": mean_iu,
79 | "Class IoU": cls_iu,
80 | }
81 |
82 | def reset(self):
83 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
84 |
85 | class AverageMeter(object):
86 | """Computes average values"""
87 | def __init__(self):
88 | self.book = dict()
89 |
90 | def reset_all(self):
91 | self.book.clear()
92 |
93 | def reset(self, id):
94 | item = self.book.get(id, None)
95 | if item is not None:
96 | item[0] = 0
97 | item[1] = 0
98 |
99 | def update(self, id, val):
100 | record = self.book.get(id, None)
101 | if record is None:
102 | self.book[id] = [val, 1]
103 | else:
104 | record[0]+=val
105 | record[1]+=1
106 |
107 | def get_results(self, id):
108 | record = self.book.get(id, None)
109 | assert record is not None
110 | return record[0] / record[1]
111 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class PreActBlock(nn.Module):
7 | '''Pre-activation version of the BasicBlock.'''
8 | expansion = 1
9 |
10 | def __init__(self, in_planes, planes, stride=1):
11 | super(PreActBlock, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
16 |
17 | if stride != 1 or in_planes != self.expansion*planes:
18 | self.shortcut = nn.Sequential(
19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
20 | )
21 |
22 | def forward(self, x):
23 | out = F.relu(self.bn1(x))
24 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
25 | out = self.conv1(out)
26 | out = self.conv2(F.relu(self.bn2(out)))
27 | out += shortcut
28 | return out
29 |
30 |
31 | class PreActBottleneck(nn.Module):
32 | '''Pre-activation version of the original Bottleneck module.'''
33 | expansion = 4
34 |
35 | def __init__(self, in_planes, planes, stride=1):
36 | super(PreActBottleneck, self).__init__()
37 | self.bn1 = nn.BatchNorm2d(in_planes)
38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
39 | self.bn2 = nn.BatchNorm2d(planes)
40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
41 | self.bn3 = nn.BatchNorm2d(planes)
42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
43 |
44 | if stride != 1 or in_planes != self.expansion*planes:
45 | self.shortcut = nn.Sequential(
46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
47 | )
48 |
49 | def forward(self, x):
50 | out = F.relu(self.bn1(x))
51 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
52 | out = self.conv1(out)
53 | out = self.conv2(F.relu(self.bn2(out)))
54 | out = self.conv3(F.relu(self.bn3(out)))
55 | out += shortcut
56 | return out
57 |
58 |
59 | class PreActResNet(nn.Module):
60 | def __init__(self, block, num_blocks, num_classes=10):
61 | super(PreActResNet, self).__init__()
62 | self.in_planes = 64
63 |
64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
65 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
66 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
67 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
68 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
69 | self.bn = nn.BatchNorm2d(512 * block.expansion)
70 | self.linear = nn.Linear(512*block.expansion, num_classes)
71 |
72 | def _make_layer(self, block, planes, num_blocks, stride):
73 | strides = [stride] + [1]*(num_blocks-1)
74 | layers = []
75 | for stride in strides:
76 | layers.append(block(self.in_planes, planes, stride))
77 | self.in_planes = planes * block.expansion
78 | return nn.Sequential(*layers)
79 |
80 | def forward(self, x):
81 | out = self.conv1(x)
82 | out = self.layer1(out)
83 | out = self.layer2(out)
84 | out = self.layer3(out)
85 | out = self.layer4(out)
86 | out = F.relu(self.bn(out))
87 | out = F.avg_pool2d(out, 4)
88 | out = out.view(out.size(0), -1)
89 | out = self.linear(out)
90 | return out
91 |
92 |
93 | def PreActResNet18(num_classes=10):
94 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
95 |
--------------------------------------------------------------------------------
/SAM_segmentation/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import errno
5 | from tqdm import tqdm
6 |
7 |
8 | def gen_bar_updater(pbar):
9 | def bar_update(count, block_size, total_size):
10 | if pbar.total is None and total_size:
11 | pbar.total = total_size
12 | progress_bytes = count * block_size
13 | pbar.update(progress_bytes - pbar.n)
14 |
15 | return bar_update
16 |
17 |
18 | def check_integrity(fpath, md5=None):
19 | if md5 is None:
20 | return True
21 | if not os.path.isfile(fpath):
22 | return False
23 | md5o = hashlib.md5()
24 | with open(fpath, 'rb') as f:
25 | # read in 1MB chunks
26 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
27 | md5o.update(chunk)
28 | md5c = md5o.hexdigest()
29 | if md5c != md5:
30 | return False
31 | return True
32 |
33 |
34 | def makedir_exist_ok(dirpath):
35 | """
36 | Python2 support for os.makedirs(.., exist_ok=True)
37 | """
38 | try:
39 | os.makedirs(dirpath)
40 | except OSError as e:
41 | if e.errno == errno.EEXIST:
42 | pass
43 | else:
44 | raise
45 |
46 |
47 | def download_url(url, root, filename=None, md5=None):
48 | """Download a file from a url and place it in root.
49 | Args:
50 | url (str): URL to download file from
51 | root (str): Directory to place downloaded file in
52 | filename (str): Name to save the file under. If None, use the basename of the URL
53 | md5 (str): MD5 checksum of the download. If None, do not check
54 | """
55 | from six.moves import urllib
56 |
57 | root = os.path.expanduser(root)
58 | if not filename:
59 | filename = os.path.basename(url)
60 | fpath = os.path.join(root, filename)
61 |
62 | makedir_exist_ok(root)
63 |
64 | # downloads file
65 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
66 | print('Using downloaded and verified file: ' + fpath)
67 | else:
68 | try:
69 | print('Downloading ' + url + ' to ' + fpath)
70 | urllib.request.urlretrieve(
71 | url, fpath,
72 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
73 | )
74 | except OSError:
75 | if url[:5] == 'https':
76 | url = url.replace('https:', 'http:')
77 | print('Failed download. Trying https -> http instead.'
78 | ' Downloading ' + url + ' to ' + fpath)
79 | urllib.request.urlretrieve(
80 | url, fpath,
81 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
82 | )
83 |
84 |
85 | def list_dir(root, prefix=False):
86 | """List all directories at a given root
87 | Args:
88 | root (str): Path to directory whose folders need to be listed
89 | prefix (bool, optional): If true, prepends the path to each result, otherwise
90 | only returns the name of the directories found
91 | """
92 | root = os.path.expanduser(root)
93 | directories = list(
94 | filter(
95 | lambda p: os.path.isdir(os.path.join(root, p)),
96 | os.listdir(root)
97 | )
98 | )
99 |
100 | if prefix is True:
101 | directories = [os.path.join(root, d) for d in directories]
102 |
103 | return directories
104 |
105 |
106 | def list_files(root, suffix, prefix=False):
107 | """List all files ending with a suffix at a given root
108 | Args:
109 | root (str): Path to directory whose folders need to be listed
110 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
111 | It uses the Python "str.endswith" method and is passed directly
112 | prefix (bool, optional): If true, prepends the path to each result, otherwise
113 | only returns the name of the files found
114 | """
115 | root = os.path.expanduser(root)
116 | files = list(
117 | filter(
118 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
119 | os.listdir(root)
120 | )
121 | )
122 |
123 | if prefix is True:
124 | files = [os.path.join(root, d) for d in files]
125 |
126 | return files
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/examples/evaluate_imagenet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import argparse
4 | import sys
5 | import os
6 | from torch import optim
7 | from torch.utils.data import DataLoader
8 | from torchvision.models import resnet50
9 | from torchvision.datasets import ImageNet
10 | from torchvision import transforms
11 |
12 | # mister_ed
13 | from recoloradv.mister_ed import loss_functions as lf
14 | from recoloradv.mister_ed import adversarial_training as advtrain
15 | from recoloradv.mister_ed import adversarial_perturbations as ap
16 | from recoloradv.mister_ed import adversarial_attacks as aa
17 | from recoloradv.mister_ed import spatial_transformers as st
18 | from recoloradv.mister_ed.utils import pytorch_utils as utils
19 |
20 | # ReColorAdv
21 | from recoloradv import perturbations as pt
22 | from recoloradv import color_transformers as ct
23 | from recoloradv import color_spaces as cs
24 |
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser(
28 | description='Evaluate a ResNet-50 trained on Imagenet '
29 | 'against ReColorAdv'
30 | )
31 |
32 | parser.add_argument('--imagenet_path', type=str, required=True,
33 | help='path to ImageNet dataset')
34 | parser.add_argument('--batch_size', type=int, default=100,
35 | help='number of examples/minibatch')
36 | parser.add_argument('--num_batches', type=int, required=False,
37 | help='number of batches (default entire dataset)')
38 | args = parser.parse_args()
39 |
40 | model = resnet50(pretrained=True, progress=True)
41 | normalizer = utils.DifferentiableNormalize(mean=[0.485, 0.456, 0.406],
42 | std=[0.229, 0.224, 0.225])
43 |
44 | dataset = ImageNet(
45 | args.imagenet_path,
46 | split='val',
47 | transform=transforms.Compose([
48 | transforms.CenterCrop(224),
49 | transforms.ToTensor(),
50 | ]),
51 | )
52 | val_loader = DataLoader(
53 | dataset,
54 | batch_size=args.batch_size,
55 | shuffle=True,
56 | )
57 |
58 | model.eval()
59 | if torch.cuda.is_available():
60 | model.cuda()
61 |
62 | cw_loss = lf.CWLossF6(model, normalizer, kappa=float('inf'))
63 | perturbation_loss = lf.PerturbationNormLoss(lp=2)
64 | adv_loss = lf.RegularizedLoss(
65 | {'cw': cw_loss, 'pert': perturbation_loss},
66 | {'cw': 1.0, 'pert': 0.05},
67 | negate=True,
68 | )
69 |
70 | pgd_attack = aa.PGD(
71 | model,
72 | normalizer,
73 | ap.ThreatModel(pt.ReColorAdv, {
74 | 'xform_class': ct.FullSpatial,
75 | 'cspace': cs.CIELUVColorSpace(),
76 | 'lp_style': 'inf',
77 | 'lp_bound': 0.06,
78 | 'xform_params': {
79 | 'resolution_x': 16,
80 | 'resolution_y': 32,
81 | 'resolution_z': 32,
82 | },
83 | 'use_smooth_loss': True,
84 | }),
85 | adv_loss,
86 | )
87 |
88 | batches_correct = []
89 | for batch_index, (inputs, labels) in enumerate(val_loader):
90 | if (
91 | args.num_batches is not None and
92 | batch_index >= args.num_batches
93 | ):
94 | break
95 |
96 | if torch.cuda.is_available():
97 | inputs = inputs.cuda()
98 | labels = labels.cuda()
99 |
100 | adv_inputs = pgd_attack.attack(
101 | inputs,
102 | labels,
103 | optimizer=optim.Adam,
104 | optimizer_kwargs={'lr': 0.001},
105 | signed=False,
106 | verbose=False,
107 | num_iterations=(100, 300),
108 | ).adversarial_tensors()
109 | with torch.no_grad():
110 | adv_logits = model(normalizer(adv_inputs))
111 | batch_correct = (adv_logits.argmax(1) == labels).detach()
112 |
113 | batch_accuracy = batch_correct.float().mean().item()
114 | print(f'BATCH {batch_index:05d}',
115 | f'accuracy = {batch_accuracy * 100:.1f}',
116 | sep='\t')
117 | batches_correct.append(batch_correct)
118 |
119 | accuracy = torch.cat(batches_correct).float().mean().item()
120 | print('OVERALL ',
121 | f'accuracy = {accuracy * 100:.1f}',
122 | sep='\t')
123 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 | from typing import Tuple
6 |
7 | from .mister_ed.cifar10 import cifar_resnets
8 | from .mister_ed.utils.pytorch_utils import DifferentiableNormalize
9 | from .mister_ed import config
10 | from .mister_ed import adversarial_perturbations as ap
11 | from .mister_ed import adversarial_attacks as aa
12 | from .mister_ed import spatial_transformers as st
13 | from .mister_ed import loss_functions as lf
14 | from .mister_ed import adversarial_training as advtrain
15 |
16 | from . import perturbations as pt
17 | from . import color_transformers as ct
18 | from . import color_spaces as cs
19 |
20 |
21 | def load_pretrained_cifar10_model(
22 | path: str, resnet_size: int = 32,
23 | ) -> Tuple[nn.Module, DifferentiableNormalize]:
24 | """
25 | Loads a pretrained CIFAR-10 ResNet from the given path along with its
26 | associated normalizer.
27 | """
28 |
29 | model: nn.Module = getattr(cifar_resnets, f'resnet{resnet_size}')()
30 | model_state = torch.load(path, map_location=torch.device('cpu'))
31 | model.load_state_dict({re.sub(r'^module\.', '', k): v for k, v in
32 | model_state['state_dict'].items()})
33 |
34 | normalizer = DifferentiableNormalize(
35 | mean=config.CIFAR10_MEANS,
36 | std=config.CIFAR10_STDS,
37 | )
38 |
39 | return model, normalizer
40 |
41 |
42 | def get_attack_from_name(
43 | name: str,
44 | classifier: nn.Module,
45 | normalizer: DifferentiableNormalize,
46 | verbose: bool = False,
47 | ) -> advtrain.AdversarialAttackParameters:
48 | """
49 | Builds an attack from a name like "recoloradv" or "stadv+delta" or
50 | "recoloradv+stadv+delta".
51 | """
52 |
53 | threats = []
54 | norm_weights = []
55 |
56 | for attack_part in name.split('+'):
57 | if attack_part == 'delta':
58 | threats.append(ap.ThreatModel(
59 | ap.DeltaAddition,
60 | ap.PerturbationParameters(
61 | lp_style='inf',
62 | lp_bound=1.0 / 255,
63 | ),
64 | ))
65 | norm_weights.append(0.0)
66 | elif attack_part == 'stadv':
67 | threats.append(ap.ThreatModel(
68 | ap.ParameterizedXformAdv,
69 | ap.PerturbationParameters(
70 | lp_style='inf',
71 | lp_bound=1.0 / 255,
72 | xform_class=st.FullSpatial,
73 | use_stadv=True,
74 | ),
75 | ))
76 | norm_weights.append(1.0)
77 | elif attack_part == 'recoloradv':
78 | threats.append(ap.ThreatModel(
79 | pt.ReColorAdv,
80 | ap.PerturbationParameters(
81 | lp_style='inf',
82 | lp_bound=[8.0/255, 8.0/255, 8.0/255],
83 | xform_params={
84 | 'resolution_x': 16,
85 | 'resolution_y': 32,
86 | 'resolution_z': 32,
87 | },
88 | xform_class=ct.FullSpatial,
89 | use_smooth_loss=True,
90 | cspace=cs.CIELUVColorSpace(),
91 | ),
92 | ))
93 | norm_weights.append(1.0)
94 | else:
95 | raise ValueError(f'Invalid attack "{attack_part}"')
96 |
97 | sequence_threat = ap.ThreatModel(
98 | ap.SequentialPerturbation,
99 | threats,
100 | ap.PerturbationParameters(norm_weights=norm_weights),
101 | )
102 |
103 | # use PGD attack
104 | adv_loss = lf.CWLossF6(classifier, normalizer, kappa=float('inf'))
105 | st_loss = lf.PerturbationNormLoss(lp=2)
106 | loss_fxn = lf.RegularizedLoss({'adv': adv_loss, 'pert': st_loss},
107 | {'adv': 1.0, 'pert': 0.05},
108 | negate=True)
109 |
110 | pgd_attack = aa.PGD(classifier, normalizer, sequence_threat, loss_fxn)
111 | return advtrain.AdversarialAttackParameters(
112 | pgd_attack,
113 | 1.0,
114 | attack_specific_params={'attack_kwargs': {
115 | 'num_iterations': 10,
116 | 'optimizer': optim.Adam,
117 | 'optimizer_kwargs': {'lr': 0.001},
118 | 'signed': False,
119 | 'verbose': verbose,
120 | }},
121 | )
122 |
--------------------------------------------------------------------------------
/SAM_segmentation/network/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 |
7 | class _SimpleSegmentationModel(nn.Module):
8 | def __init__(self, backbone, classifier):
9 | super(_SimpleSegmentationModel, self).__init__()
10 | self.backbone = backbone
11 | self.classifier = classifier
12 |
13 | def forward(self, x):
14 | input_shape = x.shape[-2:]
15 | features = self.backbone(x)
16 | x = self.classifier(features)
17 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
18 | return x
19 |
20 |
21 | class IntermediateLayerGetter(nn.ModuleDict):
22 | """
23 | Module wrapper that returns intermediate layers from a model
24 |
25 | It has a strong assumption that the modules have been registered
26 | into the model in the same order as they are used.
27 | This means that one should **not** reuse the same nn.Module
28 | twice in the forward if you want this to work.
29 |
30 | Additionally, it is only able to query submodules that are directly
31 | assigned to the model. So if `model` is passed, `model.feature1` can
32 | be returned, but not `model.feature1.layer2`.
33 |
34 | Arguments:
35 | model (nn.Module): model on which we will extract the features
36 | return_layers (Dict[name, new_name]): a dict containing the names
37 | of the modules for which the activations will be returned as
38 | the key of the dict, and the value of the dict is the name
39 | of the returned activation (which the user can specify).
40 |
41 | Examples::
42 |
43 | >>> m = torchvision.models.resnet18(pretrained=True)
44 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
45 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
46 | >>> {'layer1': 'feat1', 'layer3': 'feat2'})
47 | >>> out = new_m(torch.rand(1, 3, 224, 224))
48 | >>> print([(k, v.shape) for k, v in out.items()])
49 | >>> [('feat1', torch.Size([1, 64, 56, 56])),
50 | >>> ('feat2', torch.Size([1, 256, 14, 14]))]
51 | """
52 | def __init__(self, model, return_layers, hrnet_flag=False):
53 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
54 | raise ValueError("return_layers are not present in model")
55 |
56 | self.hrnet_flag = hrnet_flag
57 |
58 | orig_return_layers = return_layers
59 | return_layers = {k: v for k, v in return_layers.items()}
60 | layers = OrderedDict()
61 | for name, module in model.named_children():
62 | layers[name] = module
63 | if name in return_layers:
64 | del return_layers[name]
65 | if not return_layers:
66 | break
67 |
68 | super(IntermediateLayerGetter, self).__init__(layers)
69 | self.return_layers = orig_return_layers
70 |
71 | def forward(self, x):
72 | out = OrderedDict()
73 | for name, module in self.named_children():
74 | if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
75 | if name == 'transition1': # in transition1, you need to split the module to two streams first
76 | x = [trans(x) for trans in module]
77 | else: # all other transition is just an extra one stream split
78 | x.append(module(x[-1]))
79 | else: # other models (ex:resnet,mobilenet) are convolutions in series.
80 | x = module(x)
81 |
82 | if name in self.return_layers:
83 | out_name = self.return_layers[name]
84 | if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
85 | output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
86 | x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
87 | x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
88 | x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
89 | x = torch.cat([x[0], x1, x2, x3], dim=1)
90 | out[out_name] = x
91 | else:
92 | out[out_name] = x
93 | return out
94 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/perturbations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .mister_ed import adversarial_perturbations as ap
5 | from .mister_ed.adversarial_perturbations import initialized
6 | from .mister_ed.utils import pytorch_utils as utils
7 |
8 | from . import color_transformers as ct
9 | from . import norms
10 | from . import color_spaces as cs
11 |
12 |
13 | class ReColorAdv(ap.AdversarialPerturbation):
14 | """
15 | Puts the color at each pixel in the image through the same transformation.
16 |
17 | Parameters:
18 | - lp_style: number or 'inf'
19 | - lp_bound: maximum norm of color transformation. Can be a tensor of size
20 | (num_channels,), in which case each channel will be bounded by the
21 | cooresponding bound in the tensor. For instance, passing
22 | [0.1, 0.15, 0.05] would allow a norm of 0.1 for R, 0.15 for G, and 0.05
23 | for B. Not supported by all transformations.
24 | - use_smooth_loss: whether to optimize using the loss function
25 | for FullSpatial that rewards smooth vector fields
26 | - xform_class: a subclass of
27 | color_transformers.ParameterizedTransformation
28 | - xform_params: dict of parameters to pass to the xform_class.
29 | - cspace_class: a subclass of color_spaces.ColorSpace that indicates
30 | in which color space the transformation should be performed
31 | (RGB by default)
32 | """
33 |
34 | def __init__(self, threat_model, perturbation_params, *other_args):
35 | super().__init__(threat_model, perturbation_params)
36 | assert issubclass(perturbation_params.xform_class,
37 | ct.ParameterizedTransformation)
38 |
39 | self.lp_style = perturbation_params.lp_style
40 | self.lp_bound = perturbation_params.lp_bound
41 | self.use_smooth_loss = perturbation_params.use_smooth_loss
42 | self.scalar_step = perturbation_params.scalar_step or 1.0
43 | self.cspace = perturbation_params.cspace or cs.RGBColorSpace()
44 |
45 | def _merge_setup(self, num_examples, new_xform):
46 | """ DANGEROUS TO BE CALLED OUTSIDE OF THIS FILE!!!"""
47 | self.num_examples = num_examples
48 | self.xform = new_xform
49 | self.initialized = True
50 |
51 | def setup(self, originals):
52 | super().setup(originals)
53 | self.xform = self.perturbation_params.xform_class(
54 | shape=originals.shape, manual_gpu=self.use_gpu,
55 | cspace=self.cspace,
56 | **(self.perturbation_params.xform_params or {}),
57 | )
58 | self.initialized = True
59 |
60 | @initialized
61 | def perturbation_norm(self, x=None, lp_style=None):
62 | lp_style = lp_style or self.lp_style
63 | if self.use_smooth_loss:
64 | assert isinstance(self.xform, ct.FullSpatial)
65 | return self.xform.smoothness_norm()
66 | else:
67 | return self.xform.norm(lp=lp_style)
68 |
69 | @initialized
70 | def constrain_params(self, x=None):
71 | # Do lp projections
72 | if isinstance(self.lp_style, int) or self.lp_style == 'inf':
73 | self.xform.project_params(self.lp_style, self.lp_bound)
74 |
75 | @initialized
76 | def update_params(self, step_fxn):
77 | param_list = list(self.xform.parameters())
78 | assert len(param_list) == 1
79 | params = param_list[0]
80 | assert params.grad.data is not None
81 | self.add_to_params(step_fxn(params.grad.data) * self.scalar_step)
82 |
83 | @initialized
84 | def add_to_params(self, grad_data):
85 | """ Assumes only one parameters object in the Spatial Transform """
86 | param_list = list(self.xform.parameters())
87 | assert len(param_list) == 1
88 | params = param_list[0]
89 | params.data.add_(grad_data)
90 |
91 | @initialized
92 | def random_init(self):
93 | param_list = list(self.xform.parameters())
94 | assert len(param_list) == 1
95 | param = param_list[0]
96 | random_perturb = utils.random_from_lp_ball(param.data,
97 | self.lp_style,
98 | self.lp_bound)
99 |
100 | param.data.add_(self.xform.identity_params +
101 | random_perturb - self.xform.xform_params.data)
102 |
103 | @initialized
104 | def merge_perturbation(self, other, self_mask):
105 | super().merge_perturbation(other, self_mask)
106 | new_perturbation = ReColorAdv(self.threat_model,
107 | self.perturbation_params)
108 |
109 | new_xform = self.xform.merge_xform(other.xform, self_mask)
110 | new_perturbation._merge_setup(self.num_examples, new_xform)
111 |
112 | return new_perturbation
113 |
114 | def forward(self, x):
115 | if not self.initialized:
116 | self.setup(x)
117 | self.constrain_params()
118 |
119 | return self.cspace.to_rgb(
120 | self.xform.forward(self.cspace.from_rgb(x)))
121 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import dataset, dataloader
4 | from torchvision import datasets, transforms
5 |
6 | cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
7 | cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255
8 |
9 | mu = torch.tensor(cifar10_mean).view(3,1,1)
10 | std = torch.tensor(cifar10_std).view(3,1,1)
11 |
12 | def normalize_cifar(x):
13 | return (x - mu.to(x.device))/(std.to(x.device))
14 |
15 | CIFAR100_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
16 | CIFAR100_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
17 |
18 | mu_cifar100 = torch.tensor(CIFAR100_MEAN).view(3,1,1).cuda()
19 | std_cifar100 = torch.tensor(CIFAR100_STD).view(3,1,1).cuda()
20 |
21 | def normalize_cifar100(x):
22 | return (x - mu_cifar100.to(x.device))/(std_cifar100.to(x.device))
23 |
24 | def load_dataset(dataset='cifar10', batch_size=128):
25 | if dataset == 'cifar10':
26 | transform_ = transforms.Compose([transforms.ToTensor()])
27 | train_transform_ = transforms.Compose([
28 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.ToTensor()])
31 |
32 | train_loader = torch.utils.data.DataLoader(
33 | datasets.CIFAR10('/data/cifar_data', train=True, download=True, transform=train_transform_),
34 | batch_size=batch_size, shuffle=True)
35 |
36 | test_loader = torch.utils.data.DataLoader(
37 | datasets.CIFAR10('/data/cifar_data', train=False, download=True, transform=transform_),
38 | batch_size=batch_size, shuffle=False)
39 |
40 | return train_loader, test_loader
41 |
42 | elif dataset == 'cifar100':
43 | transform_ = transforms.Compose([transforms.ToTensor()])
44 | train_transform_ = transforms.Compose([
45 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
46 | transforms.RandomHorizontalFlip(),
47 | transforms.ToTensor()])
48 | train_loader = torch.utils.data.DataLoader(
49 | datasets.CIFAR100('/data/cifar_data', train=True, download=True, transform=train_transform_),
50 | batch_size=batch_size, shuffle=True)
51 | test_loader = torch.utils.data.DataLoader(
52 | datasets.CIFAR100('/data/cifar_data', train=False, download=True, transform=transform_),
53 | batch_size=batch_size, shuffle=False)
54 |
55 | return train_loader, test_loader
56 |
57 | class Attack():
58 | def __init__(self, iters, alpha, eps, norm, criterion, rand_init, rand_perturb, targeted, normalize=normalize_cifar):
59 | self.iters = iters
60 | self.alpha = alpha
61 | self.eps = eps
62 | self.norm = norm
63 | assert norm in ['linf', 'l2']
64 | self.criterion = criterion # loss function for perturb
65 | self.rand_init = rand_init # random initialization before perturb
66 | self.rand_perturb = rand_perturb # add random noise in each step
67 | self.targetd = targeted # targeted attack
68 | self.normalize = normalize # normalize_cifar
69 |
70 | def perturb(self, model, x, y):
71 | delta = torch.zeros_like(x).to(x.device)
72 | if self.rand_init:
73 |
74 | if self.norm == "linf":
75 | delta.uniform_(-self.eps, self.eps)
76 | elif self.norm == "l2":
77 | delta.normal_()
78 | d_flat = delta.view(delta.size(0),-1)
79 | n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
80 | r = torch.zeros_like(n).uniform_(0, 1)
81 | delta *= r/n*self.eps
82 | else:
83 | raise ValueError
84 |
85 | delta = torch.clamp(delta, 0-x, 1-x)
86 | delta.requires_grad = True
87 |
88 | for _ in range(self.iters):
89 | output = model(self.normalize(x+delta))
90 | loss = self.criterion(output, y)
91 | if self.targetd:
92 | loss *= -1
93 | loss.backward()
94 | g = delta.grad.detach()
95 | if self.norm == "linf":
96 | d = torch.clamp(delta + self.alpha * torch.sign(g), min=-self.eps, max=self.eps).detach()
97 | elif self.norm == "l2":
98 | g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
99 | scaled_g = g/(g_norm + 1e-10)
100 | d = (delta + scaled_g*self.alpha).view(delta.size(0),-1).renorm(p=2,dim=0,maxnorm=self.eps).view_as(delta).detach()
101 | d = torch.clamp(d, 0 - x, 1 - x)
102 | delta.data = d
103 | delta.grad.zero_()
104 |
105 | return delta.detach()
106 |
107 | class PGD(Attack):
108 | def __init__(self, iters, alpha, eps, norm, targeted=False, normalize=normalize_cifar):
109 | super().__init__(iters, alpha, eps, norm, nn.CrossEntropyLoss(), True, False, targeted, normalize=normalize)
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/cifar10/cifar_resnets.py:
--------------------------------------------------------------------------------
1 | '''
2 | MISTER_ED_NOTE: I blatantly copied this code from this github repository:
3 | https://github.com/akamaster/pytorch_resnet_cifar10
4 |
5 | Huge kudos to Yerlan Idelbayev.
6 | '''
7 |
8 |
9 |
10 | '''
11 | Properly implemented ResNet-s for CIFAR10 as described in paper [1].
12 |
13 | The implementation and structure of this file is hugely influenced by [2]
14 | which is implemented for ImageNet and doesn't have option A for identity.
15 | Moreover, most of the implementations on the web is copy-paste from
16 | torchvision's resnet and has wrong number of params.
17 |
18 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
19 | number of layers and parameters:
20 |
21 | name | layers | params
22 | ResNet20 | 20 | 0.27M
23 | ResNet32 | 32 | 0.46M
24 | ResNet44 | 44 | 0.66M
25 | ResNet56 | 56 | 0.85M
26 | ResNet110 | 110 | 1.7M
27 | ResNet1202| 1202 | 19.4m
28 |
29 | which this implementation indeed has.
30 |
31 | Reference:
32 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
33 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
34 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
35 |
36 | If you use this implementation in you work, please don't forget to mention the
37 | author, Yerlan Idelbayev.
38 | '''
39 | import torch
40 | import torch.nn as nn
41 | import torch.nn.functional as F
42 | import torch.nn.init as init
43 |
44 | from torch.autograd import Variable
45 |
46 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
47 |
48 | def _weights_init(m):
49 | classname = m.__class__.__name__
50 | # print(classname)
51 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
52 | try:
53 | init.kaiming_normal_(m.weight)
54 | except AttributeError:
55 | init.kaiming_normal(m.weight)
56 |
57 | class LambdaLayer(nn.Module):
58 | def __init__(self, lambd):
59 | super(LambdaLayer, self).__init__()
60 | self.lambd = lambd
61 |
62 | def forward(self, x):
63 | return self.lambd(x)
64 |
65 |
66 | class BasicBlock(nn.Module):
67 | expansion = 1
68 |
69 | def __init__(self, in_planes, planes, stride=1, option='A'):
70 | super(BasicBlock, self).__init__()
71 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
72 | self.bn1 = nn.BatchNorm2d(planes)
73 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
74 | self.bn2 = nn.BatchNorm2d(planes)
75 |
76 | self.shortcut = nn.Sequential()
77 | if stride != 1 or in_planes != planes:
78 | if option == 'A':
79 | """
80 | For CIFAR10 ResNet paper uses option A.
81 | """
82 | self.shortcut = LambdaLayer(lambda x:
83 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
84 | elif option == 'B':
85 | self.shortcut = nn.Sequential(
86 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
87 | nn.BatchNorm2d(self.expansion * planes)
88 | )
89 |
90 | def forward(self, x):
91 | out = F.relu(self.bn1(self.conv1(x)))
92 | out = self.bn2(self.conv2(out))
93 | out += self.shortcut(x)
94 | out = F.relu(out)
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 | def __init__(self, block, num_blocks, num_classes=10):
100 | super(ResNet, self).__init__()
101 | self.in_planes = 16
102 |
103 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
104 | self.bn1 = nn.BatchNorm2d(16)
105 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
106 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
107 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
108 | self.linear = nn.Linear(64, num_classes)
109 | self.apply(_weights_init)
110 |
111 |
112 | def _make_layer(self, block, planes, num_blocks, stride):
113 | strides = [stride] + [1]*(num_blocks-1)
114 | layers = []
115 | for stride in strides:
116 | layers.append(block(self.in_planes, planes, stride))
117 | self.in_planes = planes * block.expansion
118 |
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | out = F.relu(self.bn1(self.conv1(x)))
123 | out = self.layer1(out)
124 | out = self.layer2(out)
125 | out = self.layer3(out)
126 | out = F.avg_pool2d(out, out.size()[3])
127 | out = out.view(out.size(0), -1)
128 | out = self.linear(out)
129 | return out
130 |
131 |
132 | def resnet20():
133 | return ResNet(BasicBlock, [3, 3, 3])
134 |
135 |
136 | def resnet32():
137 | return ResNet(BasicBlock, [5, 5, 5])
138 |
139 |
140 | def resnet44():
141 | return ResNet(BasicBlock, [7, 7, 7])
142 |
143 |
144 | def resnet56():
145 | return ResNet(BasicBlock, [9, 9, 9])
146 |
147 |
148 | def resnet110():
149 | return ResNet(BasicBlock, [18, 18, 18])
150 |
151 |
152 | def resnet1202():
153 | return ResNet(BasicBlock, [200, 200, 200])
154 |
155 |
156 | def test(net):
157 | import numpy as np
158 | total_params = 0
159 |
160 | for x in filter(lambda p: p.requires_grad, net.parameters()):
161 | total_params += np.prod(x.data.numpy().shape)
162 | print("Total number of params", total_params)
163 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
164 |
--------------------------------------------------------------------------------
/train_eval_scripts/stadv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torchvision import transforms, datasets
8 | from scipy import optimize
9 | from utils import *
10 | from model import PreActResNet18
11 |
12 |
13 |
14 | def flow_st(images, flows):
15 | images_shape = images.size()
16 | flows_shape = flows.size()
17 | batch_size = images_shape[0]
18 | H = images_shape[2]
19 | W = images_shape[3]
20 | basegrid = torch.stack(torch.meshgrid(torch.arange(0, H), torch.arange(0, W))) # (2,H,W)
21 | sampling_grid = basegrid.unsqueeze(0).type(torch.float32).cuda() + flows.cuda()
22 | sampling_grid_x = torch.clamp(sampling_grid[:, 1], 0.0, W - 1.0).type(torch.float32)
23 | sampling_grid_y = torch.clamp(sampling_grid[:, 0], 0.0, H - 1.0).type(torch.float32)
24 |
25 | x0 = torch.floor(sampling_grid_x).type(torch.int64)
26 | x1 = x0 + 1
27 | y0 = torch.floor(sampling_grid_y).type(torch.int64)
28 | y1 = y0 + 1
29 |
30 | x0 = torch.clamp(x0, 0, W - 2)
31 | x1 = torch.clamp(x1, 0, W - 1)
32 | y0 = torch.clamp(y0, 0, H - 2)
33 | y1 = torch.clamp(y1, 0, H - 1)
34 |
35 | Ia = images[:, :, y0[0, :, :], x0[0, :, :]]
36 | Ib = images[:, :, y1[0, :, :], x0[0, :, :]]
37 | Ic = images[:, :, y0[0, :, :], x1[0, :, :]]
38 | Id = images[:, :, y1[0, :, :], x1[0, :, :]]
39 |
40 | x0 = x0.type(torch.float32)
41 | x1 = x1.type(torch.float32)
42 | y0 = y0.type(torch.float32)
43 | y1 = y1.type(torch.float32)
44 |
45 | wa = (x1 - sampling_grid_x) * (y1 - sampling_grid_y)
46 | wb = (x1 - sampling_grid_x) * (sampling_grid_y - y0)
47 | wc = (sampling_grid_x - x0) * (y1 - sampling_grid_y)
48 | wd = (sampling_grid_x - x0) * (sampling_grid_y - y0)
49 |
50 | perturbed_image = wa.unsqueeze(0) * Ia + wb.unsqueeze(0) * Ib + wc.unsqueeze(0) * Ic + wd.unsqueeze(0) * Id
51 |
52 | return perturbed_image.type(torch.float32).cuda()
53 |
54 |
55 | def flow_loss(flows, padding_mode='constant', epsilon=1e-8):
56 | paddings = (1, 1, 1, 1)
57 | padded_flows = F.pad(flows, paddings, mode=padding_mode, value=0)
58 | shifted_flows = [
59 | padded_flows[:, :, 2:, 2:], # bottom right (+1,+1)
60 | padded_flows[:, :, 2:, :-2], # bottom left (+1,-1)
61 | padded_flows[:, :, :-2, 2:], # top right (-1,+1)
62 | padded_flows[:, :, :-2, :-2] # top left (-1,-1)
63 | ]
64 | # ||\Delta u^{(p)} - \Delta u^{(q)}||_2^2 + # ||\Delta v^{(p)} - \Delta v^{(q)}||_2^2
65 | loss = 0
66 | for shifted_flow in shifted_flows:
67 | loss += torch.sum(torch.square(flows[:, 1] - shifted_flow[:, 1]) + torch.square(
68 | flows[:, 0] - shifted_flow[:, 0]) + epsilon).cuda()
69 | return loss.type(torch.float32)
70 |
71 |
72 | def adv_loss(logits, targets, confidence=0.0):
73 | confidence = torch.tensor(confidence).cuda()
74 | real = torch.sum(logits * targets, -1)
75 | other = torch.max((1 - targets) * logits - (targets * 10000), -1)[0]
76 | return torch.max(other - real, confidence)[0].type(torch.float32)
77 |
78 |
79 | def func(flows, input, target, model, const=0.05):
80 | input = torch.from_numpy(input).cuda()
81 | target = torch.from_numpy(target).cuda()
82 | flows = torch.from_numpy(flows).view((1, 2,) + input.size()[2:]).cuda()
83 | flows.requires_grad = True
84 | pert_out = flow_st(input, flows)
85 | output = model(pert_out)
86 | L_flow = flow_loss(flows)
87 | L_adv = adv_loss(output, target)
88 | L_final = L_adv + const * L_flow
89 | model.zero_grad()
90 | L_final.backward()
91 | gradient = flows.grad.data.view(-1).detach().cpu().numpy()
92 | return L_final.item(), gradient
93 |
94 |
95 | def attack(input, target, model):
96 | init_flows = np.zeros((1, 2,) + input.size()[2:]).reshape(-1)
97 | results = optimize.fmin_l_bfgs_b(func, init_flows, args=(input.cpu().numpy(), target.cpu().numpy(), model))
98 | flows = torch.from_numpy(results[0]).view((1, 2,) + input.size()[2:])
99 | pert_out = flow_st(input, flows)
100 | return pert_out
101 |
102 |
103 | class Model(nn.Module):
104 | def __init__(self, model, norm):
105 | super(Model, self).__init__()
106 | self.model = model
107 | self.norm = norm
108 |
109 | def forward(self, x):
110 | return self.model(self.norm(x))
111 |
112 | if __name__ == '__main__':
113 | np.random.seed(42)
114 | torch.manual_seed(42)
115 | model = PreActResNet18(10)
116 | model.load_state_dict(torch.load('./cifar10_models/cifar10_prn_sgd_sub.pth'))
117 | Mod = Model(model, normalize_cifar)
118 | Mod.eval()
119 | Mod.cuda()
120 |
121 | use_cuda = True
122 | device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
123 |
124 | train_loader, test_loader = load_dataset('cifar10', 1)
125 | norm = normalize_cifar
126 |
127 | adv = []
128 | adv_label = []
129 | correct_label = []
130 | sample = 10000
131 | success = 0
132 | target_s = 0
133 | for i, (x, y) in enumerate(test_loader):
134 | x, y = x.cuda(), y.cuda()
135 | # y : [x] -> [x+1 mod 10]
136 | target = (y + 1) % 10
137 | pert_out = attack(x, target, model)
138 | if pert_out is not None:
139 | output = model(pert_out)
140 | success += (output.max(1)[1] != y).float().sum().item()
141 | target_s += (output.max(1)[1] == target).float().sum().item()
142 | print(output, y, target)
143 | else:
144 | break
145 |
146 | print('success: ', success, 'sample: ', i, 'target: ', target_s)
147 | print(success / sample)
148 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import pandas as pd
6 |
7 | import argparse
8 | from time import time
9 |
10 | from utils import *
11 | from model import PreActResNet18
12 | from sam import SAM
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--fname', type=str, required=True)
18 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100'])
19 | parser.add_argument('--epochs', default=100, type=int)
20 | parser.add_argument('--max-lr', default=0.1, type=float)
21 | parser.add_argument('--opt', default='SAM', choices=['SAM', 'SGD'])
22 | parser.add_argument('--batch-size', default=128, type=int)
23 | parser.add_argument('--device', default=0, type=int)
24 | parser.add_argument('--adv', action='store_true')
25 | parser.add_argument('--rho', default=0.05, type=float) # for SAM
26 |
27 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2'])
28 | parser.add_argument('--train-eps', default=8., type=float)
29 | parser.add_argument('--train-alpha', default=2., type=float)
30 | parser.add_argument('--train-step', default=5, type=int)
31 |
32 | parser.add_argument('--test-eps', default=1., type=float)
33 | parser.add_argument('--test-alpha', default=0.5, type=float)
34 | parser.add_argument('--test-step', default=5, type=int)
35 | return parser.parse_args()
36 |
37 | args = get_args()
38 |
39 | def lr_schedule(epoch):
40 | if epoch < args.epochs * 0.75:
41 | return args.max_lr
42 | elif epoch < args.epochs * 0.9:
43 | return args.max_lr * 0.1
44 | else:
45 | return args.max_lr * 0.01
46 |
47 | if __name__ == '__main__':
48 |
49 | dataset = args.dataset
50 | device = f'cuda:{args.device}'
51 | model = PreActResNet18(10 if dataset == 'cifar10' else 100).to(device)
52 | train_loader, test_loader = load_dataset(dataset, args.batch_size)
53 | params = model.parameters()
54 | criterion = nn.CrossEntropyLoss()
55 |
56 |
57 | if args.opt == 'SGD':
58 | opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4)
59 | elif args.opt == 'SAM':
60 | base_opt = torch.optim.SGD
61 | opt = SAM(params, base_opt,lr=args.max_lr, momentum=0.9, weight_decay=5e-4, rho=args.rho)
62 | normalize = normalize_cifar if dataset == 'cifar10' else normalize_cifar100
63 |
64 | all_log_data = []
65 | train_pgd = PGD(args.train_step, args.train_alpha / 255., args.train_eps / 255., args.norm, False, normalize)
66 | test_pgd = PGD(args.test_step, args.test_alpha / 255., args.test_eps / 255., args.norm, False, normalize)
67 |
68 | for epoch in range(args.epochs):
69 | start_time = time()
70 | log_data = [0,0,0,0,0,0] # train_loss, train_acc, test_loss, test_acc, test_robust_loss, test_robust
71 | # train
72 | model.train()
73 | lr = lr_schedule(epoch)
74 | opt.param_groups[0].update(lr=lr)
75 | for x, y in train_loader:
76 | x, y = x.to(device), y.to(device)
77 | if args.adv:
78 | delta = train_pgd.perturb(model, x, y)
79 | else:
80 | delta = torch.zeros_like(x).to(x.device)
81 |
82 | output = model(normalize(x + delta))
83 | loss = criterion(output, y)
84 |
85 | if args.opt == 'SGD':
86 | opt.zero_grad()
87 | loss.backward()
88 | opt.step()
89 |
90 | elif args.opt == 'SAM':
91 | loss.backward()
92 | opt.first_step(zero_grad=True)
93 |
94 | output_2 = model(normalize(x + delta))
95 | criterion(output_2, y).backward()
96 | opt.second_step(zero_grad=True)
97 |
98 | log_data[0] += (loss * len(y)).item()
99 | log_data[1] += (output.max(1)[1] == y).float().sum().item()
100 |
101 | # test
102 | model.eval()
103 | for x, y in test_loader:
104 |
105 | x, y = x.to(device), y.to(device)
106 | # clean
107 | output = model(normalize(x)).detach()
108 | loss = criterion(output, y)
109 |
110 | log_data[2] += (loss * len(y)).item()
111 | log_data[3] += (output.max(1)[1] == y).float().sum().item()
112 | continue
113 | delta = test_pgd.perturb(model, x, y)
114 | output = model(normalize(x + delta)).detach()
115 | loss = criterion(output, y)
116 |
117 | log_data[4] += (loss * len(y)).item()
118 | log_data[5] += (output.max(1)[1] == y).float().sum().item()
119 |
120 | log_data = np.array(log_data)
121 | log_data[0] /= 60000
122 | log_data[1] /= 60000
123 | log_data[2] /= 10000
124 | log_data[3] /= 10000
125 | log_data[4] /= 10000
126 | log_data[5] /= 10000
127 | all_log_data.append(log_data)
128 |
129 | print(f'Epoch {epoch}:\t',log_data,f'\tTime {time()-start_time:.1f}s')
130 | torch.save(model.state_dict(), f'models/{args.fname}.pth' if args.dataset == 'cifar10' else f'cifar100_models/{args.fname}.pth')
131 |
132 | all_log_data = np.stack(all_log_data,axis=0)
133 |
134 | df = pd.DataFrame(all_log_data)
135 | df.to_csv(f'logs/{args.fname}.csv')
136 |
137 |
138 | plt.plot(all_log_data[:, [2,4]])
139 | plt.grid()
140 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Loss', fontsize=16)
141 | plt.legend(['clean', 'robust'], fontsize=16)
142 | plt.savefig(f'figs/{args.fname}_loss.png', dpi=200)
143 | plt.clf()
144 |
145 | plt.plot(all_log_data[:, [3,5]])
146 | plt.grid()
147 | #plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Acc', fontsize=16)
148 | plt.legend(['clean', 'robust'], fontsize=16)
149 | plt.savefig(f'figs/{args.fname}_acc.png', dpi=200)
150 | plt.clf()
151 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/scripts/setup_cifar.py:
--------------------------------------------------------------------------------
1 | """ Script to ensure that:
2 | 1) all dependencies are installed correctly
3 | 2) CIFAR data can be accessed locally
4 | 3) a functional classifier for CIFAR has been loaded.
5 |
6 | """
7 |
8 |
9 | ##############################################################################
10 | # #
11 | # STEP ONE: DEPENDENCIES ARE INSTALLED #
12 | # #
13 | ##############################################################################
14 | from __future__ import print_function
15 | print("Checking imports...")
16 | import sys
17 | import os
18 | sys.path.append(os.path.abspath(os.path.split(os.path.split(__file__)[0])[0]))
19 |
20 | import torch
21 | import glob
22 | import numpy as np
23 | import math
24 | import config
25 | import torchvision.datasets as datasets
26 |
27 | try: #This block from: https://stackoverflow.com/a/17510727
28 | # For Python 3.0 and later
29 | from urllib.request import urlopen
30 | except ImportError:
31 | # Fall back to Python 2's urllib2
32 | from urllib2 import urlopen
33 |
34 | import hashlib
35 | print("...imports look okay!")
36 |
37 |
38 | ##############################################################################
39 | # #
40 | # STEP TWO: CIFAR DATA HAS BEEN LOADED #
41 | # #
42 | ##############################################################################
43 |
44 | def check_cifar_data_loaded():
45 | print("Checking CIFAR10 data loaded...")
46 | dataset_dir = config.DEFAULT_DATASETS_DIR
47 |
48 | train_set = datasets.CIFAR10(root=dataset_dir, train=True, download=True)
49 | val_set = datasets.CIFAR10(root=dataset_dir, train=False, download=True)
50 |
51 | print("...CIFAR10 data looks okay!")
52 |
53 |
54 | check_cifar_data_loaded()
55 |
56 | ##############################################################################
57 | # #
58 | # STEP THREE: LOAD CLASSIFIER FOR CIFAR10 #
59 | # #
60 | ##############################################################################
61 |
62 |
63 | # https://stackoverflow.com/a/44873382
64 | def file_hash(filename):
65 | h = hashlib.sha256()
66 | with open(filename, 'rb', buffering=0) as f:
67 | for b in iter(lambda : f.read(128*1024), b''):
68 | h.update(b)
69 | return h.hexdigest()
70 |
71 |
72 |
73 | def load_cifar_classifiers():
74 | print("Checking CIFAR10 classifier exists...")
75 |
76 | # NOTE: pretrained models are produced by Yerlan Idelbayev
77 | # https://github.com/akamaster/pytorch_resnet_cifar10
78 | # I'm just hosting these on my dropbox for stability purposes
79 |
80 | # Check which models already exist in model directory
81 | resnet_name = lambda flavor: 'cifar10_resnet%s.th' % flavor
82 | total_cifar_files = set([resnet_name(flavor) for flavor in
83 | [1202, 110, 56, 44, 32, 20]])
84 | total_cifar_files.add('Wide-Resnet28x10')
85 |
86 | try:
87 | os.makedirs(config.MODEL_PATH)
88 | except OSError as err:
89 | if not os.path.isdir(config.MODEL_PATH):
90 | raise err
91 |
92 | extant_models = set([_.split('/')[-1] for _ in
93 | glob.glob(os.path.join(*[config.MODEL_PATH, '*']))])
94 |
95 | lacking_models = total_cifar_files - extant_models
96 |
97 | LINK_DEPOT = {resnet_name(20) : 'https://www.dropbox.com/s/glchyr9ljnpgvb5/cifar10_resnet20.th?dl=1',
98 | resnet_name(32) : 'https://www.dropbox.com/s/kis991c5w2qtgpq/cifar10_resnet32.th?dl=1',
99 | resnet_name(44) : 'https://www.dropbox.com/s/sigj56ysrti6s6a/cifar10_resnet44.th?dl=1',
100 | resnet_name(56) : 'https://www.dropbox.com/s/3p6d5tkvdgcbru5/c7ifar10_resnet56.th?dl=1',
101 | resnet_name(110) : 'https://www.dropbox.com/s/sp172x5vjlypfw6/cifar10_resnet110.th?dl=1',
102 | resnet_name(1202): 'https://www.dropbox.com/s/4qxfa6dmdliw9ko/cifar10_resnet1202.th?dl=1',
103 | 'Wide-Resnet28x10': 'https://www.dropbox.com/s/5ln2gow7mnxub29/cifar10_wide-resnet28x10.th?dl=1'
104 | }
105 |
106 |
107 | HASH_DEPOT = {resnet_name(20) : '12fca82f0bebc4135bf1f32f6e3710e61d5108578464b84fd6d7f5c1b04036c8',
108 | resnet_name(32) : 'd509ac1820d7f25398913559d7e81a13229b1e7adc5648e3bfa5e22dc137f850',
109 | resnet_name(44) : '014dd6541728a1c700b1642ab640e211dc6eb8ed507d70697458dc8f8a0ae2e4',
110 | resnet_name(56) : '4bfd97631478d6b638d2764fd2baff3edb1d7d82252d54439343b6596b9b5367',
111 | resnet_name(110) : '1d1ed7c27571399c1fef66969bc4df68d6a92c8e6c41170f444e120e5354e3bc',
112 | resnet_name(1202): 'f3b1deed382cd4c986ff8aa090c805d99a646e99d1f9227d7178183648844f62',
113 | 'Wide-Resnet28x10': 'd6a68ec2135294d91f9014abfdb45232d07fda0cdcd67f8c3b3653b28f08a88f'}
114 |
115 | for name in lacking_models:
116 | link = LINK_DEPOT[name]
117 | print("Downloading %s..." % name)
118 | u = urlopen(link)
119 | data = u.read()
120 | u.close()
121 | filename = os.path.join(config.MODEL_PATH, name)
122 | with open(filename, 'wb') as f:
123 | f.write(data)
124 |
125 | try:
126 | assert file_hash(filename) == HASH_DEPOT[name]
127 | except AssertionError as err:
128 | print("Something went wrong downloading %s" % name)
129 | os.remove(filename)
130 | raise err
131 |
132 | # Then load up all that doesn't already exist
133 |
134 | print("...CIFAR10 classifier looks okay")
135 |
136 |
137 |
138 | load_cifar_classifiers()
139 |
140 |
141 | print("\n Okay, you should be good to go now! ")
142 | print("Try running tutorial_{1,2,3}.ipynb in notebooks/")
143 |
144 |
--------------------------------------------------------------------------------
/SAM_segmentation/datasets/voc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tarfile
4 | import collections
5 | import torch.utils.data as data
6 | import shutil
7 | import numpy as np
8 |
9 | from PIL import Image
10 | from torchvision.datasets.utils import download_url, check_integrity
11 |
12 | DATASET_YEAR_DICT = {
13 | '2012': {
14 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15 | 'filename': 'VOCtrainval_11-May-2012.tar',
16 | 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17 | 'base_dir': 'VOCdevkit/VOC2012'
18 | },
19 | '2011': {
20 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21 | 'filename': 'VOCtrainval_25-May-2011.tar',
22 | 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23 | 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24 | },
25 | '2010': {
26 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27 | 'filename': 'VOCtrainval_03-May-2010.tar',
28 | 'md5': 'da459979d0c395079b5c75ee67908abb',
29 | 'base_dir': 'VOCdevkit/VOC2010'
30 | },
31 | '2009': {
32 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33 | 'filename': 'VOCtrainval_11-May-2009.tar',
34 | 'md5': '59065e4b188729180974ef6572f6a212',
35 | 'base_dir': 'VOCdevkit/VOC2009'
36 | },
37 | '2008': {
38 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39 | 'filename': 'VOCtrainval_11-May-2012.tar',
40 | 'md5': '2629fa636546599198acfcfbfcf1904a',
41 | 'base_dir': 'VOCdevkit/VOC2008'
42 | },
43 | '2007': {
44 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45 | 'filename': 'VOCtrainval_06-Nov-2007.tar',
46 | 'md5': 'c52e279531787c972589f7e41ab4ae64',
47 | 'base_dir': 'VOCdevkit/VOC2007'
48 | }
49 | }
50 |
51 |
52 | def voc_cmap(N=256, normalized=False):
53 | def bitget(byteval, idx):
54 | return ((byteval & (1 << idx)) != 0)
55 |
56 | dtype = 'float32' if normalized else 'uint8'
57 | cmap = np.zeros((N, 3), dtype=dtype)
58 | for i in range(N):
59 | r = g = b = 0
60 | c = i
61 | for j in range(8):
62 | r = r | (bitget(c, 0) << 7-j)
63 | g = g | (bitget(c, 1) << 7-j)
64 | b = b | (bitget(c, 2) << 7-j)
65 | c = c >> 3
66 |
67 | cmap[i] = np.array([r, g, b])
68 |
69 | cmap = cmap/255 if normalized else cmap
70 | return cmap
71 |
72 | class VOCSegmentation(data.Dataset):
73 | """`Pascal VOC `_ Segmentation Dataset.
74 | Args:
75 | root (string): Root directory of the VOC Dataset.
76 | year (string, optional): The dataset year, supports years 2007 to 2012.
77 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
78 | download (bool, optional): If true, downloads the dataset from the internet and
79 | puts it in root directory. If dataset is already downloaded, it is not
80 | downloaded again.
81 | transform (callable, optional): A function/transform that takes in an PIL image
82 | and returns a transformed version. E.g, ``transforms.RandomCrop``
83 | """
84 | cmap = voc_cmap()
85 | def __init__(self,
86 | root,
87 | year='2012',
88 | image_set='train',
89 | download=False,
90 | transform=None):
91 |
92 | is_aug=False
93 | if year=='2012_aug':
94 | is_aug = True
95 | year = '2012'
96 |
97 | self.root = os.path.expanduser(root)
98 | self.year = year
99 | self.url = DATASET_YEAR_DICT[year]['url']
100 | self.filename = DATASET_YEAR_DICT[year]['filename']
101 | self.md5 = DATASET_YEAR_DICT[year]['md5']
102 | self.transform = transform
103 |
104 | self.image_set = image_set
105 | base_dir = DATASET_YEAR_DICT[year]['base_dir']
106 | voc_root = os.path.join(self.root, base_dir)
107 | image_dir = os.path.join(voc_root, 'JPEGImages')
108 |
109 | if download:
110 | download_extract(self.url, self.root, self.filename, self.md5)
111 |
112 | if not os.path.isdir(voc_root):
113 | raise RuntimeError('Dataset not found or corrupted.' +
114 | ' You can use download=True to download it')
115 |
116 | if is_aug and image_set=='train':
117 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
118 | assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
119 | split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'
120 | else:
121 | mask_dir = os.path.join(voc_root, 'SegmentationClass')
122 | splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
123 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
124 |
125 | if not os.path.exists(split_f):
126 | raise ValueError(
127 | 'Wrong image_set entered! Please use image_set="train" '
128 | 'or image_set="trainval" or image_set="val"')
129 |
130 | with open(os.path.join(split_f), "r") as f:
131 | file_names = [x.strip() for x in f.readlines()]
132 |
133 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
134 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
135 | assert (len(self.images) == len(self.masks))
136 |
137 | def __getitem__(self, index):
138 | """
139 | Args:
140 | index (int): Index
141 | Returns:
142 | tuple: (image, target) where target is the image segmentation.
143 | """
144 | img = Image.open(self.images[index]).convert('RGB')
145 | target = Image.open(self.masks[index])
146 | if self.transform is not None:
147 | img, target = self.transform(img, target)
148 |
149 | if img.max() > 1 or img.min() < 0:
150 | img = (img - img.min()) / (img.max() - img.min())
151 |
152 | return img, target
153 |
154 |
155 | def __len__(self):
156 | return len(self.images)
157 |
158 | @classmethod
159 | def decode_target(cls, mask):
160 | """decode semantic mask to RGB image"""
161 | return cls.cmap[mask]
162 |
163 | def download_extract(url, root, filename, md5):
164 | download_url(url, root, filename, md5)
165 | with tarfile.open(os.path.join(root, filename), "r") as tar:
166 | tar.extractall(path=root)
--------------------------------------------------------------------------------
/train_eval_scripts/attack.py:
--------------------------------------------------------------------------------
1 | import torchattacks
2 | from model import PreActResNet18, WRN28_10, DeiT
3 | from utils import *
4 | import recoloradv.mister_ed.config as config
5 | from recoloradv.mister_ed.utils.pytorch_utils import DifferentiableNormalize
6 | from recoloradv.utils import get_attack_from_name
7 | from argparse import ArgumentParser
8 |
9 | parser = ArgumentParser()
10 | parser.add_argument('--model_path', default='cifar10_prn_sam_0_1.pth', type=str)
11 | args = parser.parse_args()
12 | file_name = args.model_path
13 |
14 | class Model(nn.Module):
15 | def __init__(self, model, norm):
16 | super(Model, self).__init__()
17 | self.model = model
18 | self.norm = norm
19 |
20 | def forward(self, x):
21 | return self.model(self.norm(x))
22 |
23 | label_dim = 10
24 | if 'cifar10_' in file_name:
25 | label_dim = 10
26 | normalizer = DifferentiableNormalize(
27 | mean=config.CIFAR10_MEANS,
28 | std=config.CIFAR10_STDS,
29 | )
30 | norm = normalize_cifar
31 | train_loader, test_loader = load_dataset('cifar10', 1000)
32 | elif 'cifar100_' in file_name:
33 | label_dim = 100
34 | normalizer = DifferentiableNormalize(
35 | mean=CIFAR100_MEAN,
36 | std=CIFAR100_STD,
37 | )
38 | norm = normalize_cifar100
39 | train_loader, test_loader = load_dataset('cifar100', 1000)
40 | elif 'tiny' in file_name:
41 | label_dim = 200
42 | normalizer = DifferentiableNormalize(
43 | mean=TINYIMAGENET_MEAN,
44 | std=TINYIMAGENET_STD,
45 | )
46 | norm = normalize_tinyimagenet
47 | train_loader, test_loader = load_dataset('tiny-imagenet-200', 1000)
48 | else:
49 | raise ValueError('Unknown dataset')
50 |
51 | if 'prn' in file_name and 'deit' not in file_name and 'wrn' not in file_name:
52 | model = PreActResNet18(label_dim)
53 | elif 'wrn' in file_name:
54 | model = WRN28_10(label_dim)
55 | elif 'deit' in file_name:
56 | model = DeiT(label_dim)
57 |
58 | d = torch.load('./models/' + file_name, map_location='cuda:0')
59 | for k in list(d.keys()):
60 | if k.startswith('module.'):
61 | d[k[7:]] = d[k]
62 | del d[k]
63 |
64 | model.load_state_dict(d)
65 | model.eval()
66 | model.cuda()
67 |
68 | normed_model = Model(model, norm)
69 | normed_model.eval()
70 | normed_model.cuda()
71 |
72 | # test clean accuracy on the whole test set
73 | acc = 0.
74 | for x, y in test_loader:
75 | x, y = x.cuda(), y.cuda()
76 | with torch.no_grad():
77 | pred = normed_model(x)
78 | acc += (pred.max(1)[1] == y).float().sum().item()
79 | acc /= len(test_loader.dataset)
80 | print('Model: {}, Clean Accuracy: {:.4f}'.format(file_name, acc))
81 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f:
82 | f.write('Model: {}, Clean Accuracy: {:.4f}\n'.format(file_name, acc))
83 |
84 | # attackers
85 | pgd_1 = torchattacks.PGD(normed_model, eps=1 / 255, alpha=0.25 / 255, steps=10)
86 | pgd_2 = torchattacks.PGD(normed_model, eps=2 / 255, alpha=0.5 / 255, steps=10)
87 | pgd_4 = torchattacks.PGD(normed_model, eps=4 / 255, alpha=1 / 255, steps=10)
88 | pgd_8 = torchattacks.PGD(normed_model, eps=8 / 255, alpha=2 / 255, steps=10)
89 | fgsm_1 = torchattacks.FGSM(normed_model, eps=1 / 255)
90 | fgsm_8 = torchattacks.FGSM(normed_model, eps=8 / 255)
91 | cw = torchattacks.CW(normed_model, c=1, kappa=0, steps=10)
92 | autoattack = torchattacks.APGDT(normed_model, norm='Linf', eps=4/255, steps=5, n_restarts=1, seed=0
93 | , eot_iter=1, rho=.75, verbose=False, n_classes=label_dim)
94 | pgd_l2_32 = torchattacks.PGDL2(normed_model, eps=32 / 255, alpha=8 / 255, steps=10)
95 | pgd_l2_64 = torchattacks.PGDL2(normed_model, eps=64 / 255, alpha=16 / 255, steps=10)
96 | pixle = torchattacks.Pixle(normed_model, max_iterations=5, restarts=5)
97 | fab = torchattacks.FAB(normed_model, eps=8 / 255, norm='L2')
98 |
99 | recolor_attack = get_attack_from_name('recoloradv+stadv+delta', model, normalizer, verbose=True)
100 | stadv_attack = get_attack_from_name('stadv', model, normalizer, verbose=True)
101 |
102 | lib_attacker_list = [pgd_1, pgd_2, pgd_4, pgd_8, fgsm_1, fgsm_8, cw, autoattack, pgd_l2_32,
103 | pgd_l2_64, pixle, fab]
104 | lib_atkname_list = ['pgd_1', 'pgd_2', 'pgd_4', 'pgd_8', 'fgsm_1', 'fgsm_8', 'cw', 'autoattack',
105 | 'pgd_l2_32', 'pgd_l2_64', 'pixle', 'fab']
106 | sem_attacker_list = [recolor_attack, stadv_attack]
107 | sem_atkname_list = ['recolor', 'stadv']
108 | for i in range(len(lib_attacker_list)):
109 | try:
110 | lib_attacker = lib_attacker_list[i]
111 | lib_atkname = lib_atkname_list[i]
112 | acc = 0
113 | # get first 1000 imgs and calculate acc
114 | for x, y in test_loader:
115 | x, y = x.cuda(), y.cuda()
116 | adv_x = lib_attacker(x, y)
117 | pred = normed_model(adv_x)
118 | acc += (pred.max(1)[1] == y).float().sum().item()
119 | break
120 | acc /= 1000
121 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, lib_atkname, acc))
122 | # write to log in ./logs/attack_log.txt
123 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f:
124 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, lib_atkname, acc))
125 | except:
126 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, lib_atkname, 'failed'))
127 | # write to log in ./logs/attack_log.txt
128 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f:
129 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, lib_atkname, 'failed'))
130 |
131 | for i in range(len(sem_attacker_list)):
132 | try:
133 | sem_attacker = sem_attacker_list[i]
134 | sem_atkname = sem_atkname_list[i]
135 | acc = 0
136 | # get first 1000 imgs and calculate acc
137 | for x, y in test_loader:
138 | x, y = x.cuda(), y.cuda()
139 | adv_x = sem_attacker.attack(x, y)[0]
140 | pred = normed_model(adv_x)
141 | acc += (pred.max(1)[1] == y).float().sum().item()
142 | break
143 | acc /= 1000
144 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, sem_atkname, acc))
145 | # write to log in ./logs/attack_log.txt
146 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f:
147 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, sem_atkname, acc))
148 | except:
149 | print('Model: {}, Attack: {}, Accuracy: {}'.format(file_name, sem_atkname, 'failed'))
150 | # write to log in ./logs/attack_log.txt
151 | with open(f'./logs/{file_name}-attack_log.txt', 'a') as f:
152 | f.write('Model: {}, Attack: {}, Accuracy: {}\n'.format(file_name, sem_atkname, 'failed'))
153 |
--------------------------------------------------------------------------------
/SAM_segmentation/utils/attack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import torch.nn.functional as F
6 |
7 | voc_mu = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
8 | voc_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
9 | def normalize_voc(x):
10 | return (x - voc_mu.to(x.device))/(voc_std.to(x.device))
11 |
12 | city_mu = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
13 | city_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
14 | def normalize_city(x):
15 | return (x - city_mu.to(x.device))/(city_std.to(x.device))
16 |
17 | iccv09_mu = torch.tensor([0.4813, 0.4901, 0.4747]).view(3, 1, 1)
18 | iccv09_std = torch.tensor([0.2495, 0.2492, 0.2748]).view(3, 1, 1)
19 | def normalize_iccv09(x):
20 | return (x - iccv09_mu.to(x.device))/(iccv09_std.to(x.device))
21 |
22 | class Attack():
23 | def __init__(self, iters, alpha, eps, norm, criterion, rand_init, rand_perturb, targeted, normalize):
24 | self.iters = iters
25 | self.alpha = alpha
26 | self.eps = eps
27 | self.norm = norm
28 | assert norm in ['linf', 'l2']
29 | self.criterion = criterion # loss function for perturb
30 | self.rand_init = rand_init # random initialization before perturb
31 | self.rand_perturb = rand_perturb # add random noise in each step
32 | self.targetd = targeted # targeted attack
33 | self.normalize = normalize # normalize_cifar
34 |
35 | def perturb(self, model, x, y):
36 | assert x.min() >= 0 and x.max() <= 1
37 | delta = torch.zeros_like(x, device=x.device)
38 | if self.rand_init:
39 | if self.norm == "linf":
40 | delta.uniform_(-self.eps, self.eps)
41 | elif self.norm == "l2":
42 | delta.normal_()
43 | d_flat = delta.view(delta.size(0), -1)
44 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
45 | r = torch.zeros_like(n).uniform_(0, 1)
46 | delta *= r/n*self.eps
47 | else:
48 | raise NotImplementedError("Only linf and l2 norms are implemented.")
49 |
50 | delta = torch.clamp(delta, 0-x, 1-x)
51 | delta.requires_grad = True
52 |
53 | for i in range(self.iters):
54 | # output = model(self.normalize(x+delta))
55 | output = model(x+delta)
56 | loss = self.criterion(output, y)
57 | if self.targetd:
58 | loss = -loss
59 | loss.backward()
60 | grad = delta.grad.detach()
61 | if self.norm == "linf":
62 | d = torch.clamp(delta + self.alpha * torch.sign(grad), min=-self.eps, max=self.eps).detach()
63 | elif self.norm == "l2":
64 | grad_norm = torch.norm(grad.view(grad.size(0), -1), dim=1).view(-1, 1, 1, 1)
65 | scaled_grad = grad / (grad_norm + 1e-10)
66 | d = (delta + scaled_grad * self.alpha).view(delta.size(0), -1).renorm(p=2, dim=0, maxnorm=self.eps).view_as(delta).detach()
67 |
68 | d = torch.clamp(d, 0-x, 1-x)
69 | delta.data = d
70 | delta.grad.zero_()
71 |
72 | return delta.detach()
73 |
74 | def make_one_hot(input, num_classes):
75 | """Convert class index tensor to one hot encoding tensor.
76 |
77 | Args:
78 | input: A tensor of shape [N, 1, *]
79 | num_classes: An int of number of class
80 | Returns:
81 | A tensor of shape [N, num_classes, *]
82 | """
83 | shape = np.array(input.shape)
84 | shape[1] = num_classes
85 | shape = tuple(shape)
86 | result = torch.zeros(shape)
87 | result = result.scatter_(1, input.cpu(), 1)
88 |
89 | return result
90 |
91 | class BinaryDiceLoss(nn.Module):
92 | """Dice loss of binary class
93 | Args:
94 | smooth: A float number to smooth loss, and avoid NaN error, default: 1
95 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
96 | predict: A tensor of shape [N, *]
97 | target: A tensor of shape same with predict
98 | reduction: Reduction method to apply, return mean over batch if 'mean',
99 | return sum if 'sum', return a tensor of shape [N,] if 'none'
100 | Returns:
101 | Loss tensor according to arg reduction
102 | Raise:
103 | Exception if unexpected reduction
104 | """
105 | def __init__(self, smooth=1, p=2, reduction='mean'):
106 | super(BinaryDiceLoss, self).__init__()
107 | self.smooth = smooth
108 | self.p = p
109 | self.reduction = reduction
110 |
111 | def forward(self, predict, target):
112 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
113 | predict = predict.contiguous().view(predict.shape[0], -1)
114 | target = target.contiguous().view(target.shape[0], -1)
115 |
116 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
117 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
118 |
119 | loss = 1 - num / den
120 |
121 | if self.reduction == 'mean':
122 | return loss.mean()
123 | elif self.reduction == 'sum':
124 | return loss.sum()
125 | elif self.reduction == 'none':
126 | return loss
127 | else:
128 | raise Exception('Unexpected reduction {}'.format(self.reduction))
129 |
130 |
131 | class DiceLoss(nn.Module):
132 | def __init__(self, weight=None, ignore_index=None, **kwargs):
133 | super(DiceLoss, self).__init__()
134 | self.kwargs = kwargs
135 | self.weight = weight
136 | self.ignore_index = ignore_index
137 |
138 |
139 | def forward(self, predict, target):
140 | target = self._convert_target(target)
141 | assert predict.shape == target.shape, 'predict & target shape do not match'
142 | dice = BinaryDiceLoss(**self.kwargs)
143 | total_loss = 0
144 | predict = F.softmax(predict, dim=1)
145 |
146 | for i in range(target.shape[1]):
147 | if i != self.ignore_index:
148 | dice_loss = dice(predict[:, i], target[:, i])
149 | if self.weight is not None:
150 | assert self.weight.shape[0] == target.shape[1], \
151 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
152 | dice_loss *= self.weights[i]
153 | total_loss += dice_loss
154 |
155 | return total_loss/target.shape[1]
156 |
157 | def _convert_target(self, target):
158 | device = target.device
159 | target = make_one_hot(target.unsqueeze(1), 9)
160 | target = target.to(device)
161 | return target
162 |
163 | class PGD(Attack):
164 | def __init__(self, iters, alpha, eps, norm, rand_init, targeted=False, normalize=normalize_voc):
165 | # super().__init__(iters, alpha, eps, norm, DiceLoss(ignore_index=255), rand_init=rand_init, rand_perturb=False, targeted=targeted, normalize=normalize)
166 | super().__init__(iters, alpha, eps, norm, nn.CrossEntropyLoss(ignore_index=255, reduction='mean'), rand_init=rand_init, rand_perturb=False, targeted=targeted, normalize=normalize)
--------------------------------------------------------------------------------
/SAM_segmentation/network/_deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import _SimpleSegmentationModel
6 |
7 |
8 | __all__ = ["DeepLabV3"]
9 |
10 |
11 | class DeepLabV3(_SimpleSegmentationModel):
12 | """
13 | Implements DeepLabV3 model from
14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation"
15 | `_.
16 |
17 | Arguments:
18 | backbone (nn.Module): the network used to compute the features for the model.
19 | The backbone should return an OrderedDict[Tensor], with the key being
20 | "out" for the last feature map used, and "aux" if an auxiliary classifier
21 | is used.
22 | classifier (nn.Module): module that takes the "out" element returned from
23 | the backbone and returns a dense prediction.
24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training
25 | """
26 | pass
27 |
28 | class DeepLabHeadV3Plus(nn.Module):
29 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
30 | super(DeepLabHeadV3Plus, self).__init__()
31 | self.project = nn.Sequential(
32 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
33 | nn.BatchNorm2d(48),
34 | nn.ReLU(inplace=True),
35 | )
36 |
37 | self.aspp = ASPP(in_channels, aspp_dilate)
38 |
39 | self.classifier = nn.Sequential(
40 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
41 | nn.BatchNorm2d(256),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(256, num_classes, 1)
44 | )
45 | self._init_weight()
46 |
47 | def forward(self, feature):
48 | low_level_feature = self.project( feature['low_level'] )
49 | output_feature = self.aspp(feature['out'])
50 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
51 | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
52 |
53 | def _init_weight(self):
54 | for m in self.modules():
55 | if isinstance(m, nn.Conv2d):
56 | nn.init.kaiming_normal_(m.weight)
57 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
58 | nn.init.constant_(m.weight, 1)
59 | nn.init.constant_(m.bias, 0)
60 |
61 | class DeepLabHead(nn.Module):
62 | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
63 | super(DeepLabHead, self).__init__()
64 |
65 | self.classifier = nn.Sequential(
66 | ASPP(in_channels, aspp_dilate),
67 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
68 | nn.BatchNorm2d(256),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(256, num_classes, 1)
71 | )
72 | self._init_weight()
73 |
74 | def forward(self, feature):
75 | return self.classifier( feature['out'] )
76 |
77 | def _init_weight(self):
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | nn.init.kaiming_normal_(m.weight)
81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
82 | nn.init.constant_(m.weight, 1)
83 | nn.init.constant_(m.bias, 0)
84 |
85 | class AtrousSeparableConvolution(nn.Module):
86 | """ Atrous Separable Convolution
87 | """
88 | def __init__(self, in_channels, out_channels, kernel_size,
89 | stride=1, padding=0, dilation=1, bias=True):
90 | super(AtrousSeparableConvolution, self).__init__()
91 | self.body = nn.Sequential(
92 | # Separable Conv
93 | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
94 | # PointWise Conv
95 | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
96 | )
97 |
98 | self._init_weight()
99 |
100 | def forward(self, x):
101 | return self.body(x)
102 |
103 | def _init_weight(self):
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | nn.init.kaiming_normal_(m.weight)
107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
108 | nn.init.constant_(m.weight, 1)
109 | nn.init.constant_(m.bias, 0)
110 |
111 | class ASPPConv(nn.Sequential):
112 | def __init__(self, in_channels, out_channels, dilation):
113 | modules = [
114 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
115 | nn.BatchNorm2d(out_channels),
116 | nn.ReLU(inplace=True)
117 | ]
118 | super(ASPPConv, self).__init__(*modules)
119 |
120 | class ASPPPooling(nn.Sequential):
121 | def __init__(self, in_channels, out_channels):
122 | super(ASPPPooling, self).__init__(
123 | nn.AdaptiveAvgPool2d(1),
124 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
125 | nn.BatchNorm2d(out_channels),
126 | nn.ReLU(inplace=True))
127 |
128 | def forward(self, x):
129 | size = x.shape[-2:]
130 | x = super(ASPPPooling, self).forward(x)
131 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
132 |
133 | class ASPP(nn.Module):
134 | def __init__(self, in_channels, atrous_rates):
135 | super(ASPP, self).__init__()
136 | out_channels = 256
137 | modules = []
138 | modules.append(nn.Sequential(
139 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
140 | nn.BatchNorm2d(out_channels),
141 | nn.ReLU(inplace=True)))
142 |
143 | rate1, rate2, rate3 = tuple(atrous_rates)
144 | modules.append(ASPPConv(in_channels, out_channels, rate1))
145 | modules.append(ASPPConv(in_channels, out_channels, rate2))
146 | modules.append(ASPPConv(in_channels, out_channels, rate3))
147 | modules.append(ASPPPooling(in_channels, out_channels))
148 |
149 | self.convs = nn.ModuleList(modules)
150 |
151 | self.project = nn.Sequential(
152 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
153 | nn.BatchNorm2d(out_channels),
154 | nn.ReLU(inplace=True),
155 | nn.Dropout(0.1),)
156 |
157 | def forward(self, x):
158 | res = []
159 | for conv in self.convs:
160 | res.append(conv(x))
161 | res = torch.cat(res, dim=1)
162 | return self.project(res)
163 |
164 |
165 |
166 | def convert_to_separable_conv(module):
167 | new_module = module
168 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
169 | new_module = AtrousSeparableConvolution(module.in_channels,
170 | module.out_channels,
171 | module.kernel_size,
172 | module.stride,
173 | module.padding,
174 | module.dilation,
175 | module.bias)
176 | for name, child in module.named_children():
177 | new_module.add_module(name, convert_to_separable_conv(child))
178 | return new_module
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/cifar10/cifar_loader.py:
--------------------------------------------------------------------------------
1 | """ Code to build a cifar10 data loader """
2 |
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torchvision.datasets as datasets
7 | from . import cifar_resnets
8 | from . import wide_resnets
9 | from ..utils import pytorch_utils as utils
10 | from .. import config
11 | import os
12 | import re
13 |
14 |
15 | ###############################################################################
16 | # PARSE CONFIGS #
17 | ###############################################################################
18 |
19 | DEFAULT_DATASETS_DIR = config.DEFAULT_DATASETS_DIR
20 | RESNET_WEIGHT_PATH = config.MODEL_PATH
21 | DEFAULT_BATCH_SIZE = config.DEFAULT_BATCH_SIZE
22 | DEFAULT_WORKERS = config.DEFAULT_WORKERS
23 | CIFAR10_MEANS = config.CIFAR10_MEANS
24 | CIFAR10_STDS = config.CIFAR10_STDS
25 | WIDE_CIFAR10_MEANS = config.WIDE_CIFAR10_MEANS
26 | WIDE_CIFAR10_STDS = config.WIDE_CIFAR10_STDS
27 | ###############################################################################
28 | # END PARSE CONFIGS #
29 | ###############################################################################
30 |
31 |
32 | ##############################################################################
33 | # #
34 | # MODEL LOADER #
35 | # #
36 | ##############################################################################
37 |
38 | def load_pretrained_cifar_resnet(flavor=32,
39 | return_normalizer=False,
40 | manual_gpu=None):
41 | """ Helper fxn to initialize/load the pretrained cifar resnet
42 | """
43 |
44 | # Resolve load path
45 | valid_flavor_numbers = [110, 1202, 20, 32, 44, 56]
46 | assert flavor in valid_flavor_numbers
47 | weight_path = os.path.join(RESNET_WEIGHT_PATH,
48 | 'cifar10_resnet%s.th' % flavor)
49 |
50 |
51 | # Resolve CPU/GPU stuff
52 | if manual_gpu is not None:
53 | use_gpu = manual_gpu
54 | else:
55 | use_gpu = utils.use_gpu()
56 |
57 | if use_gpu:
58 | map_location = None
59 | else:
60 | map_location = (lambda s, l: s)
61 |
62 |
63 | # need to modify the resnet state dict to be proper
64 | # TODO: LOAD THESE INTO MODEL ZOO
65 | bad_state_dict = torch.load(weight_path, map_location=map_location)
66 | correct_state_dict = {re.sub(r'^module\.', '', k): v for k, v in
67 | bad_state_dict['state_dict'].items()}
68 |
69 |
70 | classifier_net = eval("cifar_resnets.resnet%s" % flavor)()
71 | classifier_net.load_state_dict(correct_state_dict)
72 |
73 | if return_normalizer:
74 | normalizer = utils.DifferentiableNormalize(mean=CIFAR10_MEANS,
75 | std=CIFAR10_STDS)
76 | return classifier_net, normalizer
77 |
78 | return classifier_net
79 |
80 |
81 | def load_pretrained_cifar_wide_resnet(use_gpu=False, return_normalizer=False):
82 | """ Helper fxn to initialize/load a pretrained 28x10 CIFAR resnet """
83 |
84 | weight_path = os.path.join(RESNET_WEIGHT_PATH,
85 | 'cifar10_wide-resnet28x10.th')
86 | state_dict = torch.load(weight_path)
87 | classifier_net = wide_resnets.Wide_ResNet(28, 10, 0, 10)
88 |
89 | classifier_net.load_state_dict(state_dict)
90 |
91 | if return_normalizer:
92 | normalizer = utils.DifferentiableNormalize(mean=WIDE_CIFAR10_MEANS,
93 | std=WIDE_CIFAR10_STDS)
94 | return classifier_net, normalizer
95 |
96 | return classifier_net
97 |
98 |
99 |
100 |
101 |
102 |
103 | ##############################################################################
104 | # #
105 | # DATA LOADER #
106 | # #
107 | ##############################################################################
108 |
109 | def load_cifar_data(train_or_val, extra_args=None, dataset_dir=None,
110 | normalize=False, batch_size=None, manual_gpu=None,
111 | shuffle=True, no_transform=False):
112 | """ Builds a CIFAR10 data loader for either training or evaluation of
113 | CIFAR10 data. See the 'DEFAULTS' section in the fxn for default args
114 | ARGS:
115 | train_or_val: string - one of 'train' or 'val' for whether we should
116 | load training or validation datap
117 | extra_args: dict - if not None is the kwargs to be passed to DataLoader
118 | constructor
119 | dataset_dir: string - if not None is a directory to load the data from
120 | normalize: boolean - if True, we normalize the data by subtracting out
121 | means and dividing by standard devs
122 | manual_gpu : boolean or None- if None, we use the GPU if we can
123 | else, we use the GPU iff this is True
124 | shuffle: boolean - if True, we load the data in a shuffled order
125 | no_transform: boolean - if True, we don't do any random cropping/
126 | reflections of the data
127 | """
128 |
129 | ##################################################################
130 | # DEFAULTS #
131 | ##################################################################
132 | # dataset directory
133 | dataset_dir = dataset_dir or DEFAULT_DATASETS_DIR
134 | batch_size = batch_size or DEFAULT_BATCH_SIZE
135 |
136 | # Extra arguments for DataLoader constructor
137 | if manual_gpu is not None:
138 | use_gpu = manual_gpu
139 | else:
140 | use_gpu = utils.use_gpu()
141 |
142 | constructor_kwargs = {'batch_size': batch_size,
143 | 'shuffle': shuffle,
144 | 'num_workers': DEFAULT_WORKERS,
145 | 'pin_memory': use_gpu}
146 | constructor_kwargs.update(extra_args or {})
147 |
148 | # transform chain
149 | transform_list = []
150 | if no_transform is False:
151 | transform_list.extend([transforms.RandomHorizontalFlip(),
152 | transforms.RandomCrop(32, 4)])
153 | transform_list.append(transforms.ToTensor())
154 |
155 | if normalize:
156 | normalizer = transforms.Normalize(mean=CIFAR10_MEANS,
157 | std=CIFAR10_STDS)
158 | transform_list.append(normalizer)
159 |
160 |
161 | transform_chain = transforms.Compose(transform_list)
162 | # train_or_val validation
163 | assert train_or_val in ['train', 'val']
164 |
165 | ##################################################################
166 | # Build DataLoader #
167 | ##################################################################
168 | return torch.utils.data.DataLoader(
169 | datasets.CIFAR10(root=dataset_dir, train=train_or_val=='train',
170 | transform=transform_chain, download=True),
171 | **constructor_kwargs)
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | """ Specific utilities for image classification
2 | (i.e. RGB images i.e. tensors of the form NxCxHxW )
3 | """
4 | from __future__ import print_function
5 | from . import pytorch_utils as utils
6 | import torch
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torchvision.transforms as transforms
10 | import random
11 |
12 | def nhwc255_xform(img_np_array):
13 | """ Takes in a numpy array and transposes it so that the channel is the last
14 | axis. Also multiplies all values by 255.0
15 | ARGS:
16 | img_np_array : np.ndarray - array of shape (NxHxWxC) or (NxCxHxW)
17 | [assumes that we're in NCHW by default,
18 | but if not ambiguous will handle NHWC too ]
19 | RETURNS:
20 | array of form NHWC
21 | """
22 | assert isinstance(img_np_array, np.ndarray)
23 | shape = img_np_array.shape
24 | assert len(shape) == 4
25 |
26 | # determine which configuration we're in
27 | ambiguous = (shape[1] == shape[3] == 3)
28 | nhwc = (shape[1] == 3)
29 |
30 | # transpose unless we're unambiguously in nhwc case
31 | if nhwc and not ambiguous:
32 | return img_np_array * 255.0
33 | else:
34 | return np.transpose(img_np_array, (0, 2, 3, 1)) * 255.0
35 |
36 |
37 | def show_images(images, normalize=None, ipython=True,
38 | margin_height=2, margin_color='red',
39 | figsize=(18,16)):
40 | """ Shows pytorch tensors/variables as images """
41 |
42 |
43 | # first format the first arg to be hz-stacked numpy arrays
44 | if not isinstance(images, list):
45 | images = [images]
46 | images = [np.dstack(image.cpu().numpy()) for image in images]
47 | image_shape = images[0].shape
48 | assert all(image.shape == image_shape for image in images)
49 | assert all(image.ndim == 3 for image in images) # CxHxW
50 |
51 | # now build the list of final rows
52 | rows = []
53 | if margin_height >0:
54 | assert margin_color in ['red', 'black']
55 | margin_shape = list(image_shape)
56 | margin_shape[1] = margin_height
57 | margin = np.zeros(margin_shape)
58 | if margin_color == 'red':
59 | margin[0] = 1
60 | else:
61 | margin = None
62 |
63 | for image_row in images:
64 | rows.append(margin)
65 | rows.append(image_row)
66 |
67 | rows = [_ for _ in rows[1:] if _ is not None]
68 | plt.figure(figsize=figsize, dpi=80, facecolor='w', edgecolor='k')
69 |
70 | cat_rows = np.concatenate(rows, 1).transpose(1, 2, 0)
71 | imshow_kwargs = {}
72 | if cat_rows.shape[-1] == 1: # 1 channel: greyscale
73 | cat_rows = cat_rows.squeeze()
74 | imshow_kwargs['cmap'] = 'gray'
75 |
76 | plt.imshow(cat_rows, **imshow_kwargs)
77 |
78 | plt.show()
79 |
80 |
81 |
82 |
83 | def display_adversarial_2row(classifier_net, normalizer, original_images,
84 | adversarial_images, num_to_show=4, which='incorrect',
85 | ipython=False, margin_width=2):
86 | """ Displays adversarial images side-by-side with their unperturbed
87 | counterparts. Opens a window displaying two rows: top row is original
88 | images, bottom row is perturbed
89 | ARGS:
90 | classifier_net : nn - with a .forward method that takes normalized
91 | variables and outputs logits
92 | normalizer : object w/ .forward method - should probably be an instance
93 | of utils.DifferentiableNormalize or utils.IdentityNormalize
94 | original_images: Variable or Tensor (NxCxHxW) - original images to
95 | display. Images in [0., 1.] range
96 | adversarial_images: Variable or Tensor (NxCxHxW) - perturbed images to
97 | display. Should be same shape as original_images
98 | num_to_show : int - number of images to show
99 | which : string in ['incorrect', 'random', 'correct'] - which images to
100 | show.
101 | -- 'incorrect' means successfully attacked images,
102 | -- 'random' means some random selection of images
103 | -- 'correct' means unsuccessfully attacked images
104 | ipython: bool - if True, we use in an ipython notebook so slightly
105 | different way to show Images
106 | margin_width - int : height in pixels of the red margin separating top
107 | and bottom rows. Set to 0 for no margin
108 | RETURNS:
109 | None, but displays images
110 | """
111 | assert which in ['incorrect', 'random', 'correct']
112 |
113 |
114 | # If not 'random' selection, prune to only the valid things
115 | to_sample_idxs = []
116 | if which != 'random':
117 | classifier_net.eval() # can never be too safe =)
118 |
119 | # classify the originals with top1
120 | original_norm_var = normalizer.forward(original_images)
121 | original_out_logits = classifier_net.forward(original_norm_var)
122 | _, original_out_classes = original_out_logits.max(1)
123 |
124 | # classify the adversarials with top1
125 | adv_norm_var = normalizer.forward(adversarial_images)
126 | adv_out_logits = classifier_net.forward(adv_norm_var)
127 | _, adv_out_classes = adv_out_logits.max(1)
128 |
129 |
130 | # collect indices of matching
131 | selector = lambda var: (which == 'correct') == bool(float(var))
132 | for idx, var_el in enumerate(original_out_classes == adv_out_classes):
133 | if selector(var_el):
134 | to_sample_idxs.append(idx)
135 | else:
136 | to_sample_idxs = list(range(original_images.shape[0]))
137 |
138 | # Now select some indices to show
139 | if to_sample_idxs == []:
140 | print("Couldn't show anything. Try changing the 'which' argument here")
141 | return
142 |
143 | to_show_idxs = random.sample(to_sample_idxs, min([num_to_show,
144 | len(to_sample_idxs)]))
145 |
146 | # Now start building up the images : first horizontally, then vertically
147 | top_row = torch.cat([original_images[idx] for idx in to_show_idxs], dim=2)
148 | bottom_row = torch.cat([adversarial_images[idx] for idx in to_show_idxs],
149 | dim=2)
150 |
151 | if margin_width > 0:
152 | margin = torch.zeros(3, margin_width, top_row.shape[-1])
153 | margin[0] = 1.0 # make it red
154 | margin = margin.type(type(top_row))
155 | stack = [top_row, margin, bottom_row]
156 | else:
157 | stack = [top_row, bottom_row]
158 |
159 | plt.imshow(torch.cat(stack, dim=1).cpu().numpy().transpose(1, 2, 0))
160 | plt.show()
161 |
162 |
163 | def display_adversarial_notebook():
164 | pass
165 |
166 | def nchw_l2(x, y, squared=True):
167 | """ Computes l2 norm between two NxCxHxW images
168 | ARGS:
169 | x, y: Tensor/Variable (NxCxHxW) - x, y must be same type & shape.
170 | squared : bool - if True we return squared loss, otherwise we return
171 | square root of l2
172 | RETURNS:
173 | ||x - y ||_2 ^2 (no exponent if squared == False),
174 | shape is (Nx1x1x1)
175 | """
176 | temp = torch.pow(x - y, 2) # square diff
177 |
178 |
179 | for i in range(1, temp.dim()): # reduce on all but first dimension
180 | temp = torch.sum(temp, i, keepdim=True)
181 |
182 | if not squared:
183 | temp = torch.pow(temp, 0.5)
184 |
185 | return temp.squeeze()
186 |
--------------------------------------------------------------------------------
/SAM_segmentation/network/backbone/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | try: # for torchvision<0.4
3 | from torchvision.models.utils import load_state_dict_from_url
4 | except: # for torchvision>=0.4
5 | from torch.hub import load_state_dict_from_url
6 | import torch.nn.functional as F
7 |
8 | __all__ = ['MobileNetV2', 'mobilenet_v2']
9 |
10 |
11 | model_urls = {
12 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
13 | }
14 |
15 |
16 | def _make_divisible(v, divisor, min_value=None):
17 | """
18 | This function is taken from the original tf repo.
19 | It ensures that all layers have a channel number that is divisible by 8
20 | It can be seen here:
21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22 | :param v:
23 | :param divisor:
24 | :param min_value:
25 | :return:
26 | """
27 | if min_value is None:
28 | min_value = divisor
29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30 | # Make sure that round down does not go down by more than 10%.
31 | if new_v < 0.9 * v:
32 | new_v += divisor
33 | return new_v
34 |
35 |
36 | class ConvBNReLU(nn.Sequential):
37 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
38 | #padding = (kernel_size - 1) // 2
39 | super(ConvBNReLU, self).__init__(
40 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
41 | nn.BatchNorm2d(out_planes),
42 | nn.ReLU6(inplace=True)
43 | )
44 |
45 | def fixed_padding(kernel_size, dilation):
46 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
47 | pad_total = kernel_size_effective - 1
48 | pad_beg = pad_total // 2
49 | pad_end = pad_total - pad_beg
50 | return (pad_beg, pad_end, pad_beg, pad_end)
51 |
52 | class InvertedResidual(nn.Module):
53 | def __init__(self, inp, oup, stride, dilation, expand_ratio):
54 | super(InvertedResidual, self).__init__()
55 | self.stride = stride
56 | assert stride in [1, 2]
57 |
58 | hidden_dim = int(round(inp * expand_ratio))
59 | self.use_res_connect = self.stride == 1 and inp == oup
60 |
61 | layers = []
62 | if expand_ratio != 1:
63 | # pw
64 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
65 |
66 | layers.extend([
67 | # dw
68 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
69 | # pw-linear
70 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
71 | nn.BatchNorm2d(oup),
72 | ])
73 | self.conv = nn.Sequential(*layers)
74 |
75 | self.input_padding = fixed_padding( 3, dilation )
76 |
77 | def forward(self, x):
78 | x_pad = F.pad(x, self.input_padding)
79 | if self.use_res_connect:
80 | return x + self.conv(x_pad)
81 | else:
82 | return self.conv(x_pad)
83 |
84 | class MobileNetV2(nn.Module):
85 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
86 | """
87 | MobileNet V2 main class
88 |
89 | Args:
90 | num_classes (int): Number of classes
91 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
92 | inverted_residual_setting: Network structure
93 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
94 | Set to 1 to turn off rounding
95 | """
96 | super(MobileNetV2, self).__init__()
97 | block = InvertedResidual
98 | input_channel = 32
99 | last_channel = 1280
100 | self.output_stride = output_stride
101 | current_stride = 1
102 | if inverted_residual_setting is None:
103 | inverted_residual_setting = [
104 | # t, c, n, s
105 | [1, 16, 1, 1],
106 | [6, 24, 2, 2],
107 | [6, 32, 3, 2],
108 | [6, 64, 4, 2],
109 | [6, 96, 3, 1],
110 | [6, 160, 3, 2],
111 | [6, 320, 1, 1],
112 | ]
113 |
114 | # only check the first element, assuming user knows t,c,n,s are required
115 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
116 | raise ValueError("inverted_residual_setting should be non-empty "
117 | "or a 4-element list, got {}".format(inverted_residual_setting))
118 |
119 | # building first layer
120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
122 | features = [ConvBNReLU(3, input_channel, stride=2)]
123 | current_stride *= 2
124 | dilation=1
125 | previous_dilation = 1
126 |
127 | # building inverted residual blocks
128 | for t, c, n, s in inverted_residual_setting:
129 | output_channel = _make_divisible(c * width_mult, round_nearest)
130 | previous_dilation = dilation
131 | if current_stride == output_stride:
132 | stride = 1
133 | dilation *= s
134 | else:
135 | stride = s
136 | current_stride *= s
137 | output_channel = int(c * width_mult)
138 |
139 | for i in range(n):
140 | if i==0:
141 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
142 | else:
143 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
144 | input_channel = output_channel
145 | # building last several layers
146 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
147 | # make it nn.Sequential
148 | self.features = nn.Sequential(*features)
149 |
150 | # building classifier
151 | self.classifier = nn.Sequential(
152 | nn.Dropout(0.2),
153 | nn.Linear(self.last_channel, num_classes),
154 | )
155 |
156 | # weight initialization
157 | for m in self.modules():
158 | if isinstance(m, nn.Conv2d):
159 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
160 | if m.bias is not None:
161 | nn.init.zeros_(m.bias)
162 | elif isinstance(m, nn.BatchNorm2d):
163 | nn.init.ones_(m.weight)
164 | nn.init.zeros_(m.bias)
165 | elif isinstance(m, nn.Linear):
166 | nn.init.normal_(m.weight, 0, 0.01)
167 | nn.init.zeros_(m.bias)
168 |
169 | def forward(self, x):
170 | x = self.features(x)
171 | x = x.mean([2, 3])
172 | x = self.classifier(x)
173 | return x
174 |
175 |
176 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
177 | """
178 | Constructs a MobileNetV2 architecture from
179 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | progress (bool): If True, displays a progress bar of the download to stderr
184 | """
185 | model = MobileNetV2(**kwargs)
186 | if pretrained:
187 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
188 | progress=progress)
189 | model.load_state_dict(state_dict)
190 | return model
191 |
--------------------------------------------------------------------------------
/train_eval_scripts/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import pandas as pd
6 |
7 | import argparse
8 | from time import time
9 |
10 | from utils import *
11 | from model import PreActResNet18, WRN28_10, DeiT
12 | from sam import SAM, ASAM, ESAM
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--fname', type=str, required=True)
18 | parser.add_argument('--model', type=str, default='PreActResNet18', choices=['PRN', 'WRN', 'DeiT'])
19 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tiny-imagenet-200'])
20 | parser.add_argument('--epochs', default=100, type=int)
21 | parser.add_argument('--max-lr', default=0.1, type=float)
22 | parser.add_argument('--opt', default='SGD', choices=['Adam', 'SGD'])
23 | parser.add_argument('--sam', default='NO', choices=['SAM', 'ASAM', 'ESAM', 'NO'])
24 | parser.add_argument('--batch-size', default=128, type=int)
25 | parser.add_argument('--device', default=0, type=int)
26 | parser.add_argument('--adv', action='store_true')
27 | parser.add_argument('--rho', default=0.05, type=float) # for SAM
28 |
29 | parser.add_argument('--norm', default='linf', choices=['linf', 'l2'])
30 | parser.add_argument('--train-eps', default=8., type=float)
31 | parser.add_argument('--train-alpha', default=2., type=float)
32 | parser.add_argument('--train-step', default=5, type=int)
33 |
34 | parser.add_argument('--test-eps', default=1., type=float)
35 | parser.add_argument('--test-alpha', default=0.5, type=float)
36 | parser.add_argument('--test-step', default=5, type=int)
37 | return parser.parse_args()
38 |
39 |
40 | args = get_args()
41 |
42 |
43 | def lr_schedule(epoch):
44 | if epoch < args.epochs * 0.75:
45 | return args.max_lr
46 | elif epoch < args.epochs * 0.9:
47 | return args.max_lr * 0.1
48 | else:
49 | return args.max_lr * 0.01
50 |
51 |
52 | if __name__ == '__main__':
53 | dataset = args.dataset
54 | device = f'cuda:{args.device}'
55 | model_name = args.model
56 | label_dim = {'cifar10': 10, 'cifar100': 100, 'tiny-imagenet-200': 200}[dataset]
57 | model = {'PRN': PreActResNet18(label_dim), 'WRN': WRN28_10(label_dim), 'DeiT': DeiT(label_dim)}[model_name].to(
58 | device)
59 | train_loader, test_loader = load_dataset(dataset, args.batch_size)
60 | params = model.parameters()
61 | criterion = nn.CrossEntropyLoss()
62 |
63 | if args.sam == 'NO':
64 | if args.opt == 'SGD':
65 | opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4)
66 | elif args.opt == 'Adam':
67 | opt = torch.optim.Adam(params, lr=args.max_lr, weight_decay=5e-4)
68 | else:
69 | raise "Invalid optimizer"
70 | else:
71 | if args.sam == 'SAM':
72 | base_opt = torch.optim.SGD
73 | opt = SAM(params, base_opt, lr=args.max_lr, momentum=0.9, weight_decay=5e-4, rho=args.rho)
74 | elif args.sam == 'ASAM':
75 | base_opt = torch.optim.SGD(params, lr=args.max_lr, momentum=0.9, weight_decay=5e-4)
76 | opt = ASAM(base_opt, model, rho=args.rho)
77 | elif args.sam == 'ESAM':
78 | base_opt = torch.optim.SGD(model.parameters(), lr=args.max_lr, momentum=0.9, weight_decay=5e-4)
79 | opt = ESAM(params, base_opt, rho=args.rho)
80 | else:
81 | raise "Invalid SAM optimizer"
82 |
83 | normalize = \
84 | {'cifar10': normalize_cifar, 'cifar100': normalize_cifar100, 'tiny-imagenet-200': normalize_tinyimagenet}[dataset]
85 |
86 | all_log_data = []
87 | train_pgd = PGD(args.train_step, args.train_alpha / 255., args.train_eps / 255., args.norm, False, normalize)
88 | test_pgd = PGD(args.test_step, args.test_alpha / 255., args.test_eps / 255., args.norm, False, normalize)
89 |
90 | for epoch in range(args.epochs):
91 | start_time = time()
92 | log_data = [0, 0, 0, 0, 0, 0] # train_loss, train_acc, test_loss, test_acc, test_robust_loss, test_robust
93 | # train
94 | model.train()
95 | lr = lr_schedule(epoch)
96 | if args.sam == 'ASAM':
97 | opt.optimizer.param_groups[0].update(lr=lr)
98 | else:
99 | opt.param_groups[0].update(lr=lr)
100 | for x, y in train_loader:
101 | x, y = x.to(device), y.to(device)
102 | if args.adv:
103 | delta = train_pgd.perturb(model, x, y)
104 | else:
105 | delta = torch.zeros_like(x).to(x.device)
106 |
107 | if args.sam == 'NO':
108 | output = model(normalize(x + delta))
109 | loss = criterion(output, y)
110 | opt.zero_grad()
111 | loss.backward()
112 | opt.step()
113 |
114 | else:
115 | if args.sam == 'SAM':
116 | output = model(normalize(x + delta))
117 | loss = criterion(output, y)
118 | loss.backward()
119 | opt.first_step(zero_grad=True)
120 | output_2 = model(normalize(x + delta))
121 | criterion(output_2, y).backward()
122 | opt.second_step(zero_grad=True)
123 | elif args.sam == 'ASAM':
124 | output = model(normalize(x + delta))
125 | loss = criterion(output, y)
126 | loss.backward()
127 | opt.ascent_step()
128 | output_2 = model(normalize(x + delta))
129 | criterion(output_2, y).backward()
130 | opt.descent_step()
131 | elif args.sam == 'ESAM':
132 | def defined_backward(loss):
133 | loss.backward()
134 | paras = [normalize(x + delta), y, criterion, model, defined_backward]
135 | opt.paras = paras
136 | opt.step()
137 | output, loss = opt.returnthings
138 |
139 | log_data[0] += (loss * len(y)).item()
140 | log_data[1] += (output.max(1)[1] == y).float().sum().item()
141 |
142 | # test
143 | model.eval()
144 | for x, y in test_loader:
145 | x, y = x.to(device), y.to(device)
146 | # clean
147 | output = model(normalize(x)).detach()
148 | loss = criterion(output, y)
149 |
150 | log_data[2] += (loss * len(y)).item()
151 | log_data[3] += (output.max(1)[1] == y).float().sum().item()
152 | delta = test_pgd.perturb(model, x, y)
153 | output = model(normalize(x + delta)).detach()
154 | loss = criterion(output, y)
155 |
156 | log_data[4] += (loss * len(y)).item()
157 | log_data[5] += (output.max(1)[1] == y).float().sum().item()
158 |
159 | log_data = np.array(log_data)
160 | num_train = 60000 if 'cifar' in dataset else 100000
161 | num_test = 10000 if 'cifar' in dataset else 10000
162 | log_data[0] /= num_train
163 | log_data[1] /= num_train
164 | log_data[2] /= num_test
165 | log_data[3] /= num_test
166 | log_data[4] /= num_test
167 | log_data[5] /= num_test
168 | all_log_data.append(log_data)
169 |
170 | print(f'Epoch {epoch}:\t', log_data, f'\tTime {time() - start_time:.1f}s')
171 | save_path = '{dataset}_models/{fname}.pth'
172 | torch.save(model.state_dict(), save_path.format(dataset=dataset, fname=args.fname))
173 |
174 | all_log_data = np.stack(all_log_data, axis=0)
175 |
176 | df = pd.DataFrame(all_log_data)
177 | df.to_csv(f'logs/{args.fname}.csv')
178 |
179 | plt.plot(all_log_data[:, [2, 4]])
180 | plt.grid()
181 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Loss', fontsize=16)
182 | plt.legend(['clean', 'robust'], fontsize=16)
183 | plt.savefig(f'figs/{args.fname}_loss.png', dpi=200)
184 | plt.clf()
185 |
186 | plt.plot(all_log_data[:, [3, 5]])
187 | plt.grid()
188 | # plt.title(f'{dataset} {args.opt}{" adv" if args.adv else ""} Acc', fontsize=16)
189 | plt.legend(['clean', 'robust'], fontsize=16)
190 | plt.savefig(f'figs/{args.fname}_acc.png', dpi=200)
191 | plt.clf()
192 |
--------------------------------------------------------------------------------
/SAM_segmentation/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.utils.data as data
7 | from PIL import Image
8 | import numpy as np
9 |
10 |
11 | class Cityscapes(data.Dataset):
12 | """Cityscapes Dataset.
13 |
14 | **Parameters:**
15 | - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
16 | - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
17 | - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
18 | - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
20 | """
21 |
22 | # Based on https://github.com/mcordts/cityscapesScripts
23 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
24 | 'has_instances', 'ignore_in_eval', 'color'])
25 | classes = [
26 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
27 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
28 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
29 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
30 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
31 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
32 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
33 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
34 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
35 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
36 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
37 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
38 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
39 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
40 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
41 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
42 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
43 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
44 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
45 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
46 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
47 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
48 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
49 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
50 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
51 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
52 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
53 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
54 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
55 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
56 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
57 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
58 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
59 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
60 | CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
61 | ]
62 |
63 | train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
64 | train_id_to_color.append([0, 0, 0])
65 | train_id_to_color = np.array(train_id_to_color)
66 | id_to_train_id = np.array([c.train_id for c in classes])
67 |
68 | #train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35),
69 | # (70, 130, 180), (220, 20, 60), (0, 0, 142)]
70 | #train_id_to_color = np.array(train_id_to_color)
71 | #id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1
72 |
73 | def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None):
74 | self.root = os.path.expanduser(root)
75 | self.mode = 'gtFine'
76 | self.target_type = target_type
77 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
78 |
79 | self.targets_dir = os.path.join(self.root, self.mode, split)
80 | self.transform = transform
81 |
82 | self.split = split
83 | self.images = []
84 | self.targets = []
85 |
86 | if split not in ['train', 'test', 'val']:
87 | raise ValueError('Invalid split for mode! Please use split="train", split="test"'
88 | ' or split="val"')
89 |
90 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
91 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
92 | ' specified "split" and "mode" are inside the "root" directory')
93 |
94 | for city in os.listdir(self.images_dir):
95 | img_dir = os.path.join(self.images_dir, city)
96 | target_dir = os.path.join(self.targets_dir, city)
97 |
98 | for file_name in os.listdir(img_dir):
99 | self.images.append(os.path.join(img_dir, file_name))
100 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
101 | self._get_target_suffix(self.mode, self.target_type))
102 | self.targets.append(os.path.join(target_dir, target_name))
103 |
104 | @classmethod
105 | def encode_target(cls, target):
106 | return cls.id_to_train_id[np.array(target)]
107 |
108 | @classmethod
109 | def decode_target(cls, target):
110 | target[target == 255] = 19
111 | #target = target.astype('uint8') + 1
112 | return cls.train_id_to_color[target]
113 |
114 | def __getitem__(self, index):
115 | """
116 | Args:
117 | index (int): Index
118 | Returns:
119 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
120 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
121 | """
122 | image = Image.open(self.images[index]).convert('RGB')
123 | target = Image.open(self.targets[index])
124 | if self.transform:
125 | image, target = self.transform(image, target)
126 | target = self.encode_target(target)
127 | if image.max() > 1 or image.min() < 0:
128 | image = (image - image.min()) / (image.max() - image.min())
129 | return image, target
130 |
131 | def __len__(self):
132 | return len(self.images)
133 |
134 | def _load_json(self, path):
135 | with open(path, 'r') as file:
136 | data = json.load(file)
137 | return data
138 |
139 | def _get_target_suffix(self, mode, target_type):
140 | if target_type == 'instance':
141 | return '{}_instanceIds.png'.format(mode)
142 | elif target_type == 'semantic':
143 | return '{}_labelIds.png'.format(mode)
144 | elif target_type == 'color':
145 | return '{}_color.png'.format(mode)
146 | elif target_type == 'polygon':
147 | return '{}_polygons.json'.format(mode)
148 | elif target_type == 'depth':
149 | return '{}_disparity.png'.format(mode)
--------------------------------------------------------------------------------
/train_eval_scripts/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from collections import defaultdict
4 |
5 | class SAM(torch.optim.Optimizer):
6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
8 |
9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
10 | super(SAM, self).__init__(params, defaults)
11 |
12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
13 | self.param_groups = self.base_optimizer.param_groups
14 | self.defaults.update(self.base_optimizer.defaults)
15 |
16 | @torch.no_grad()
17 | def first_step(self, zero_grad=False):
18 | grad_norm = self._grad_norm()
19 | for group in self.param_groups:
20 | scale = group["rho"] / (grad_norm + 1e-12)
21 |
22 | for p in group["params"]:
23 | if p.grad is None: continue
24 | self.state[p]["old_p"] = p.data.clone()
25 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
26 | p.add_(e_w) # climb to the local maximum "w + e(w)"
27 |
28 | if zero_grad: self.zero_grad()
29 |
30 | @torch.no_grad()
31 | def second_step(self, zero_grad=False):
32 | for group in self.param_groups:
33 | for p in group["params"]:
34 | if p.grad is None: continue
35 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
36 |
37 | self.base_optimizer.step() # do the actual "sharpness-aware" update
38 |
39 | if zero_grad: self.zero_grad()
40 |
41 | @torch.no_grad()
42 | def step(self, closure=None):
43 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
44 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
45 |
46 | self.first_step(zero_grad=True)
47 | closure()
48 | self.second_step()
49 |
50 | def _grad_norm(self):
51 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
52 | norm = torch.norm(
53 | torch.stack([
54 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
55 | for group in self.param_groups for p in group["params"]
56 | if p.grad is not None
57 | ]),
58 | p=2
59 | )
60 | return norm
61 |
62 | def load_state_dict(self, state_dict):
63 | super().load_state_dict(state_dict)
64 | self.base_optimizer.param_groups = self.param_groups
65 |
66 |
67 | class ASAM:
68 | def __init__(self, optimizer, model, rho=0.5, eta=0.01):
69 | self.optimizer = optimizer
70 | self.model = model
71 | self.rho = rho
72 | self.eta = eta
73 | self.state = defaultdict(dict)
74 |
75 | @torch.no_grad()
76 | def ascent_step(self):
77 | wgrads = []
78 | for n, p in self.model.named_parameters():
79 | if p.grad is None:
80 | continue
81 | t_w = self.state[p].get("eps")
82 | if t_w is None:
83 | t_w = torch.clone(p).detach()
84 | self.state[p]["eps"] = t_w
85 | if 'weight' in n:
86 | t_w[...] = p[...]
87 | t_w.abs_().add_(self.eta)
88 | p.grad.mul_(t_w)
89 | wgrads.append(torch.norm(p.grad, p=2))
90 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
91 | for n, p in self.model.named_parameters():
92 | if p.grad is None:
93 | continue
94 | t_w = self.state[p].get("eps")
95 | if 'weight' in n:
96 | p.grad.mul_(t_w)
97 | eps = t_w
98 | eps[...] = p.grad[...]
99 | eps.mul_(self.rho / wgrad_norm)
100 | p.add_(eps)
101 | self.optimizer.zero_grad()
102 |
103 | @torch.no_grad()
104 | def descent_step(self):
105 | for n, p in self.model.named_parameters():
106 | if p.grad is None:
107 | continue
108 | p.sub_(self.state[p]["eps"])
109 | self.optimizer.step()
110 | self.optimizer.zero_grad()
111 |
112 |
113 | class ESAM(torch.optim.Optimizer):
114 | def __init__(self, params, base_optimizer, rho=0.05, beta=1.0, gamma=1.0, adaptive=False, **kwargs):
115 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
116 | self.beta = beta
117 | self.gamma = gamma
118 |
119 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
120 | super(ESAM, self).__init__(params, defaults)
121 |
122 | self.base_optimizer = base_optimizer
123 | self.param_groups = self.base_optimizer.param_groups
124 | for group in self.param_groups:
125 | group["rho"] = rho
126 | group["adaptive"] = adaptive
127 | self.paras = None
128 |
129 | @torch.no_grad()
130 | def first_step(self, zero_grad=False):
131 | # first order sum
132 | grad_norm = self._grad_norm()
133 | for group in self.param_groups:
134 | scale = group["rho"] / (grad_norm + 1e-7) / self.beta
135 | for p in group["params"]:
136 | p.requires_grad = True
137 | if p.grad is None: continue
138 | # original sam
139 | # e_w = p.grad * scale.to(p)
140 | # asam
141 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
142 | p.add_(e_w * 1) # climb to the local maximum "w + e(w)"
143 | self.state[p]["e_w"] = e_w
144 |
145 | if zero_grad: self.zero_grad()
146 |
147 | '''
148 | @torch.no_grad()
149 | def first_half(self, zero_grad=False):
150 | #first order sum
151 | for group in self.param_groups:
152 | for p in group["params"]:
153 | if self.state[p]:
154 | p.add_(self.state[p]["e_w"]*0.90) # climb to the local maximum "w + e(w)"
155 | '''
156 |
157 | @torch.no_grad()
158 | def second_step(self, zero_grad=False):
159 | for group in self.param_groups:
160 | for p in group["params"]:
161 | if p.grad is None or not self.state[p]: continue
162 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
163 | self.state[p]["e_w"] = 0
164 |
165 | if random.random() > self.beta:
166 | p.requires_grad = False
167 |
168 | self.base_optimizer.step() # do the actual "sharpness-aware" update
169 |
170 | if zero_grad: self.zero_grad()
171 |
172 | def step(self):
173 | inputs, targets, loss_fct, model, defined_backward = self.paras
174 | assert defined_backward is not None, "Sharpness Aware Minimization requires defined_backward, but it was not provided"
175 |
176 | model.require_backward_grad_sync = False
177 | model.require_forward_param_sync = True
178 |
179 | logits = model(inputs)
180 | loss = loss_fct(logits, targets)
181 |
182 | l_before = loss.clone().detach()
183 | predictions = logits
184 | return_loss = loss.clone().detach()
185 | loss = loss.mean()
186 | defined_backward(loss)
187 |
188 | # first step to w + e(w)
189 | self.first_step(True)
190 |
191 | with torch.no_grad():
192 | l_after = loss_fct(model(inputs), targets)
193 | instance_sharpness = l_after - l_before
194 |
195 | # codes for sorting
196 | prob = self.gamma
197 | if prob >= 0.99:
198 | indices = range(len(targets))
199 | else:
200 | position = int(len(targets) * prob)
201 | cutoff, _ = torch.topk(instance_sharpness, position)
202 | cutoff = cutoff[-1]
203 |
204 | # cutoff = 0
205 | # select top k%
206 |
207 | indices = [instance_sharpness > cutoff]
208 |
209 | # second forward-backward step
210 | # self.first_half()
211 |
212 | model.require_backward_grad_sync = True
213 | model.require_forward_param_sync = False
214 |
215 | loss = loss_fct(model(inputs[indices]), targets[indices])
216 | loss = loss.mean()
217 | defined_backward(loss)
218 | self.second_step(True)
219 |
220 | self.returnthings = (predictions, return_loss)
221 |
222 | def _grad_norm(self):
223 | shared_device = self.param_groups[0]["params"][
224 | 0].device # put everything on the same device, in case of model parallelism
225 | norm = torch.norm(
226 | torch.stack([
227 | # original sam
228 | # p.grad.norm(p=2).to(shared_device)
229 | # asam
230 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
231 | for group in self.param_groups for p in group["params"]
232 | if p.grad is not None
233 | ]),
234 | p=2
235 | )
236 | return norm
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/color_spaces.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains classes that convert from RGB to various other color spaces and back.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from .mister_ed.utils import pytorch_utils as utils
8 | from torch.autograd import Variable
9 | import numpy as np
10 | from recoloradv import norms
11 | import math
12 |
13 |
14 | class ColorSpace(object):
15 | """
16 | Base class for color spaces.
17 | """
18 |
19 | def from_rgb(self, imgs):
20 | """
21 | Converts an Nx3xWxH tensor in RGB color space to a Nx3xWxH tensor in
22 | this color space. All outputs should be in the 0-1 range.
23 | """
24 | raise NotImplementedError()
25 |
26 | def to_rgb(self, imgs):
27 | """
28 | Converts an Nx3xWxH tensor in this color space to a Nx3xWxH tensor in
29 | RGB color space.
30 | """
31 | raise NotImplementedError()
32 |
33 |
34 | class RGBColorSpace(ColorSpace):
35 | """
36 | RGB color space. Just applies identity transformation.
37 | """
38 |
39 | def from_rgb(self, imgs):
40 | return imgs
41 |
42 | def to_rgb(self, imgs):
43 | return imgs
44 |
45 |
46 | class YPbPrColorSpace(ColorSpace):
47 | """
48 | YPbPr color space. Uses ITU-R BT.601 standard by default.
49 | """
50 |
51 | def __init__(self, kr=0.299, kg=0.587, kb=0.114, luma_factor=1,
52 | chroma_factor=1):
53 | self.kr, self.kg, self.kb = kr, kg, kb
54 | self.luma_factor = luma_factor
55 | self.chroma_factor = chroma_factor
56 |
57 | def from_rgb(self, imgs):
58 | r, g, b = imgs.permute(1, 0, 2, 3)
59 |
60 | y = r * self.kr + g * self.kg + b * self.kb
61 | pb = (b - y) / (2 * (1 - self.kb))
62 | pr = (r - y) / (2 * (1 - self.kr))
63 |
64 | return torch.stack([y * self.luma_factor,
65 | pb * self.chroma_factor + 0.5,
66 | pr * self.chroma_factor + 0.5], 1)
67 |
68 | def to_rgb(self, imgs):
69 | y_prime, pb_prime, pr_prime = imgs.permute(1, 0, 2, 3)
70 | y = y_prime / self.luma_factor
71 | pb = (pb_prime - 0.5) / self.chroma_factor
72 | pr = (pr_prime - 0.5) / self.chroma_factor
73 |
74 | b = pb * 2 * (1 - self.kb) + y
75 | r = pr * 2 * (1 - self.kr) + y
76 | g = (y - r * self.kr - b * self.kb) / self.kg
77 |
78 | return torch.stack([r, g, b], 1).clamp(0, 1)
79 |
80 |
81 | class ApproxHSVColorSpace(ColorSpace):
82 | """
83 | Converts from RGB to approximately the HSV cone using a much smoother
84 | transformation.
85 | """
86 |
87 | def from_rgb(self, imgs):
88 | r, g, b = imgs.permute(1, 0, 2, 3)
89 |
90 | x = r * np.sqrt(2) / 3 - g / (np.sqrt(2) * 3) - b / (np.sqrt(2) * 3)
91 | y = g / np.sqrt(6) - b / np.sqrt(6)
92 | z, _ = imgs.max(1)
93 |
94 | return torch.stack([z, x + 0.5, y + 0.5], 1)
95 |
96 | def to_rgb(self, imgs):
97 | z, xp, yp = imgs.permute(1, 0, 2, 3)
98 | x, y = xp - 0.5, yp - 0.5
99 |
100 | rp = float(np.sqrt(2)) * x
101 | gp = -x / np.sqrt(2) + y * np.sqrt(3 / 2)
102 | bp = -x / np.sqrt(2) - y * np.sqrt(3 / 2)
103 |
104 | delta = z - torch.max(torch.stack([rp, gp, bp], 1), 1)[0]
105 | r, g, b = rp + delta, gp + delta, bp + delta
106 |
107 | return torch.stack([r, g, b], 1).clamp(0, 1)
108 |
109 |
110 | class HSVConeColorSpace(ColorSpace):
111 | """
112 | Converts from RGB to the HSV "cone", where (x, y, z) =
113 | (s * v cos h, s * v sin h, v). Note that this cone is then squashed to fit
114 | in [0, 1]^3 by letting (x', y', z') = ((x + 1) / 2, (y + 1) / 2, z).
115 |
116 | WARNING: has a very complex derivative, not very useful in practice
117 | """
118 |
119 | def from_rgb(self, imgs):
120 | r, g, b = imgs.permute(1, 0, 2, 3)
121 |
122 | mx, argmx = imgs.max(1)
123 | mn, _ = imgs.min(1)
124 | chroma = mx - mn
125 | eps = 1e-10
126 | h_max_r = math.pi / 3 * (g - b) / (chroma + eps)
127 | h_max_g = math.pi / 3 * (b - r) / (chroma + eps) + math.pi * 2 / 3
128 | h_max_b = math.pi / 3 * (r - g) / (chroma + eps) + math.pi * 4 / 3
129 |
130 | h = (((argmx == 0) & (chroma != 0)).float() * h_max_r
131 | + ((argmx == 1) & (chroma != 0)).float() * h_max_g
132 | + ((argmx == 2) & (chroma != 0)).float() * h_max_b)
133 |
134 | x = torch.cos(h) * chroma
135 | y = torch.sin(h) * chroma
136 | z = mx
137 |
138 | return torch.stack([(x + 1) / 2, (y + 1) / 2, z], 1)
139 |
140 | def _to_rgb_part(self, h, chroma, v, n):
141 | """
142 | Implements the function f(n) defined here:
143 | https://en.wikipedia.org/wiki/HSL_and_HSV#Alternative_HSV_to_RGB
144 | """
145 |
146 | k = (n + h * math.pi / 3) % 6
147 | return v - chroma * torch.min(k, 4 - k).clamp(0, 1)
148 |
149 | def to_rgb(self, imgs):
150 | xp, yp, z = imgs.permute(1, 0, 2, 3)
151 | x, y = xp * 2 - 1, yp * 2 - 1
152 |
153 | # prevent NaN gradients when calculating atan2
154 | x_nonzero = (1 - 2 * (torch.sign(x) == -1).float()) * (torch.abs(x) + 1e-10)
155 | h = torch.atan2(y, x_nonzero)
156 | v = z.clamp(0, 1)
157 | chroma = torch.min(torch.sqrt(x ** 2 + y ** 2 + 1e-10), v)
158 |
159 | r = self._to_rgb_part(h, chroma, v, 5)
160 | g = self._to_rgb_part(h, chroma, v, 3)
161 | b = self._to_rgb_part(h, chroma, v, 1)
162 |
163 | return torch.stack([r, g, b], 1).clamp(0, 1)
164 |
165 |
166 | class CIEXYZColorSpace(ColorSpace):
167 | """
168 | The 1931 CIE XYZ color space (assuming input is in sRGB).
169 |
170 | Warning: may have values outside [0, 1] range. Should only be used in
171 | the process of converting to/from other color spaces.
172 | """
173 |
174 | def from_rgb(self, imgs):
175 | # apply gamma correction
176 | small_values_mask = (imgs < 0.04045).float()
177 | imgs_corrected = (
178 | (imgs / 12.92) * small_values_mask +
179 | ((imgs + 0.055) / 1.055) ** 2.4 * (1 - small_values_mask)
180 | )
181 |
182 | # linear transformation to XYZ
183 | r, g, b = imgs_corrected.permute(1, 0, 2, 3)
184 | x = 0.4124 * r + 0.3576 * g + 0.1805 * b
185 | y = 0.2126 * r + 0.7152 * g + 0.0722 * b
186 | z = 0.0193 * r + 0.1192 * g + 0.9504 * b
187 |
188 | return torch.stack([x, y, z], 1)
189 |
190 | def to_rgb(self, imgs):
191 | # linear transformation
192 | x, y, z = imgs.permute(1, 0, 2, 3)
193 | r = 3.2406 * x - 1.5372 * y - 0.4986 * z
194 | g = -0.9689 * x + 1.8758 * y + 0.0415 * z
195 | b = 0.0557 * x - 0.2040 * y + 1.0570 * z
196 |
197 | imgs = torch.stack([r, g, b], 1)
198 |
199 | # apply gamma correction
200 | small_values_mask = (imgs < 0.0031308).float()
201 | imgs_clamped = imgs.clamp(min=1e-10) # prevent NaN gradients
202 | imgs_corrected = (
203 | (12.92 * imgs) * small_values_mask +
204 | (1.055 * imgs_clamped ** (1 / 2.4) - 0.055) *
205 | (1 - small_values_mask)
206 | )
207 |
208 | return imgs_corrected
209 |
210 |
211 | class CIELUVColorSpace(ColorSpace):
212 | """
213 | Converts to the 1976 CIE L*u*v* color space.
214 | """
215 |
216 | def __init__(self, up_white=0.1978, vp_white=0.4683, y_white=1,
217 | eps=1e-10):
218 | self.xyz_cspace = CIEXYZColorSpace()
219 | self.up_white = up_white
220 | self.vp_white = vp_white
221 | self.y_white = y_white
222 | self.eps = eps
223 |
224 | def from_rgb(self, imgs):
225 | x, y, z = self.xyz_cspace.from_rgb(imgs).permute(1, 0, 2, 3)
226 |
227 | # calculate u' and v'
228 | denom = x + 15 * y + 3 * z + self.eps
229 | up = 4 * x / denom
230 | vp = 9 * y / denom
231 |
232 | # calculate L*, u*, and v*
233 | small_values_mask = (y / self.y_white < (6 / 29) ** 3).float()
234 | y_clamped = y.clamp(min=self.eps) # prevent NaN gradients
235 | L = (
236 | ((29 / 3) ** 3 * y / self.y_white) * small_values_mask +
237 | (116 * (y_clamped / self.y_white) ** (1 / 3) - 16) *
238 | (1 - small_values_mask)
239 | )
240 | u = 13 * L * (up - self.up_white)
241 | v = 13 * L * (vp - self.vp_white)
242 |
243 | return torch.stack([L / 100, (u + 100) / 200, (v + 100) / 200], 1)
244 |
245 | def to_rgb(self, imgs):
246 | L = imgs[:, 0, :, :] * 100
247 | u = imgs[:, 1, :, :] * 200 - 100
248 | v = imgs[:, 2, :, :] * 200 - 100
249 |
250 | up = u / (13 * L + self.eps) + self.up_white
251 | vp = v / (13 * L + self.eps) + self.vp_white
252 |
253 | small_values_mask = (L <= 8).float()
254 | y = (
255 | (self.y_white * L * (3 / 29) ** 3) * small_values_mask +
256 | (self.y_white * ((L + 16) / 116) ** 3) * (1 - small_values_mask)
257 | )
258 | denom = 4 * vp + self.eps
259 | x = y * 9 * up / denom
260 | z = y * (12 - 3 * up - 20 * vp) / denom
261 |
262 | return self.xyz_cspace.to_rgb(
263 | torch.stack([x, y, z], 1).clamp(0, 1.1)).clamp(0, 1)
264 |
--------------------------------------------------------------------------------
/train_eval_scripts/recoloradv/mister_ed/utils/discretization.py:
--------------------------------------------------------------------------------
1 | """ File that holds techniques for discretizing images --
2 | In general, images of the form NxCxHxW will with values in the [0.,1.] range
3 | need to be converted to the [0, 255 (int)] range to be displayed as images.
4 |
5 | Sometimes the naive rounding scheme can mess up the classification, so this
6 | file holds techniques to discretize these images into tensors with values
7 | of the form i/255.0 for some integers i.
8 | """
9 |
10 | import torch
11 | from torch.autograd import Variable
12 | from . import pytorch_utils as utils
13 |
14 | ##############################################################################
15 | # #
16 | # HELPER METHODS #
17 | # #
18 | ##############################################################################
19 |
20 |
21 | def discretize_image(img_tensor, zero_one=False):
22 | """ Discretizes an image tensor into a tensor filled with ints ranging
23 | between 0 and 255
24 | ARGS:
25 | img_tensor : floatTensor (NxCxHxW) - tensor to be discretized
26 | pixel_max : int - discretization bucket size
27 | zero_one : bool - if True divides output by 255 before returning it
28 | """
29 |
30 | assert float(torch.min(img_tensor)) >= 0.
31 | assert float(torch.max(img_tensor)) <= 1.0
32 |
33 |
34 | original_shape = img_tensor.shape
35 | if img_tensor.dim() != 4:
36 | img_tensor = img_tensor.unsqueeze(0)
37 |
38 | int_tensors = [] # actually floatTensor, but full of ints
39 | img_shape = original_shape[1:]
40 | for example in img_tensor:
41 | pixel_channel_tuples = zip(*list(smp.toimage(example).getdata()))
42 | int_tensors.append(img_tensor.new(pixel_channel_tuples).view(img_shape))
43 |
44 | stacked_tensors = torch.stack(int_tensors)
45 | if zero_one:
46 | return stacked_tensors / 255.0
47 | return stacked_tensors
48 |
49 |
50 |
51 | ##############################################################################
52 | # #
53 | # MAIN DISCRETIZATION TECHNIQUES #
54 | # #
55 | ##############################################################################
56 |
57 | def discretized_adversarial(img_tensor, classifier_net, normalizer,
58 | flavor='greedy'):
59 | """ Takes in an image_tensor and classifier/normalizer pair and outputs a
60 | 'discretized' image_tensor [each val is i/255.0 for some integer i]
61 | with the same classification
62 | ARGS:
63 | img_tensor : tensor (NxCxHxW) - tensor of images with values between
64 | 0.0 and 1.0.
65 | classifier_net : NN - neural net with .forward method to classify
66 | normalized images
67 | normalizer : differentiableNormalizer object - normalizes 0,1 images
68 | into classifier_domain
69 | flavor : string - either 'random' or 'greedy', determining which
70 | 'next_pixel_to_flip' function we use
71 | RETURNS:
72 | img_tensor of the same shape, but no with values of the form i/255.0
73 | for integers i.
74 | """
75 |
76 | img_tensor = utils.safe_tensor(img_tensor)
77 |
78 | nptf_map = {'random': flip_random_pixel,
79 | 'greedy': flip_greedy_pixel}
80 | next_pixel_to_flip = nptf_map[flavor](classifier_net, normalizer)
81 |
82 | ##########################################################################
83 | # First figure out 'correct' labels and the 'discretized' labels #
84 | ##########################################################################
85 | var_img = utils.safe_var(img_tensor)
86 | norm_var = normalizer.forward(var_img)
87 | norm_output = classifier_net.forward(norm_var)
88 | correct_targets = norm_output.max(1)[1]
89 |
90 | og_discretized = utils.safe_var(discretize_image(img_tensor, zero_one=True))
91 | norm_discretized = normalizer.forward(og_discretized)
92 | discretized_output = classifier_net.forward(norm_discretized)
93 | discretized_targets = discretized_output.max(1)[1]
94 |
95 | ##########################################################################
96 | # Collect idxs for examples affected by discretization #
97 | ##########################################################################
98 | incorrect_idxs = set()
99 |
100 | for i, el in enumerate(correct_targets.ne(discretized_targets)):
101 | if float(el) != 0:
102 | incorrect_idxs.add(i)
103 |
104 |
105 | ##########################################################################
106 | # Fix all bad images #
107 | ##########################################################################
108 |
109 | corrected_imgs = []
110 | for idx in incorrect_idxs:
111 | desired_target = correct_targets[idx]
112 | example = og_discretized[idx].data.clone() # tensor
113 | signs = torch.sign(var_img - og_discretized)
114 | bad_discretization = True
115 | pixels_changed_so_far = set() # populated with tuples of idxs
116 |
117 | while bad_discretization:
118 | pixel_idx, grad_sign = next_pixel_to_flip(example,
119 | pixels_changed_so_far,
120 | desired_target)
121 | pixels_changed_so_far.add(pixel_idx)
122 |
123 | if grad_sign == 0:
124 | grad_sign = utils.tuple_getter(signs[idx], pixel_idx)
125 |
126 | new_val = (grad_sign / 255. + utils.tuple_getter(example, pixel_idx))
127 | utils.tuple_setter(example, pixel_idx, float(new_val))
128 |
129 | new_out = classifier_net.forward(normalizer.forward(\
130 | Variable(example.unsqueeze(0))))
131 | bad_discretization = (int(desired_target) != int(new_out.max(1)[1]))
132 | corrected_imgs.append(example)
133 |
134 | # Stack up results
135 | output = []
136 |
137 | for idx in range(len(img_tensor)):
138 | if idx in incorrect_idxs:
139 | output.append(corrected_imgs.pop(0))
140 | else:
141 | output.append(og_discretized[idx].data)
142 |
143 | return torch.stack(output) # Variable
144 |
145 |
146 |
147 |
148 |
149 | #############################################################################
150 | # #
151 | # FLIP TECHNIQUES #
152 | # #
153 | #############################################################################
154 | ''' Flip techniques in general have the following specs:
155 | ARGS:
156 | classifier_net : NN - neural net with .forward method to classify
157 | normalized images
158 | normalizer : differentiableNormalizer object - normalizes 0,1 images
159 | into classifier_domain
160 | RETURNS: flip_function
161 | '''
162 |
163 | '''
164 | Flip function is a function that takes the following args:
165 | ARGS:
166 | img_tensor : Tensor (CxHxW) - image tensor in range 0.0 to 1.0 and is
167 | already discretized
168 | pixels_changed_so_far: set - set of index_tuples that have already been
169 | modified (we don't want to modify a pixel by
170 | more than 1/255 in any channel)
171 | correct_target : torch.LongTensor (1) - single element in a tensor that
172 | is the target class
173 | (e.g. int between 0 and 9 for CIFAR )
174 | RETURNS: (idx_tuple, sign)
175 | index_tuple is a triple of indices indicating which pixel-channel needs
176 | to be modified, and sign is in {-1, 0, 1}. If +-1, we will modify the
177 | pixel-channel in that direction, otherwise we'll modify in the opposite
178 | of the direction that discretization rounded to.
179 | '''
180 |
181 |
182 | def flip_random_pixel(classifier_net, normalizer):
183 | def flip_fxn(img_tensor, pixels_changed_so_far, correct_target):
184 | numel = img_tensor.numel()
185 | if len(pixels_changed_so_far) > numel * .9:
186 | raise Exception("WHAT IS GOING ON???")
187 |
188 | while True:
189 | pixel_idx, _ = utils.random_element_index(img_tensor)
190 | if pixel_idx not in pixels_changed_so_far:
191 | return pixel_idx, 0
192 |
193 | return flip_fxn
194 |
195 |
196 |
197 | def flip_greedy_pixel(classifier_net, normalizer):
198 | def flip_fxn(img_tensor, pixels_changed_so_far, correct_target,
199 | classifier_net=classifier_net, normalizer=normalizer):
200 | # Computes gradient and figures out which px most affects class_out
201 | classifier_net.zero_grad()
202 | img_var = Variable(img_tensor.unsqueeze(0), requires_grad=True)
203 | class_out = classifier_net.forward(normalizer.forward(img_var))
204 |
205 | criterion = torch.nn.CrossEntropyLoss()
206 | loss = criterion(class_out, correct_target) # RESHAPE HERE
207 | loss.backward()
208 | # Really inefficient algorithm here, can probably do better
209 | new_grad_data = img_var.grad.data.clone().squeeze()
210 | signs = new_grad_data.sign()
211 | for idx_tuple in pixels_changed_so_far:
212 | utils.tuple_setter(new_grad_data, idx_tuple, 0)
213 |
214 | argmax = utils.torch_argmax(new_grad_data.abs())
215 | return argmax, -1 * utils.tuple_getter(signs, argmax)
216 |
217 | return flip_fxn
218 |
219 |
220 |
221 |
--------------------------------------------------------------------------------
/SAM_segmentation/network/backbone/xception.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Xception is adapted from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
4 |
5 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
6 | @author: tstandley
7 | Adapted by cadene
8 | Creates an Xception Model as defined in:
9 | Francois Chollet
10 | Xception: Deep Learning with Depthwise Separable Convolutions
11 | https://arxiv.org/pdf/1610.02357.pdf
12 | This weights ported from the Keras implementation. Achieves the following performance on the validation set:
13 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292
14 | REMEMBER to set your image size to 3x299x299 for both test and validation
15 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
16 | std=[0.5, 0.5, 0.5])
17 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
18 | """
19 | from __future__ import print_function, division, absolute_import
20 | import math
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | import torch.utils.model_zoo as model_zoo
25 | from torch.nn import init
26 |
27 | __all__ = ['xception']
28 |
29 | pretrained_settings = {
30 | 'xception': {
31 | 'imagenet': {
32 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
33 | 'input_space': 'RGB',
34 | 'input_size': [3, 299, 299],
35 | 'input_range': [0, 1],
36 | 'mean': [0.5, 0.5, 0.5],
37 | 'std': [0.5, 0.5, 0.5],
38 | 'num_classes': 1000,
39 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
40 | }
41 | }
42 | }
43 |
44 |
45 | class SeparableConv2d(nn.Module):
46 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
47 | super(SeparableConv2d,self).__init__()
48 |
49 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
50 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
51 |
52 | def forward(self,x):
53 | x = self.conv1(x)
54 | x = self.pointwise(x)
55 | return x
56 |
57 |
58 | class Block(nn.Module):
59 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilation=1):
60 | super(Block, self).__init__()
61 |
62 | if out_filters != in_filters or strides!=1:
63 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
64 | self.skipbn = nn.BatchNorm2d(out_filters)
65 | else:
66 | self.skip=None
67 |
68 | rep=[]
69 |
70 | filters=in_filters
71 | if grow_first:
72 | rep.append(nn.ReLU(inplace=True))
73 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation, dilation=dilation, bias=False))
74 | rep.append(nn.BatchNorm2d(out_filters))
75 | filters = out_filters
76 |
77 | for i in range(reps-1):
78 | rep.append(nn.ReLU(inplace=True))
79 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=dilation,dilation=dilation,bias=False))
80 | rep.append(nn.BatchNorm2d(filters))
81 |
82 | if not grow_first:
83 | rep.append(nn.ReLU(inplace=True))
84 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation,dilation=dilation,bias=False))
85 | rep.append(nn.BatchNorm2d(out_filters))
86 |
87 | if not start_with_relu:
88 | rep = rep[1:]
89 | else:
90 | rep[0] = nn.ReLU(inplace=False)
91 |
92 | if strides != 1:
93 | rep.append(nn.MaxPool2d(3,strides,1))
94 | self.rep = nn.Sequential(*rep)
95 |
96 | def forward(self,inp):
97 | x = self.rep(inp)
98 |
99 | if self.skip is not None:
100 | skip = self.skip(inp)
101 | skip = self.skipbn(skip)
102 | else:
103 | skip = inp
104 | x+=skip
105 | return x
106 |
107 |
108 | class Xception(nn.Module):
109 | """
110 | Xception optimized for the ImageNet dataset, as specified in
111 | https://arxiv.org/pdf/1610.02357.pdf
112 | """
113 | def __init__(self, num_classes=1000, replace_stride_with_dilation=None):
114 | """ Constructor
115 | Args:
116 | num_classes: number of classes
117 | """
118 | super(Xception, self).__init__()
119 |
120 | self.num_classes = num_classes
121 | self.dilation = 1
122 | if replace_stride_with_dilation is None:
123 | # each element in the tuple indicates if we should replace
124 | # the 2x2 stride with a dilated convolution instead
125 | replace_stride_with_dilation = [False, False, False, False]
126 | if len(replace_stride_with_dilation) != 4:
127 | raise ValueError("replace_stride_with_dilation should be None "
128 | "or a 4-element tuple, got {}".format(replace_stride_with_dilation))
129 |
130 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) # 1 / 2
131 | self.bn1 = nn.BatchNorm2d(32)
132 | self.relu1 = nn.ReLU(inplace=True)
133 |
134 | self.conv2 = nn.Conv2d(32,64,3,bias=False)
135 | self.bn2 = nn.BatchNorm2d(64)
136 | self.relu2 = nn.ReLU(inplace=True)
137 | #do relu here
138 |
139 | self.block1=self._make_block(64,128,2,2,start_with_relu=False,grow_first=True, dilate=replace_stride_with_dilation[0]) # 1 / 4
140 | self.block2=self._make_block(128,256,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[1]) # 1 / 8
141 | self.block3=self._make_block(256,728,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) # 1 / 16
142 |
143 | self.block4=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
144 | self.block5=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
145 | self.block6=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
146 | self.block7=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
147 |
148 | self.block8=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
149 | self.block9=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
150 | self.block10=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
151 | self.block11=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
152 |
153 | self.block12=self._make_block(728,1024,2,2,start_with_relu=True,grow_first=False, dilate=replace_stride_with_dilation[3]) # 1 / 32
154 |
155 | self.conv3 = SeparableConv2d(1024,1536,3,1,1, dilation=self.dilation)
156 | self.bn3 = nn.BatchNorm2d(1536)
157 | self.relu3 = nn.ReLU(inplace=True)
158 |
159 | #do relu here
160 | self.conv4 = SeparableConv2d(1536,2048,3,1,1, dilation=self.dilation)
161 | self.bn4 = nn.BatchNorm2d(2048)
162 |
163 | self.fc = nn.Linear(2048, num_classes)
164 |
165 | # #------- init weights --------
166 | # for m in self.modules():
167 | # if isinstance(m, nn.Conv2d):
168 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
169 | # m.weight.data.normal_(0, math.sqrt(2. / n))
170 | # elif isinstance(m, nn.BatchNorm2d):
171 | # m.weight.data.fill_(1)
172 | # m.bias.data.zero_()
173 | # #-----------------------------
174 |
175 | def _make_block(self, in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilate=False):
176 | if dilate:
177 | self.dilation *= strides
178 | strides = 1
179 | return Block(in_filters,out_filters,reps,strides,start_with_relu=start_with_relu,grow_first=grow_first, dilation=self.dilation)
180 |
181 | def features(self, input):
182 | x = self.conv1(input)
183 | x = self.bn1(x)
184 | x = self.relu1(x)
185 |
186 | x = self.conv2(x)
187 | x = self.bn2(x)
188 | x = self.relu2(x)
189 |
190 | x = self.block1(x)
191 | x = self.block2(x)
192 | x = self.block3(x)
193 | x = self.block4(x)
194 | x = self.block5(x)
195 | x = self.block6(x)
196 | x = self.block7(x)
197 | x = self.block8(x)
198 | x = self.block9(x)
199 | x = self.block10(x)
200 | x = self.block11(x)
201 | x = self.block12(x)
202 |
203 | x = self.conv3(x)
204 | x = self.bn3(x)
205 | x = self.relu3(x)
206 |
207 | x = self.conv4(x)
208 | x = self.bn4(x)
209 | return x
210 |
211 | def logits(self, features):
212 | x = nn.ReLU(inplace=True)(features)
213 |
214 | x = F.adaptive_avg_pool2d(x, (1, 1))
215 | x = x.view(x.size(0), -1)
216 | x = self.last_linear(x)
217 | return x
218 |
219 | def forward(self, input):
220 | x = self.features(input)
221 | x = self.logits(x)
222 | return x
223 |
224 |
225 | def xception(num_classes=1000, pretrained='imagenet', replace_stride_with_dilation=None):
226 | model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation)
227 | if pretrained:
228 | settings = pretrained_settings['xception'][pretrained]
229 | assert num_classes == settings['num_classes'], \
230 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
231 |
232 | model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation)
233 | model.load_state_dict(model_zoo.load_url(settings['url']))
234 |
235 | # TODO: ugly
236 | model.last_linear = model.fc
237 | del model.fc
238 | return model
--------------------------------------------------------------------------------
/train_eval_scripts/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from deit import deit_tiny_patch16_224
5 |
6 |
7 | class PreActBlock(nn.Module):
8 | '''Pre-activation version of the BasicBlock.'''
9 | expansion = 1
10 |
11 | def __init__(self, in_planes, planes, stride=1):
12 | super(PreActBlock, self).__init__()
13 | self.bn1 = nn.BatchNorm2d(in_planes)
14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(planes)
16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
17 |
18 | if stride != 1 or in_planes != self.expansion*planes:
19 | self.shortcut = nn.Sequential(
20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
21 | )
22 |
23 | def forward(self, x):
24 | out = F.relu(self.bn1(x))
25 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
26 | out = self.conv1(out)
27 | out = self.conv2(F.relu(self.bn2(out)))
28 | out += shortcut
29 | return out
30 |
31 |
32 | class PreActBottleneck(nn.Module):
33 | '''Pre-activation version of the original Bottleneck module.'''
34 | expansion = 4
35 |
36 | def __init__(self, in_planes, planes, stride=1):
37 | super(PreActBottleneck, self).__init__()
38 | self.bn1 = nn.BatchNorm2d(in_planes)
39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
40 | self.bn2 = nn.BatchNorm2d(planes)
41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes)
43 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
44 |
45 | if stride != 1 or in_planes != self.expansion*planes:
46 | self.shortcut = nn.Sequential(
47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
48 | )
49 |
50 | def forward(self, x):
51 | out = F.relu(self.bn1(x))
52 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
53 | out = self.conv1(out)
54 | out = self.conv2(F.relu(self.bn2(out)))
55 | out = self.conv3(F.relu(self.bn3(out)))
56 | out += shortcut
57 | return out
58 |
59 |
60 | class PreActResNet(nn.Module):
61 | def __init__(self, block, num_blocks, num_classes=10):
62 | super(PreActResNet, self).__init__()
63 | self.in_planes = 64
64 |
65 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
70 | self.bn = nn.BatchNorm2d(512 * block.expansion)
71 | self.linear = nn.Linear(512*block.expansion, num_classes)
72 |
73 | def _make_layer(self, block, planes, num_blocks, stride):
74 | strides = [stride] + [1]*(num_blocks-1)
75 | layers = []
76 | for stride in strides:
77 | layers.append(block(self.in_planes, planes, stride))
78 | self.in_planes = planes * block.expansion
79 | return nn.Sequential(*layers)
80 |
81 | def forward(self, x):
82 | out = self.conv1(x)
83 | out = self.layer1(out)
84 | out = self.layer2(out)
85 | out = self.layer3(out)
86 | out = self.layer4(out)
87 | out = F.relu(self.bn(out))
88 | out = F.avg_pool2d(out, 4)
89 | out = out.view(out.size(0), -1)
90 | out = self.linear(out)
91 | return out
92 |
93 |
94 | def PreActResNet18(num_classes=10):
95 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
96 |
97 | def initialize_weights(module):
98 | if isinstance(module, nn.Conv2d):
99 | nn.init.kaiming_normal_(module.weight.data, mode='fan_in')
100 | elif isinstance(module, nn.BatchNorm2d):
101 | module.weight.data.uniform_()
102 | module.bias.data.zero_()
103 | elif isinstance(module, nn.Linear):
104 | module.bias.data.zero_()
105 |
106 |
107 | class BasicBlock(nn.Module):
108 | def __init__(self, in_channels, out_channels, stride, drop_rate):
109 | super(BasicBlock, self).__init__()
110 |
111 | self.drop_rate = drop_rate
112 |
113 | self._preactivate_both = (in_channels != out_channels)
114 |
115 | self.bn1 = nn.BatchNorm2d(in_channels)
116 | self.conv1 = nn.Conv2d(
117 | in_channels,
118 | out_channels,
119 | kernel_size=3,
120 | stride=stride, # downsample with first conv
121 | padding=1,
122 | bias=False)
123 |
124 | self.bn2 = nn.BatchNorm2d(out_channels)
125 | self.conv2 = nn.Conv2d(
126 | out_channels,
127 | out_channels,
128 | kernel_size=3,
129 | stride=1,
130 | padding=1,
131 | bias=False)
132 |
133 | self.shortcut = nn.Sequential()
134 | if in_channels != out_channels:
135 | self.shortcut.add_module(
136 | 'conv',
137 | nn.Conv2d(
138 | in_channels,
139 | out_channels,
140 | kernel_size=1,
141 | stride=stride, # downsample
142 | padding=0,
143 | bias=False))
144 |
145 | def forward(self, x):
146 | if self._preactivate_both:
147 | x = F.relu(
148 | self.bn1(x), inplace=True) # shortcut after preactivation
149 | y = self.conv1(x)
150 | else:
151 | y = F.relu(
152 | self.bn1(x),
153 | inplace=True) # preactivation only for residual path
154 | y = self.conv1(y)
155 | if self.drop_rate > 0:
156 | y = F.dropout(
157 | y, p=self.drop_rate, training=self.training, inplace=False)
158 |
159 | y = F.relu(self.bn2(y), inplace=True)
160 | y = self.conv2(y)
161 | y += self.shortcut(x)
162 | return y
163 |
164 |
165 | class Network(nn.Module):
166 | def __init__(self, config):
167 | super(Network, self).__init__()
168 |
169 | input_shape = config['input_shape']
170 | n_classes = config['n_classes']
171 |
172 | base_channels = config['base_channels']
173 | widening_factor = config['widening_factor']
174 | drop_rate = config['drop_rate']
175 | depth = config['depth']
176 |
177 | block = BasicBlock
178 | n_blocks_per_stage = (depth - 4) // 6
179 | assert n_blocks_per_stage * 6 + 4 == depth
180 |
181 | n_channels = [
182 | base_channels, base_channels * widening_factor,
183 | base_channels * 2 * widening_factor,
184 | base_channels * 4 * widening_factor
185 | ]
186 |
187 | self.conv = nn.Conv2d(
188 | input_shape[1],
189 | n_channels[0],
190 | kernel_size=3,
191 | stride=1,
192 | padding=1,
193 | bias=False)
194 |
195 | self.stage1 = self._make_stage(
196 | n_channels[0],
197 | n_channels[1],
198 | n_blocks_per_stage,
199 | block,
200 | stride=1,
201 | drop_rate=drop_rate)
202 | self.stage2 = self._make_stage(
203 | n_channels[1],
204 | n_channels[2],
205 | n_blocks_per_stage,
206 | block,
207 | stride=2,
208 | drop_rate=drop_rate)
209 | self.stage3 = self._make_stage(
210 | n_channels[2],
211 | n_channels[3],
212 | n_blocks_per_stage,
213 | block,
214 | stride=2,
215 | drop_rate=drop_rate)
216 | self.bn = nn.BatchNorm2d(n_channels[3])
217 |
218 | # compute conv feature size
219 | with torch.no_grad():
220 | self.feature_size = self._forward_conv(
221 | torch.zeros(*input_shape)).view(-1).shape[0]
222 |
223 | self.fc = nn.Linear(self.feature_size, n_classes)
224 |
225 | # initialize weights
226 | self.apply(initialize_weights)
227 |
228 | def _make_stage(self, in_channels, out_channels, n_blocks, block, stride,
229 | drop_rate):
230 | stage = nn.Sequential()
231 | for index in range(n_blocks):
232 | block_name = 'block{}'.format(index + 1)
233 | if index == 0:
234 | stage.add_module(
235 | block_name,
236 | block(
237 | in_channels,
238 | out_channels,
239 | stride=stride,
240 | drop_rate=drop_rate))
241 | else:
242 | stage.add_module(
243 | block_name,
244 | block(
245 | out_channels,
246 | out_channels,
247 | stride=1,
248 | drop_rate=drop_rate))
249 | return stage
250 |
251 | def _forward_conv(self, x):
252 | x = self.conv(x)
253 | x = self.stage1(x)
254 | x = self.stage2(x)
255 | x = self.stage3(x)
256 | x = F.relu(self.bn(x), inplace=True)
257 | x = F.adaptive_avg_pool2d(x, output_size=1)
258 | return x
259 |
260 | def forward(self, x):
261 | x = self._forward_conv(x)
262 | x = x.view(x.size(0), -1)
263 | x = self.fc(x)
264 | return x
265 |
266 |
267 | def WRN28_10(num_classes=10):
268 | config = {
269 | 'input_shape': (1, 3, 32, 32),
270 | 'n_classes': num_classes,
271 | 'base_channels': 16,
272 | 'widening_factor': 10,
273 | 'drop_rate': 0.3,
274 | 'depth': 28
275 | }
276 | return Network(config)
277 |
278 | def DeiT(num_classes=10):
279 | model = deit_tiny_patch16_224(pretrained=False, img_size = 32, patch_size = 2, num_classes=num_classes)
280 | return model
--------------------------------------------------------------------------------