├── evaluation ├── __init__.py └── evaluate_cdvae.py ├── networks ├── __init__.py ├── __pycache__ │ ├── vae.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── nearest_embed.py ├── adv_vae.py └── vae.py ├── detection ├── lib │ ├── __init__.py │ ├── __pycache__ │ │ ├── tvm.cpython-38.pyc │ │ ├── util.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── attacks.cpython-38.pyc │ │ ├── runutils.cpython-38.pyc │ │ ├── adversary.cpython-38.pyc │ │ └── transforms.cpython-38.pyc │ ├── findseam.h │ ├── quilting.h │ ├── findseam.cpp │ ├── tvm.py │ ├── kmedoid.py │ ├── runutils.py │ ├── transformation_helper.py │ ├── quilting_fast.py │ ├── util.py │ ├── transforms.py │ ├── quilting.cpp │ ├── quilting.py │ ├── attacks.py │ └── _tv_bregman.patch ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── vae.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── densenet.cpython-38.pyc │ │ └── wide_resnet.cpython-38.pyc │ ├── wide_resnet.py │ ├── vae.py │ ├── densenet.py │ └── resnet.py ├── detection.sh ├── lib_regression.py ├── ADV_Regression_Subspace.py ├── calculate_log.py ├── data_loader.py ├── ADV_Generate_Mahalanobis_Subspace.py └── ADV_Samples_Subspace.py ├── cd_vae.png ├── utils ├── ._set.py ├── .DS_Store ├── ._.DS_Store ├── ._randaugment4fixmatch.py ├── __pycache__ │ ├── set.cpython-37.pyc │ ├── set.cpython-38.pyc │ ├── normalize.cpython-38.pyc │ ├── randAugment.cpython-38.pyc │ ├── randaugment4fixmatch.cpython-37.pyc │ └── randaugment4fixmatch.cpython-38.pyc ├── normalize.py ├── randAugment.py └── randaugment4fixmatch.py ├── advex └── __pycache__ │ └── attacks.cpython-38.pyc ├── toolkits ├── disentangle_cifar.sh ├── adv_train_cifar.sh └── adv_test_cifar.sh ├── LICENSE ├── tools ├── adv_test_cifar.py ├── disentangle_cifar.py └── adv_train_cifar.py └── README.md /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /detection/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cd_vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/cd_vae.png -------------------------------------------------------------------------------- /utils/._set.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/._set.py -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /utils/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/._.DS_Store -------------------------------------------------------------------------------- /utils/._randaugment4fixmatch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/._randaugment4fixmatch.py -------------------------------------------------------------------------------- /utils/__pycache__/set.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/set.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/set.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/networks/__pycache__/vae.cpython-38.pyc -------------------------------------------------------------------------------- /advex/__pycache__/attacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/advex/__pycache__/attacks.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wide_resnet import * 2 | from .resnet import * 3 | from .densenet import * 4 | from .vae import * 5 | -------------------------------------------------------------------------------- /detection/lib/__pycache__/tvm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/tvm.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/normalize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/normalize.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/randAugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/randAugment.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/attacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/attacks.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/runutils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/runutils.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__pycache__/vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/models/__pycache__/vae.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/adversary.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/adversary.cpython-38.pyc -------------------------------------------------------------------------------- /detection/lib/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/lib/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/models/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /toolkits/disentangle_cifar.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tools/disentangle_cifar.py --save_dir results/disentangle_cifar_ce0.2 --ce 0.2 --optim cosine 2 | -------------------------------------------------------------------------------- /utils/__pycache__/randaugment4fixmatch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/randaugment4fixmatch.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/randaugment4fixmatch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/utils/__pycache__/randaugment4fixmatch.cpython-38.pyc -------------------------------------------------------------------------------- /detection/models/__pycache__/wide_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/CD-VAE/HEAD/detection/models/__pycache__/wide_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /toolkits/adv_train_cifar.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python tools/adv_train_cifar.py --batch_size 100 --lr 1 --cr 0.1 --cg 0.1 --save_dir ./results/defense_0.1_0.1 2 | -------------------------------------------------------------------------------- /detection/lib/findseam.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | double findseam( 8 | int numnodes, // number of nodes 9 | int numedges, // number of edges 10 | int* from, // from indices 11 | int* to, // to indices 12 | float* values, // values on edges 13 | float* tvalues, // values for terminal edges 14 | int* labels // memory in which to write the labels 15 | ); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | -------------------------------------------------------------------------------- /toolkits/adv_test_cifar.sh: -------------------------------------------------------------------------------- 1 | python tools/adv_test_cifar.py --model_path ./results/defense_0.1_0.1/robust_model_g_epoch82.pth --vae_path ./results/defense_0.1_0.1/robust_vae_epoch82.pth --batch_size 256 \ 2 | "NoAttack()" \ 3 | "AutoLinfAttack(cd_vae, 'cifar', bound=8/255)" \ 4 | "AutoL2Attack(cd_vae, 'cifar', bound=1.0)" \ 5 | "JPEGLinfAttack(cd_vae, 'cifar', bound=0.125, num_iterations=100)" \ 6 | "StAdvAttack(cd_vae, num_iterations=100)" \ 7 | "ReColorAdvAttack(cd_vae, num_iterations=100)" -------------------------------------------------------------------------------- /detection/lib/quilting.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void generatePatches( 8 | float* result, // N x (C x P x P) 9 | float* img, // C x H x W 10 | unsigned int imgH, 11 | unsigned int imgW, 12 | unsigned int patchSize, 13 | unsigned int overlap); 14 | 15 | void generateQuiltedImages( 16 | float* result, // C x H x W 17 | long* neighbors, // M 18 | float* patchDict, // N x (C x P x P) 19 | unsigned int imgH, 20 | unsigned int imgW, 21 | unsigned int patchSize, 22 | unsigned int overlap, 23 | bool graphcut); 24 | 25 | #ifdef __cplusplus 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /detection/lib/findseam.cpp: -------------------------------------------------------------------------------- 1 | #include "findseam.h" 2 | #include "graph.h" 3 | 4 | double findseam( 5 | int numnodes, // number of nodes 6 | int numedges, // number of edges 7 | int* from, // from indices 8 | int* to, // to indices 9 | float* values, // values on edges 10 | float* tvalues, // values for terminal edges 11 | int* labels // memory in which to write the labels 12 | ) { 13 | // initialize graph: 14 | Graph* g = 15 | new Graph(numnodes, numedges); 16 | g->add_node(numnodes); 17 | 18 | // add edges: 19 | for (unsigned int i = 0; i < numedges; i++) { 20 | g->add_edge(from[i], to[i], values[i], 0.0f); 21 | } 22 | 23 | // add terminal nodes: 24 | for (unsigned int i = 0; i < numnodes; i++) { 25 | g->add_tweights(i, tvalues[i * 2], tvalues[i * 2 + 1]); 26 | } 27 | 28 | // run maxflow algorithm: 29 | double flow = g->maxflow(); 30 | for (unsigned int i = 0; i < numnodes; i++) { 31 | labels[i] = g->what_segment(i); 32 | } 33 | 34 | // return results: 35 | delete g; 36 | return flow; 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 kai-wen-yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /detection/detection.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type FGSM --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 2 | CUDA_VISIBLE_DEVICES=1 python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type BIM --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 3 | CUDA_VISIBLE_DEVICES=1 python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 4 | CUDA_VISIBLE_DEVICES=1 python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type CW --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 5 | CUDA_VISIBLE_DEVICES=1 python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD-L2 --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 6 | 7 | CUDA_VISIBLE_DEVICES=1 python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type FGSM --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 8 | CUDA_VISIBLE_DEVICES=1 python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type BIM --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 9 | CUDA_VISIBLE_DEVICES=1 python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 10 | CUDA_VISIBLE_DEVICES=1 python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type CW --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 11 | CUDA_VISIBLE_DEVICES=1 python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD-L2 --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 12 | 13 | CUDA_VISIBLE_DEVICES=1 python ADV_Regression_Subspace.py --net_type resnet --outf ./data/cd-vae-1/; 14 | -------------------------------------------------------------------------------- /detection/lib/tvm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import numpy as np 7 | from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman 8 | from skimage.util import random_noise 9 | from scipy.optimize import minimize 10 | from skimage.util import img_as_float 11 | from skimage import color 12 | 13 | def tv(x, p): 14 | f = np.linalg.norm(x[1:, :] - x[:-1, :], p, axis=1).sum() 15 | f += np.linalg.norm(x[:, 1:] - x[:, :-1], p, axis=0).sum() 16 | return f 17 | 18 | 19 | def tv_dx(x, p): 20 | if p == 1: 21 | x_diff0 = np.sign(x[1:, :] - x[:-1, :]) 22 | x_diff1 = np.sign(x[:, 1:] - x[:, :-1]) 23 | elif p > 1: 24 | x_diff0_norm = np.power(np.linalg.norm(x[1:, :] - x[:-1, :], p, axis=1), p - 1) 25 | x_diff1_norm = np.power(np.linalg.norm(x[:, 1:] - x[:, :-1], p, axis=0), p - 1) 26 | x_diff0_norm[x_diff0_norm < 1e-3] = 1e-3 27 | x_diff1_norm[x_diff1_norm < 1e-3] = 1e-3 28 | x_diff0_norm = np.repeat(x_diff0_norm[:, np.newaxis], x.shape[1], axis=1) 29 | x_diff1_norm = np.repeat(x_diff1_norm[np.newaxis, :], x.shape[0], axis=0) 30 | x_diff0 = p * np.power(x[1:, :] - x[:-1, :], p - 1) / x_diff0_norm 31 | x_diff1 = p * np.power(x[:, 1:] - x[:, :-1], p - 1) / x_diff1_norm 32 | df = np.zeros(x.shape) 33 | df[:-1, :] = -x_diff0 34 | df[1:, :] += x_diff0 35 | df[:, :-1] -= x_diff1 36 | df[:, 1:] += x_diff1 37 | return df 38 | 39 | 40 | def tv_l2(x, y, w, lam, p): 41 | f = 0.5 * np.power(x - y.flatten(), 2).dot(w.flatten()) 42 | x = np.reshape(x, y.shape) 43 | return f + lam * tv(x, p) 44 | 45 | 46 | def tv_l2_dx(x, y, w, lam, p): 47 | x = np.reshape(x, y.shape) 48 | df = (x - y) * w 49 | return df.flatten() + lam * tv_dx(x, p).flatten() 50 | 51 | 52 | def tv_inf(x, y, lam, p, tau): 53 | x = np.reshape(x, y.shape) 54 | return tau + lam * tv(x, p) 55 | 56 | 57 | def tv_inf_dx(x, y, lam, p, tau): 58 | x = np.reshape(x, y.shape) 59 | return lam * tv_dx(x, p).flatten() 60 | 61 | -------------------------------------------------------------------------------- /detection/lib/kmedoid.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/letiantian/kmedoids/blob/master/kmedoids.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import numpy as np 8 | import random 9 | 10 | def kMedoids(D, k, tmax=100): 11 | # determine dimensions of distance matrix D 12 | m, n = D.shape 13 | 14 | if k > n: 15 | raise Exception('too many medoids') 16 | 17 | # find a set of valid initial cluster medoid indices since we 18 | # can't seed different clusters with two points at the same location 19 | valid_medoid_inds = set(range(n)) 20 | invalid_medoid_inds = set([]) 21 | rs,cs = np.where(D==0) 22 | # the rows, cols must be shuffled because we will keep the first duplicate below 23 | index_shuf = list(range(len(rs))) 24 | np.random.shuffle(index_shuf) 25 | rs = rs[index_shuf] 26 | cs = cs[index_shuf] 27 | for r,c in zip(rs,cs): 28 | # if there are two points with a distance of 0... 29 | # keep the first one for cluster init 30 | if r < c and r not in invalid_medoid_inds: 31 | invalid_medoid_inds.add(c) 32 | valid_medoid_inds = list(valid_medoid_inds - invalid_medoid_inds) 33 | 34 | if k > len(valid_medoid_inds): 35 | raise Exception('too many medoids (after removing {} duplicate points)'.format( 36 | len(invalid_medoid_inds))) 37 | 38 | # randomly initialize an array of k medoid indices 39 | M = np.array(valid_medoid_inds) 40 | np.random.shuffle(M) 41 | M = np.sort(M[:k]) 42 | 43 | # create a copy of the array of medoid indices 44 | Mnew = np.copy(M) 45 | 46 | # initialize a dictionary to represent clusters 47 | C = {} 48 | for t in range(tmax): 49 | # determine clusters, i. e. arrays of data indices 50 | J = np.argmin(D[:,M], axis=1) 51 | for kappa in range(k): 52 | C[kappa] = np.where(J==kappa)[0] 53 | # update cluster medoids 54 | for kappa in range(k): 55 | J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1) 56 | j = np.argmin(J) 57 | Mnew[kappa] = C[kappa][j] 58 | np.sort(Mnew) 59 | # check for convergence 60 | if np.array_equal(M, Mnew): 61 | break 62 | M = np.copy(Mnew) 63 | else: 64 | # final update of cluster memberships 65 | J = np.argmin(D[:,M], axis=1) 66 | for kappa in range(k): 67 | C[kappa] = np.where(J==kappa)[0] 68 | 69 | # return results 70 | return M, C -------------------------------------------------------------------------------- /utils/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from utils.normalize import * 5 | CIFAR_MEAN = [0.4914, 0.4822, 0.4465] 6 | CIFAR_STD = [0.2470, 0.2435, 0.2616] 7 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 8 | IMAGENET_STD = [0.229, 0.224, 0.225] 9 | 10 | def get_cifar_params(resol): 11 | mean_list = [] 12 | std_list = [] 13 | for i in range(3): 14 | mean_list.append(torch.full((resol, resol), CIFAR_MEAN[i], device='cuda')) 15 | std_list.append(torch.full((resol, resol), CIFAR_STD[i], device='cuda')) 16 | return torch.unsqueeze(torch.stack(mean_list), 0), torch.unsqueeze(torch.stack(std_list), 0) 17 | 18 | def get_imagenet_params(resol): 19 | mean_list = [] 20 | std_list = [] 21 | for i in range(3): 22 | mean_list.append(torch.full((resol, resol), IMAGENET_MEAN[i], device='cuda')) 23 | std_list.append(torch.full((resol, resol), IMAGENET_STD[i], device='cuda')) 24 | return torch.unsqueeze(torch.stack(mean_list), 0), torch.unsqueeze(torch.stack(std_list), 0) 25 | 26 | class CIFARNORMALIZE(nn.Module): 27 | def __init__(self, resol): 28 | super().__init__() 29 | self.mean, self.std = get_cifar_params(resol) 30 | 31 | def forward(self, x): 32 | ''' 33 | Parameters: 34 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 35 | ''' 36 | x = x.sub(self.mean) 37 | x = x.div(self.std) 38 | return x 39 | 40 | class CIFARINNORMALIZE(nn.Module): 41 | def __init__(self, resol): 42 | super().__init__() 43 | self.mean, self.std = get_cifar_params(resol) 44 | 45 | def forward(self, x): 46 | ''' 47 | Parameters: 48 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 49 | ''' 50 | x = x.mul(self.std) 51 | x = x.add(*self.mean) 52 | return x 53 | 54 | class IMAGENETNORMALIZE(nn.Module): 55 | def __init__(self, resol): 56 | super().__init__() 57 | self.mean, self.std = get_imagenet_params(resol) 58 | 59 | def forward(self, x): 60 | ''' 61 | Parameters: 62 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 63 | ''' 64 | x = x.sub(self.mean) 65 | x = x.div(self.std) 66 | return x 67 | 68 | class IMAGENETINNORMALIZE(nn.Module): 69 | def __init__(self, resol): 70 | super().__init__() 71 | self.mean, self.std = get_imagenet_params(resol) 72 | 73 | def forward(self, x): 74 | ''' 75 | Parameters: 76 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 77 | ''' 78 | x = x.mul(self.std) 79 | x = x.add(*self.mean) 80 | return x 81 | -------------------------------------------------------------------------------- /detection/lib_regression.py: -------------------------------------------------------------------------------- 1 | # several functions are from https://github.com/xingjunm/lid_adversarial_subspace_detection 2 | from __future__ import print_function 3 | import numpy as np 4 | import os 5 | import calculate_log as callog 6 | 7 | from scipy.spatial.distance import pdist, cdist, squareform 8 | 9 | 10 | def block_split(X, Y, out): 11 | """ 12 | Split the data training and testing 13 | :return: X (data) and Y (label) for training / testing 14 | """ 15 | num_samples = X.shape[0] 16 | if out == 'svhn': 17 | partition = 26032 18 | else: 19 | partition = 10000 20 | X_adv, Y_adv = X[:partition], Y[:partition] 21 | X_norm, Y_norm = X[partition: :], Y[partition: :] 22 | num_train = 1000 23 | 24 | X_train = np.concatenate((X_norm[:num_train], X_adv[:num_train])) 25 | Y_train = np.concatenate((Y_norm[:num_train], Y_adv[:num_train])) 26 | 27 | X_test = np.concatenate((X_norm[num_train:], X_adv[num_train:])) 28 | Y_test = np.concatenate((Y_norm[num_train:], Y_adv[num_train:])) 29 | 30 | return X_train, Y_train, X_test, Y_test 31 | 32 | 33 | def block_split_adv(X, Y): 34 | """ 35 | Split the data training and testing 36 | :return: X (data) and Y (label) for training / testing 37 | """ 38 | num_samples = X.shape[0] 39 | partition = int(num_samples / 3) 40 | X_adv, Y_adv = X[:partition], Y[:partition] 41 | X_norm, Y_norm = X[partition: 2*partition], Y[partition: 2*partition] 42 | X_noisy, Y_noisy = X[2*partition:], Y[2*partition:] 43 | num_train = int(partition*0.1) 44 | X_train = np.concatenate((X_norm[:num_train], X_noisy[:num_train], X_adv[:num_train])) 45 | Y_train = np.concatenate((Y_norm[:num_train], Y_noisy[:num_train], Y_adv[:num_train])) 46 | 47 | X_test = np.concatenate((X_norm[num_train:], X_noisy[num_train:], X_adv[num_train:])) 48 | Y_test = np.concatenate((Y_norm[num_train:], Y_noisy[num_train:], Y_adv[num_train:])) 49 | 50 | return X_train, Y_train, X_test, Y_test 51 | 52 | def detection_performance(regressor, X, Y, outf): 53 | """ 54 | Measure the detection performance 55 | return: detection metrics 56 | """ 57 | num_samples = X.shape[0] 58 | l1 = open('%s/confidence_TMP_In.txt'%outf, 'w') 59 | l2 = open('%s/confidence_TMP_Out.txt'%outf, 'w') 60 | y_pred = regressor.predict_proba(X)[:, 1] 61 | 62 | for i in range(num_samples): 63 | if Y[i] == 0: 64 | l1.write("{}\n".format(-y_pred[i])) 65 | else: 66 | l2.write("{}\n".format(-y_pred[i])) 67 | l1.close() 68 | l2.close() 69 | results = callog.metric(outf, ['TMP']) 70 | return results 71 | 72 | def load_characteristics(score, dataset, out, outf): 73 | """ 74 | Load the calculated scores 75 | return: data and label of input score 76 | """ 77 | X, Y = None, None 78 | 79 | file_name = os.path.join(outf, "%s_%s_%s.npy" % (score, dataset, out)) 80 | data = np.load(file_name) 81 | 82 | if X is None: 83 | X = data[:, :-1] 84 | else: 85 | X = np.concatenate((X, data[:, :-1]), axis=1) 86 | if Y is None: 87 | Y = data[:, -1] # labels only need to load once 88 | 89 | return X, Y -------------------------------------------------------------------------------- /evaluation/evaluate_cdvae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import wandb 13 | import os 14 | import time 15 | import argparse 16 | import datetime 17 | from torch.autograd import Variable 18 | import pdb 19 | import sys 20 | 21 | sys.path.append('.') 22 | 23 | from networks import * 24 | from utils.set import * 25 | from utils.normalize import * 26 | from advex.attacks import * 27 | normalize = CIFARNORMALIZE(32) 28 | 29 | def run_iter(wandb, batch_idx, x, y, model_r, model_g, vae, attack): 30 | attack_name = attack.__class__.__name__ 31 | x, y = x.cuda(), y.cuda().view(-1, ) 32 | adv_x = attack(x, y) 33 | rescont_bs = 64 34 | 35 | gx, _, _ = vae(normalize(adv_x)) 36 | logits_r = model_r(normalize(adv_x) - gx) 37 | logits_g = model_g(gx) 38 | 39 | prec1_g, _, _, _ = accuracy(logits_g.data, y.data, topk=(1, 5)) 40 | prec1_r, _, _, _ = accuracy(logits_r.data, y.data, topk=(1, 5)) 41 | 42 | if batch_idx <= 1: 43 | grid_X = torchvision.utils.make_grid(adv_x[:rescont_bs].data, nrow=8, padding=2, normalize=True) 44 | wandb.log({"_{attack}/_{batch}_X.jpg".format(batch=batch_idx, attack=attack_name): [ 45 | wandb.Image(grid_X)]}, commit=False) 46 | grid_Xi = torchvision.utils.make_grid(gx[:rescont_bs].data, nrow=8, padding=2, normalize=True) 47 | wandb.log({"_{attack}/_{batch}_GX.jpg".format(batch=batch_idx, attack=attack_name): [ 48 | wandb.Image(grid_Xi)]}, commit=False) 49 | grid_X_Xi = torchvision.utils.make_grid((normalize(adv_x)[:rescont_bs] - gx[:rescont_bs]).data, nrow=8, 50 | padding=2, 51 | normalize=True) 52 | wandb.log({"_{attack}/_{batch}_RX.jpg".format(batch=batch_idx, attack=attack_name): [ 53 | wandb.Image(grid_X_Xi)]}, commit=False) 54 | return prec1_r, prec1_g 55 | 56 | def test(wandb, model_r, model_g, vae, testloader, attack, val_num=None): 57 | attack_name = attack.__class__.__name__ 58 | model_r.eval() 59 | model_g.eval() 60 | vae.eval() 61 | 62 | top1_g = AverageMeter() 63 | top1_r = AverageMeter() 64 | 65 | for batch_idx, (x , y) in enumerate(testloader): 66 | bs = x.size(0) 67 | if val_num: 68 | if batch_idx >= val_num: 69 | break 70 | prec1_r, prec1_g = run_iter(wandb, batch_idx, x, y, model_r, model_g, vae, attack) 71 | top1_g.update(prec1_g.item(), bs) 72 | top1_r.update(prec1_r.item(), bs) 73 | 74 | wandb.log({f'{attack_name}/test-XG-acc': top1_g.avg, \ 75 | f'{attack_name}/test-XR-acc': top1_r.avg}, commit=False) 76 | # plot progress 77 | print('Attack:{}'.format(attack_name)) 78 | print("| XG: %.2f%% XR: %.2f%%" % (top1_g.avg, top1_r.avg)) 79 | 80 | def evaluate(wandb, model_r, model_g, vae, testloader, validation_attacks, val_num=None): 81 | for val_attack in validation_attacks: 82 | test(wandb, model_r, model_g, vae, testloader, val_attack, val_num) 83 | -------------------------------------------------------------------------------- /detection/lib/runutils.py: -------------------------------------------------------------------------------- 1 | from operator import methodcaller 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def get_cuda_state(obj): 9 | """ 10 | Get cuda state of any object. 11 | 12 | :param obj: an object (a tensor or an `torch.nn.Module`) 13 | :raise TypeError: 14 | :return: True if the object or the parameter set of the object 15 | is on GPU 16 | """ 17 | if isinstance(obj, nn.Module): 18 | try: 19 | return next(obj.parameters()).is_cuda 20 | except StopIteration: 21 | return None 22 | elif hasattr(obj, 'is_cuda'): 23 | return obj.is_cuda 24 | else: 25 | raise TypeError('unrecognized type ({}) in args'.format(type(obj))) 26 | 27 | 28 | def is_cuda_consistent(*args): 29 | """ 30 | See if the cuda states are consistent among variables (of type either 31 | tensors or torch.autograd.Variable). For example, 32 | 33 | import torch 34 | from torch.autograd import Variable 35 | import torch.nn as nn 36 | 37 | net = nn.Linear(512, 10) 38 | tensor = torch.rand(10, 10).cuda() 39 | assert not is_cuda_consistent(net=net, tensor=tensor) 40 | 41 | :param args: the variables to test 42 | :return: True if len(args) == 0 or the cuda states of all elements in args 43 | are consistent; False otherwise 44 | """ 45 | result = dict() 46 | for v in args: 47 | cur_cuda_state = get_cuda_state(v) 48 | cuda_state = result.get('cuda', cur_cuda_state) 49 | if cur_cuda_state is not cuda_state: 50 | return False 51 | result['cuda'] = cur_cuda_state 52 | return True 53 | 54 | def make_cuda_consistent(refobj, *args): 55 | """ 56 | Attempt to make the cuda states of args consistent with that of ``refobj``. 57 | If any element of args is a Variable and the cuda state of the element is 58 | inconsistent with ``refobj``, raise ValueError, since changing the cuda state 59 | of a Variable involves rewrapping it in a new Variable, which changes the 60 | semantics of the code. 61 | 62 | :param refobj: either the referential object or the cuda state of the 63 | referential object 64 | :param args: the variables to test 65 | :return: tuple of the same data as ``args`` but on the same device as 66 | ``refobj`` 67 | """ 68 | ref_cuda_state = refobj if type(refobj) is bool else get_cuda_state(refobj) 69 | if ref_cuda_state is None: 70 | raise ValueError('cannot determine the cuda state of `refobj` ({})' 71 | .format(refobj)) 72 | move_to_device = methodcaller('cuda' if ref_cuda_state else 'cpu') 73 | 74 | result_args = list() 75 | for v in args: 76 | cuda_state = get_cuda_state(v) 77 | if cuda_state != ref_cuda_state: 78 | if isinstance(v, Variable): 79 | raise ValueError('cannot change cuda state of a Variable') 80 | elif isinstance(v, nn.Module): 81 | move_to_device(v) 82 | else: 83 | v = move_to_device(v) 84 | result_args.append(v) 85 | return tuple(result_args) 86 | 87 | def predict(net, inputs): 88 | """ 89 | Predict labels. The cuda state of `net` decides that of the returned 90 | prediction tensor. 91 | 92 | :param net: the network 93 | :param inputs: the input tensor (non Variable), of dimension [B x C x W x H] 94 | :return: prediction tensor (LongTensor), of dimension [B] 95 | """ 96 | inputs = make_cuda_consistent(net, inputs)[0] 97 | inputs_var = Variable(inputs) 98 | outputs_var = net(inputs_var) 99 | predictions = torch.max(outputs_var.data, dim=1)[1] 100 | return predictions 101 | -------------------------------------------------------------------------------- /tools/adv_test_cifar.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from tqdm import tqdm 10 | from copy import deepcopy 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from perceptual_advex.attacks import StAdvAttack 14 | import wandb 15 | import os 16 | import time 17 | import argparse 18 | import datetime 19 | from torch.autograd import Variable 20 | import pdb 21 | import sys 22 | import wandb 23 | sys.path.append('.') 24 | from typing import Dict, List 25 | from networks.adv_vae import * 26 | from utils.set import * 27 | from utils.randaugment4fixmatch import RandAugmentMC 28 | from utils.normalize import * 29 | from advex.attacks import * 30 | from perceptual_advex.attacks import * 31 | 32 | normalize = CIFARNORMALIZE(32) 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 36 | parser.add_argument('attacks', metavar='attack', type=str, nargs='+', 37 | help='attack names') 38 | parser.add_argument('--dim', default=2048, type=int, help='CNN_embed_dim') 39 | parser.add_argument('--fdim', default=32, type=int, help='featdim') 40 | parser.add_argument('--batch_size', default=256, type=int, help='batch_size') 41 | parser.add_argument("--model_path", type=str, default="./results/v3cr1.0_cg1.0_kl0.1/model_g_epoch42.pth") 42 | parser.add_argument("--vae_path", type=str, default="./results/v3cr1.0_cg1.0_kl0.1/vae_epoch42.pth") 43 | args = parser.parse_args() 44 | use_cuda = torch.cuda.is_available() 45 | 46 | transform_test = transforms.Compose([ 47 | transforms.ToTensor(), 48 | ]) 49 | print("| Preparing CIFAR-10 dataset...") 50 | sys.stdout.write("| ") 51 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test) 52 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=0) 53 | 54 | cd_vae = CD_VAE(args.vae_path, args.model_path) 55 | wandb.init(config=args) 56 | if use_cuda: 57 | cd_vae.cuda() 58 | cudnn.benchmark = True 59 | 60 | cd_vae.eval() 61 | 62 | attack_names: List[str] = args.attacks 63 | attacks = [eval(attack_name) for attack_name in attack_names] 64 | batches_correct: Dict[str, List[torch.Tensor]] = \ 65 | {attack_name: [] for attack_name in attack_names} 66 | for batch_index, (inputs, labels) in enumerate(testloader): 67 | print(f'BATCH {batch_index:05d}') 68 | 69 | if torch.cuda.is_available(): 70 | inputs = inputs.cuda() 71 | labels = labels.cuda() 72 | 73 | for attack_name, attack in zip(attack_names, attacks): 74 | adv_inputs = attack(inputs, labels) 75 | with torch.no_grad(): 76 | adv_logits = cd_vae(adv_inputs) 77 | batch_correct = (adv_logits.argmax(1) == labels).detach() 78 | 79 | batch_accuracy = batch_correct.float().mean().item() 80 | print(f'ATTACK {attack_name}', 81 | f'accuracy = {batch_accuracy * 100:.1f}', 82 | sep='\t') 83 | batches_correct[attack_name].append(batch_correct) 84 | 85 | print('OVERALL') 86 | accuracies = [] 87 | attacks_correct: Dict[str, torch.Tensor] = {} 88 | for attack_name in attack_names: 89 | attacks_correct[attack_name] = torch.cat(batches_correct[attack_name]) 90 | accuracy = attacks_correct[attack_name].float().mean().item() 91 | print(f'ATTACK {attack_name}', 92 | f'accuracy = {accuracy * 100:.1f}', 93 | sep='\t') 94 | accuracies.append(accuracy) 95 | -------------------------------------------------------------------------------- /detection/lib/transformation_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torchvision.transforms as torch_trans 9 | import lib.transformations.transforms as transforms 10 | from lib.datasets.transform_dataset import TransformDataset 11 | 12 | 13 | # Initialize transformations to be applied to dataset 14 | def setup_transformations(args, data_type, defense, crop=None): 15 | if 'preprocessed_data' in args and args.preprocessed_data: 16 | assert defense is not None, ( 17 | "If data is already pre processed for defenses then " 18 | "defenses can't be None") 19 | if crop: 20 | assert callable(crop), "crop should be a callable method" 21 | 22 | transform = [] 23 | # setup transformation without adversary 24 | if 'adversary' not in args or args.adversary is None: 25 | if (data_type == 'train'): 26 | if 'preprocessed_data' in args and args.preprocessed_data: 27 | # Defenses are already applied on randomly cropped images 28 | transform.append(torch_trans.Scale(args.data_params['IMAGE_SIZE'])) 29 | else: 30 | transform.append( 31 | torch_trans.RandomSizedCrop(args.data_params['IMAGE_SIZE'])) 32 | 33 | transform.append(torch_trans.RandomHorizontalFlip()) 34 | transform.append(torch_trans.ToTensor()) 35 | 36 | else: # validation 37 | # No augmentation for validation 38 | if 'preprocessed_data' not in args or not args.preprocessed_data: 39 | transform.append(torch_trans.Scale(args.data_params['IMAGE_SCALE_SIZE'])) 40 | transform.append(torch_trans.CenterCrop( 41 | args.data_params['IMAGE_SIZE'])) 42 | transform.append(torch_trans.ToTensor()) 43 | if crop: 44 | transform.append(crop) 45 | 46 | transform.append(transforms.Scale(args.data_params['IMAGE_SIZE'])) 47 | 48 | # Apply defenses at runtime (VERY SLOW) 49 | # Prefer pre-processing and saving data, and then using it 50 | if ('preprocessed_data' in args and not args.preprocessed_data and 51 | defense is not None): 52 | transform = transform + [defense] 53 | 54 | else: # Adversarial images 55 | if crop is not None: 56 | transform.append(crop) 57 | 58 | transform.append(transforms.Scale(args.data_params['IMAGE_SIZE'], 59 | args.data_params['MEAN_STD'])) 60 | 61 | # Apply defenses at runtime (VERY SLOW) 62 | # Prefer pre-processing and saving data, and then using it 63 | if not args.preprocessed_data and defense is not None: 64 | transform.append(defense) 65 | 66 | if 'normalize' in args and args.normalize: 67 | transform.append( 68 | torch_trans.Normalize(mean=args.data_params['MEAN_STD']['MEAN'], 69 | std=args.data_params['MEAN_STD']['STD'])) 70 | 71 | if len(transform) == 0: 72 | transform = None 73 | else: 74 | transform = torch_trans.Compose(transform) 75 | 76 | return transform 77 | 78 | 79 | # Update dataset 80 | def update_dataset_transformation(dataset, args, data_type, 81 | defense, crop): 82 | 83 | # only supported for TransformDataset at the moment 84 | assert isinstance(dataset, TransformDataset), ( 85 | "updating datase transformation is only supported for TransformDataset" 86 | "for adversaries") 87 | 88 | assert data_type is not 'train', \ 89 | "updating datase transformation is not supported in training" 90 | 91 | transform = setup_transformations(args, data_type, defense, crop) 92 | dataset.update_transformation(transform=transform) 93 | -------------------------------------------------------------------------------- /detection/ADV_Regression_Subspace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Oct 25 2018 3 | @author: Kimin Lee 4 | """ 5 | from __future__ import print_function 6 | import numpy as np 7 | import os 8 | import lib_regression 9 | import argparse 10 | import pdb 11 | from sklearn.linear_model import LogisticRegressionCV 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch code: Mahalanobis detector') 14 | parser.add_argument('--net_type', required=True, help='resnet | densenet') 15 | parser.add_argument('--outf', default='./adv_output/', help='folder to output results') 16 | args = parser.parse_args() 17 | print(args) 18 | 19 | def main(): 20 | # initial setup 21 | dataset_list = ['cifar10'] 22 | adv_test_list = ['BIM', 'FGSM', 'PGD', 'PGD-L2', 'CWL2'] 23 | print('evaluate the Mahalanobis estimator') 24 | score_list = ['Mahalanobis_0.0', 'Mahalanobis_0.01', 'Mahalanobis_0.005', \ 25 | 'Mahalanobis_0.002', 'Mahalanobis_0.0014', 'Mahalanobis_0.001', 'Mahalanobis_0.0005'] 26 | 27 | list_best_results_ours, list_best_results_index_ours = [], [] 28 | for dataset in dataset_list: 29 | print('load train data: ', dataset) 30 | outf = args.outf + args.net_type + '_' + dataset + '/' 31 | list_best_results_out, list_best_results_index_out = [], [] 32 | for out in adv_test_list: 33 | best_auroc, best_result, best_index = 0, 0, 0 34 | for score in score_list: 35 | print('load train data: ', out, ' of ', score) 36 | total_X, total_Y = lib_regression.load_characteristics(score, dataset, out, outf) 37 | X_val, Y_val, X_test, Y_test = lib_regression.block_split_adv(total_X, total_Y) 38 | pivot = int(X_val.shape[0] / 6) 39 | X_train = np.concatenate((X_val[:pivot], X_val[2*pivot:3*pivot], X_val[4*pivot:5*pivot])) 40 | Y_train = np.concatenate((Y_val[:pivot], Y_val[2*pivot:3*pivot], Y_val[4*pivot:5*pivot])) 41 | X_val_for_test = np.concatenate((X_val[pivot:2*pivot], X_val[3*pivot:4*pivot], X_val[5*pivot:])) 42 | Y_val_for_test = np.concatenate((Y_val[pivot:2*pivot], Y_val[3*pivot:4*pivot], Y_val[5*pivot:])) 43 | lr = LogisticRegressionCV(n_jobs=-1).fit(X_train, Y_train) 44 | y_pred = lr.predict_proba(X_train)[:, 1] 45 | #print('training mse: {:.4f}'.format(np.mean(y_pred - Y_train))) 46 | y_pred = lr.predict_proba(X_val_for_test)[:, 1] 47 | #print('test mse: {:.4f}'.format(np.mean(y_pred - Y_val_for_test))) 48 | results = lib_regression.detection_performance(lr, X_val_for_test, Y_val_for_test, outf) 49 | if best_auroc < results['TMP']['AUROC']: 50 | best_auroc = results['TMP']['AUROC'] 51 | best_index = score 52 | best_result = lib_regression.detection_performance(lr, X_test, Y_test, outf) 53 | list_best_results_out.append(best_result) 54 | list_best_results_index_out.append(best_index) 55 | list_best_results_ours.append(list_best_results_out) 56 | list_best_results_index_ours.append(list_best_results_index_out) 57 | 58 | count_in = 0 59 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 60 | print("results of Mahalanobis") 61 | for in_list in list_best_results_ours: 62 | print('in_distribution: ' + dataset_list[count_in] + '==========') 63 | count_out = 0 64 | for results in in_list: 65 | print('out_distribution: '+ adv_test_list[count_out]) 66 | for mtype in mtypes: 67 | print(' {mtype:6s}'.format(mtype=mtype), end='') 68 | print('\n{val:6.2f}'.format(val=100.*results['TMP']['TNR']), end='') 69 | print(' {val:6.2f}'.format(val=100.*results['TMP']['AUROC']), end='') 70 | print(' {val:6.2f}'.format(val=100.*results['TMP']['DTACC']), end='') 71 | print(' {val:6.2f}'.format(val=100.*results['TMP']['AUIN']), end='') 72 | print(' {val:6.2f}\n'.format(val=100.*results['TMP']['AUOUT']), end='') 73 | print('Input noise: ' + list_best_results_index_ours[count_in][count_out]) 74 | print('') 75 | count_out += 1 76 | count_in += 1 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /detection/lib/quilting_fast.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import ctypes 9 | import torch 10 | import random 11 | import numpy 12 | import os 13 | 14 | import pkgutil 15 | if pkgutil.find_loader("adversarial") is not None: 16 | # If adversarial module is created by pip install 17 | QUILTING_LIB = ctypes.cdll.LoadLibrary(os.path.join(os.path.dirname(__file__), "libquilting.so")) 18 | else: 19 | try: 20 | QUILTING_LIB = ctypes.cdll.LoadLibrary('libquilting.so') 21 | except ImportError: 22 | raise ImportError("libquilting.so not found. Check build script") 23 | 24 | 25 | def generate_patches(img, patch_size, overlap): 26 | assert torch.is_tensor(img) and img.dim() == 3 27 | assert type(patch_size) == int and patch_size > 0 28 | assert type(overlap) == int and overlap > 0 29 | assert patch_size > overlap 30 | 31 | y_range = range(0, img.size(1) - patch_size, patch_size - overlap) 32 | x_range = range(0, img.size(2) - patch_size, patch_size - overlap) 33 | num_patches = len(y_range) * len(x_range) 34 | patches = torch.FloatTensor(num_patches, 3 * patch_size * patch_size).zero_() 35 | 36 | QUILTING_LIB.generatePatches( 37 | ctypes.c_void_p(patches.data_ptr()), 38 | ctypes.c_void_p(img.data_ptr()), 39 | ctypes.c_uint(img.size(1)), 40 | ctypes.c_uint(img.size(2)), 41 | ctypes.c_uint(patch_size), 42 | ctypes.c_uint(overlap) 43 | ) 44 | 45 | return patches 46 | 47 | 48 | def generate_quilted_images(neighbors, patch_dict, img_h, img_w, patch_size, 49 | overlap, graphcut=False, random_stitch=False): 50 | assert torch.is_tensor(neighbors) and neighbors.dim() == 1 51 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 52 | assert type(img_h) == int and img_h > 0 53 | assert type(img_w) == int and img_w > 0 54 | assert type(patch_size) == int and patch_size > 0 55 | assert type(overlap) == int and overlap > 0 56 | assert patch_size > overlap 57 | 58 | result = torch.FloatTensor(3, img_h, img_w).zero_() 59 | 60 | QUILTING_LIB.generateQuiltedImages( 61 | ctypes.c_void_p(result.data_ptr()), 62 | ctypes.c_void_p(neighbors.data_ptr()), 63 | ctypes.c_void_p(patch_dict.data_ptr()), 64 | ctypes.c_uint(img_h), 65 | ctypes.c_uint(img_w), 66 | ctypes.c_uint(patch_size), 67 | ctypes.c_uint(overlap), 68 | ctypes.c_bool(graphcut) 69 | ) 70 | 71 | return result 72 | 73 | 74 | def select_random_neighbor(neighbors): 75 | if len(neighbors.shape) == 1: 76 | # If only 1 neighbor per path is available then return 77 | return neighbors 78 | else: 79 | # Pick a neighbor randomly from top k neighbors for all queries 80 | nrows = neighbors.shape[0] 81 | ncols = neighbors.shape[1] 82 | random_patched_neighbors = numpy.zeros(nrows).astype('int') 83 | for i in range(0, nrows): 84 | col = random.randint(0, ncols - 1) 85 | random_patched_neighbors[i] = neighbors[i, col] 86 | return random_patched_neighbors 87 | 88 | 89 | # main quilting function: 90 | def quilting(img, faiss_index, patch_dict, patch_size=9, overlap=2, 91 | graphcut=False, k=1, random_stitch=False): 92 | 93 | # assertions: 94 | assert torch.is_tensor(img) 95 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 96 | assert type(patch_size) == int and patch_size > 0 97 | assert type(overlap) == int and overlap > 0 98 | assert patch_size > overlap 99 | 100 | # generate image patches 101 | patches = generate_patches(img, patch_size, overlap) 102 | 103 | # find nearest patches in faiss index: 104 | faiss_index.nprobe = 5 105 | # get top k neighbors of all queries 106 | _, neighbors = faiss_index.search(patches.numpy(), k) 107 | neighbors = select_random_neighbor(neighbors) 108 | neighbors = torch.LongTensor(neighbors).squeeze() 109 | if (neighbors == -1).any(): 110 | print('WARNING: %d out of %d neighbor searches failed.' % 111 | ((neighbors == -1).sum(), neighbors.nelement())) 112 | 113 | # stitch nn patches in the dict 114 | quilted_img = generate_quilted_images(neighbors, patch_dict, img.size(1), 115 | img.size(2), patch_size, overlap, 116 | graphcut) 117 | 118 | return quilted_img 119 | -------------------------------------------------------------------------------- /detection/calculate_log.py: -------------------------------------------------------------------------------- 1 | ## Measure the detection performance - Kibok Lee 2 | from __future__ import print_function 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import torch.optim as optim 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import numpy as np 12 | import time 13 | from scipy import misc 14 | 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | 19 | def get_curve(dir_name, stypes = ['Baseline', 'Gaussian_LDA']): 20 | tp, fp = dict(), dict() 21 | tnr_at_tpr95 = dict() 22 | for stype in stypes: 23 | known = np.loadtxt('{}/confidence_{}_In.txt'.format(dir_name, stype), delimiter='\n') 24 | novel = np.loadtxt('{}/confidence_{}_Out.txt'.format(dir_name, stype), delimiter='\n') 25 | known.sort() 26 | novel.sort() 27 | end = np.max([np.max(known), np.max(novel)]) 28 | start = np.min([np.min(known),np.min(novel)]) 29 | num_k = known.shape[0] 30 | num_n = novel.shape[0] 31 | tp[stype] = -np.ones([num_k+num_n+1], dtype=int) 32 | fp[stype] = -np.ones([num_k+num_n+1], dtype=int) 33 | tp[stype][0], fp[stype][0] = num_k, num_n 34 | k, n = 0, 0 35 | for l in range(num_k+num_n): 36 | if k == num_k: 37 | tp[stype][l+1:] = tp[stype][l] 38 | fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1) 39 | break 40 | elif n == num_n: 41 | tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1) 42 | fp[stype][l+1:] = fp[stype][l] 43 | break 44 | else: 45 | if novel[n] < known[k]: 46 | n += 1 47 | tp[stype][l+1] = tp[stype][l] 48 | fp[stype][l+1] = fp[stype][l] - 1 49 | else: 50 | k += 1 51 | tp[stype][l+1] = tp[stype][l] - 1 52 | fp[stype][l+1] = fp[stype][l] 53 | tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin() 54 | tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n 55 | return tp, fp, tnr_at_tpr95 56 | 57 | def metric(dir_name, stypes = ['Bas', 'Gau'], verbose=False): 58 | tp, fp, tnr_at_tpr95 = get_curve(dir_name, stypes) 59 | results = dict() 60 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 61 | if verbose: 62 | print(' ', end='') 63 | for mtype in mtypes: 64 | print(' {mtype:6s}'.format(mtype=mtype), end='') 65 | print('') 66 | 67 | for stype in stypes: 68 | if verbose: 69 | print('{stype:5s} '.format(stype=stype), end='') 70 | results[stype] = dict() 71 | 72 | # TNR 73 | mtype = 'TNR' 74 | results[stype][mtype] = tnr_at_tpr95[stype] 75 | if verbose: 76 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 77 | 78 | # AUROC 79 | mtype = 'AUROC' 80 | tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]]) 81 | fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]]) 82 | results[stype][mtype] = -np.trapz(1.-fpr, tpr) 83 | if verbose: 84 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 85 | 86 | # DTACC 87 | mtype = 'DTACC' 88 | results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max() 89 | if verbose: 90 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 91 | 92 | # AUIN 93 | mtype = 'AUIN' 94 | denom = tp[stype]+fp[stype] 95 | denom[denom == 0.] = -1. 96 | pin_ind = np.concatenate([[True], denom > 0., [True]]) 97 | pin = np.concatenate([[.5], tp[stype]/denom, [0.]]) 98 | results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind]) 99 | if verbose: 100 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 101 | 102 | # AUOUT 103 | mtype = 'AUOUT' 104 | denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype] 105 | denom[denom == 0.] = -1. 106 | pout_ind = np.concatenate([[True], denom > 0., [True]]) 107 | pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]]) 108 | results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind]) 109 | if verbose: 110 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 111 | print('') 112 | 113 | return results -------------------------------------------------------------------------------- /detection/lib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import os 14 | import tempfile 15 | 16 | import torch 17 | 18 | # constants: 19 | CHECKPOINT_FILE = 'checkpoint.torch' 20 | 21 | 22 | # function that measures top-k accuracy: 23 | def accuracy(output, target, topk=(1,)): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100. / batch_size)) 33 | return res 34 | 35 | 36 | # function that tries to load a checkpoint: 37 | def load_checkpoint(checkpoint_folder): 38 | 39 | # read what the latest model file is: 40 | filename = os.path.join(checkpoint_folder, CHECKPOINT_FILE) 41 | if not os.path.exists(filename): 42 | return None 43 | 44 | # load and return the checkpoint: 45 | return torch.load(filename) 46 | 47 | 48 | # function that saves checkpoint: 49 | def save_checkpoint(checkpoint_folder, state): 50 | 51 | # make sure that we have a checkpoint folder: 52 | if not os.path.isdir(checkpoint_folder): 53 | try: 54 | os.makedirs(checkpoint_folder) 55 | except BaseException: 56 | print('| WARNING: could not create directory %s' % checkpoint_folder) 57 | if not os.path.isdir(checkpoint_folder): 58 | return False 59 | 60 | # write checkpoint atomically: 61 | try: 62 | with tempfile.NamedTemporaryFile( 63 | 'w', dir=checkpoint_folder, delete=False) as fwrite: 64 | tmp_filename = fwrite.name 65 | torch.save(state, fwrite.name) 66 | os.rename(tmp_filename, os.path.join(checkpoint_folder, CHECKPOINT_FILE)) 67 | return True 68 | except BaseException: 69 | print('| WARNING: could not write checkpoint to %s.' % checkpoint_folder) 70 | return False 71 | 72 | 73 | # function that adjusts the learning rate: 74 | def adjust_learning_rate(base_lr, epoch, optimizer, lr_decay, lr_decay_stepsize): 75 | lr = base_lr * (lr_decay ** (epoch // lr_decay_stepsize)) 76 | for param_group in optimizer.param_groups: 77 | param_group['lr'] = lr 78 | 79 | 80 | # adversary functions 81 | # computes SSIM for a single block 82 | def SSIM(x, y): 83 | x = x.resize_(x.size(0), x.size(1) * x.size(2) * x.size(3)) 84 | y = y.resize_(y.size(0), y.size(1) * y.size(2) * y.size(3)) 85 | N = x.size(1) 86 | mu_x = x.mean(1) 87 | mu_y = y.mean(1) 88 | sigma_x = x.std(1) 89 | sigma_y = y.std(1) 90 | sigma_xy = ((x - mu_x.expand_as(x)) * (y - mu_y.expand_as(y))).sum(1) / (N - 1) 91 | ssim = (2 * mu_x * mu_y) * (2 * sigma_xy) 92 | ssim = ssim / (mu_x.pow(2) + mu_y.pow(2)) 93 | ssim = ssim / (sigma_x.pow(2) + sigma_y.pow(2)) 94 | return ssim 95 | 96 | 97 | # mean SSIM using local block averaging 98 | def MSSIM(x, y, window_size=16, stride=4): 99 | ssim = torch.zeros(x.size(0)) 100 | L = x.size(2) 101 | W = x.size(3) 102 | x_inds = torch.arange(0, L - window_size + 1, stride).long() 103 | y_inds = torch.arange(0, W - window_size + 1, stride).long() 104 | for i in x_inds: 105 | for j in y_inds: 106 | x_sub = x[:, :, i:(i + window_size), j:(j + window_size)] 107 | y_sub = y[:, :, i:(i + window_size), j:(j + window_size)] 108 | ssim = ssim + SSIM(x_sub, y_sub) 109 | return ssim / x_inds.size(0) / y_inds.size(0) 110 | 111 | 112 | # forwards input through model to get probabilities 113 | def get_probs(model, vae, imgs, output_prob=False): 114 | softmax = torch.nn.Softmax(1) 115 | # probs = torch.zeros(imgs.size(0), n_classes) 116 | imgsvar = torch.autograd.Variable(imgs.squeeze()) 117 | with torch.no_grad(): 118 | output = model(imgsvar - vae(imgsvar)) 119 | if output_prob: 120 | probs = output.data.cpu() 121 | else: 122 | probs = softmax.forward(output).data.cpu() 123 | 124 | return probs 125 | 126 | 127 | # calls get_probs to get predictions 128 | def get_labels(model, vae, input, output_prob=False): 129 | probs = get_probs(model, vae, input, output_prob) 130 | _, label = probs.max(1) 131 | return label.squeeze() 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CD-VAE 2 | 3 | Official implementation: 4 | - Class-Disentanglement and Applications in Adversarial Detection and Defense, NeurIPS 2021. ([Paper](https://openreview.net/pdf?id=jFMzBeLyTc0)) 5 | 6 |
7 | 8 |

CD-VAE

9 |
10 | 11 | For any questions, contact (kwyang@mail.ustc.edu.cn). 12 | 13 | ## Requirements 14 | 15 | 1. [Python](https://www.python.org/) 16 | 2. [Pytorch](https://pytorch.org/) 17 | 3. [Wandb](https://wandb.ai/site) 18 | 4. [Torchvision](https://pytorch.org/vision/stable/index.html) 19 | 5. [Perceptual-advex](https://github.com/cassidylaidlaw/perceptual-advex) 20 | 6. [Robustness](https://github.com/MadryLab/robustness) 21 | 22 | ## Pretrained Models 23 | ``` 24 | cd CD-VAE 25 | mkdir pretrained 26 | ``` 27 | Download pretrained models and put them in directory ./pretrained 28 | 1. [cd-vae-1](https://drive.google.com/file/d/1I2yuYQGEYRgqd1oQazq6goDbU2nwUvU_/view?usp=sharing) (for adversarial detection) 29 | 2. [cd-vae-2](https://drive.google.com/file/d/1b-S2BvHhV79t89M1oINyKvn_BaxYprib/view?usp=sharing) (for initializing adversarial training model) 30 | 3. [wide_resnet](https://drive.google.com/file/d/1Lycbl4BUTxBzfTsLjj-m-_jnpDl8pCcP/view?usp=sharing) (trained on clean data x, for initializing adversarial training model) 31 | 32 | ## Part 1. Class-Disentangled VAE 33 | Train a class-disentangled VAE, which is the basis of adversarial detection and defense. 34 | ``` 35 | cd CD-VAE 36 | python tools/disentangle_cifar.py --save_dir results/disentangle_cifar_ce0.2 --ce 0.2 --optim cosine 37 | ``` 38 | * **--ce** (float): Weight of the cross-entropy loss, i.e., gamma in the paper. You can try different values of it (e.g., ce=0.02, 0.2, 2) to control the reconstruction-classification trade-off. 39 | * **--save_dir** (str): Directory to save the model checkpoint and training log. 40 | * **--optim** (str): Scheduler of learning rate, we support cosine decay and stage decay now. 41 | 42 | ## Part 2. Adversarial Detection 43 | It needs a CD-VAE model for the adversarial Detection. You can use the pretrained CD-VAE or train a new one by yourself as shown in part 1. 44 | ``` 45 | cd CD-VAE/detection 46 | ``` 47 | Generate Adversarial Example: 48 | ``` 49 | python ADV_Samples_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 50 | ``` 51 | Compute Mahalanobis Distance: 52 | ``` 53 | python ADV_Generate_Mahalanobis_Subspace.py --dataset cifar10 --net_type resnet --adv_type PGD --gpu 0 --outf ./data/cd-vae-1/ --vae_path ../pretrained/cd-vae-1.pth; 54 | ``` 55 | Evaluate the Mahalanobis Estimator 56 | ``` 57 | python ADV_Regression_Subspace.py --net_type resnet --outf ./data/cd-vae-1/; 58 | ``` 59 | * **--adv_type** (str): Adversarial attack, e.g., FGSM, BIM, PGD, PGD-L2, CW. 60 | * **--outf** (str): Directory to save data and results. 61 | * **--vae_path** (str): CD-VAE checkpoint. 62 | 63 | ## Part 3. White-box Adversarial Defense 64 | Modified adversarial training based on CD-VAE(it needs a CD-VAE model and a model trained on clean data x to initialize): 65 | ``` 66 | cd CD-VAE 67 | python tools/adv_train_cifar.py --batch_size 100 --lr 1 --cr 0.1 --cg 0.1 --margin 20 --save_dir ./results/defense_0.1_0.1 68 | ``` 69 | * **--cr, --cg** (float): Weight of the cross-entropy loss, i.e., gamma in the paper. 70 | * **--lr** (float): Learning rate. 71 | * **--save_dir** (float): Directory to save checkpoints and log. 72 | 73 | Evaluation of the trained model against various white-box attack: 74 | ``` 75 | python tools/adv_test_cifar.py --model_path ./results/defense_0.1_0.1/robust_model_g_epoch92.pth --vae_path ./results/defense_0.1_0.1/robust_vae_epoch92.pth --batch_size 256 \ 76 | "NoAttack()" \ 77 | "AutoLinfAttack(cd_vae, 'cifar', bound=8/255)" \ 78 | "AutoL2Attack(cd_vae, 'cifar', bound=1.0)" \ 79 | "JPEGLinfAttack(cd_vae, 'cifar', bound=0.125, num_iterations=100)" \ 80 | "StAdvAttack(cd_vae, num_iterations=100)" \ 81 | "ReColorAdvAttack(cd_vae, num_iterations=100)" 82 | ``` 83 | 84 | ## References 85 | The code of detection part is based on https://github.com/pokaxpoka/deep_Mahalanobis_detector. 86 | 87 | The code of defense part refers to https://github.com/cassidylaidlaw/perceptual-advex and https://github.com/MadryLab/robustness. 88 | 89 | ## Citation 90 | 91 | If you find this repo useful for your research, please consider citing the paper 92 | ``` 93 | @article{yang2021class, 94 | title={Class-Disentanglement and Applications in Adversarial Detection and Defense}, 95 | author={Yang, Kaiwen and Zhou, Tianyi and Tian, Xinmei and Tao, Dacheng and others}, 96 | journal={Advances in Neural Information Processing Systems}, 97 | volume={34}, 98 | year={2021} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /detection/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | # function to extact the multiple features 86 | def feature_list(self, x): 87 | out_list = [] 88 | out = self.conv1(x) 89 | out_list.append(out) 90 | out = self.layer1(out) 91 | out_list.append(out) 92 | out = self.layer2(out) 93 | out_list.append(out) 94 | out = self.layer3(out) 95 | out_list.append(out) 96 | 97 | out = F.relu(self.bn1(out)) 98 | out = F.avg_pool2d(out, 8) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out, out_list 102 | 103 | # function to extact a specific feature 104 | def intermediate_forward(self, x, layer_index): 105 | out = self.conv1(x) 106 | if layer_index == 1: 107 | out = self.layer1(out) 108 | elif layer_index == 2: 109 | out = self.layer1(out) 110 | out = self.layer2(out) 111 | elif layer_index == 3: 112 | out = self.layer1(out) 113 | out = self.layer2(out) 114 | out = self.layer3(out) 115 | return out 116 | 117 | # function to extact the penultimate features 118 | def penultimate_forward(self, x): 119 | out = self.conv1(x) 120 | out = self.layer1(out) 121 | out = self.layer2(out) 122 | penultimate = self.layer3(out) 123 | 124 | out = F.relu(self.bn1(out)) 125 | out = F.avg_pool2d(out, 8) 126 | out = out.view(out.size(0), -1) 127 | y = self.linear(out) 128 | 129 | return y, penultimate 130 | 131 | 132 | if __name__ == '__main__': 133 | net=Wide_ResNet(28, 10, 0.3, 10) 134 | y = net(Variable(torch.randn(1,3,32,32))) 135 | 136 | print(y.size()) 137 | -------------------------------------------------------------------------------- /detection/data_loader.py: -------------------------------------------------------------------------------- 1 | # original code is from https://github.com/aaron-xichen/pytorch-playground 2 | # modified by Kimin Lee 3 | import torch 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import DataLoader 6 | import os 7 | 8 | def getCIFAR10(batch_size, TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 9 | kwargs.pop('input_size', None) 10 | ds = [] 11 | if train: 12 | train_loader = torch.utils.data.DataLoader( 13 | datasets.CIFAR10( 14 | root=data_root, train=True, download=True, 15 | transform=TF), 16 | batch_size=batch_size, shuffle=True, **kwargs) 17 | ds.append(train_loader) 18 | if val: 19 | test_loader = torch.utils.data.DataLoader( 20 | datasets.CIFAR10( 21 | root=data_root, train=False, download=True, 22 | transform=TF), 23 | batch_size=batch_size, shuffle=False, **kwargs) 24 | ds.append(test_loader) 25 | ds = ds[0] if len(ds) == 1 else ds 26 | return ds 27 | 28 | def getCIFAR100(batch_size, TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 29 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data')) 30 | num_workers = kwargs.setdefault('num_workers', 1) 31 | kwargs.pop('input_size', None) 32 | ds = [] 33 | if train: 34 | train_loader = torch.utils.data.DataLoader( 35 | datasets.CIFAR100( 36 | root=data_root, train=True, download=True, 37 | transform=TF), 38 | batch_size=batch_size, shuffle=True, **kwargs) 39 | ds.append(train_loader) 40 | 41 | if val: 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.CIFAR100( 44 | root=data_root, train=False, download=True, 45 | transform=TF), 46 | batch_size=batch_size, shuffle=False, **kwargs) 47 | ds.append(test_loader) 48 | ds = ds[0] if len(ds) == 1 else ds 49 | return ds 50 | 51 | def getIMAGENET(batch_size, TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 52 | traindir = os.path.join(data_root, 'train') 53 | valdir = os.path.join(data_root, 'val') 54 | num_workers = kwargs.setdefault('num_workers', 1) 55 | kwargs.pop('input_size', None) 56 | ds = [] 57 | label_map = get_label_mapping('restricted_imagenet', 58 | constants.RESTRICTED_IMAGNET_RANGES) 59 | if train: 60 | train_dataset = folder.ImageFolder(root=traindir, 61 | transform=TF, 62 | label_mapping=label_map) 63 | train_loader = torch.utils.data.DataLoader( 64 | train_dataset, batch_size=batch_size, shuffle=True, 65 | num_workers=num_workers, pin_memory=True) 66 | ds.append(train_loader) 67 | if val: 68 | val_dataset = folder.ImageFolder(root=valdir, 69 | transform=TF, 70 | label_mapping=label_map) 71 | val_loader = torch.utils.data.DataLoader( 72 | val_dataset, 73 | batch_size=batch_size, shuffle=False, 74 | num_workers=num_workers, pin_memory=True, 75 | ) 76 | ds.append(val_loader) 77 | ds = ds[0] if len(ds) == 1 else ds 78 | return ds 79 | 80 | def getTargetDataSet(data_type, batch_size, input_TF, dataroot): 81 | if data_type == 'cifar10': 82 | train_loader, test_loader = getCIFAR10(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1) 83 | elif data_type == 'imagenet': 84 | train_loader, test_loader = getIMAGENET(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1) 85 | 86 | 87 | return train_loader, test_loader 88 | 89 | def getNonTargetDataSet(data_type, batch_size, input_TF, dataroot): 90 | if data_type == 'cifar10': 91 | _, test_loader = getCIFAR10(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1) 92 | elif data_type == 'svhn': 93 | _, test_loader = getSVHN(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1) 94 | elif data_type == 'cifar100': 95 | _, test_loader = getCIFAR100(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1) 96 | elif data_type == 'imagenet_resize': 97 | dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_resize')) 98 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 99 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 100 | elif data_type == 'lsun_resize': 101 | dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_resize')) 102 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 103 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 104 | return test_loader 105 | 106 | 107 | -------------------------------------------------------------------------------- /detection/lib/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | # from PIL import Image 7 | import torchvision.transforms as torch_trans 8 | import random 9 | from torch import is_tensor 10 | 11 | CROP_TYPE = ['center', 'random', 'sliding'] 12 | 13 | 14 | class Crop(object): 15 | """Crops the given img tensor. 16 | Args: 17 | size (sequence or int): Desired output size of the crop. If size is an 18 | int instead of sequence like (h, w), a square crop (size, size) is 19 | made. 20 | crop_frac: crop fraction to crop from the image 21 | """ 22 | 23 | def __init__(self, crop_type=None, crop_frac=1.0, 24 | sliding_crop_position=None): 25 | assert crop_frac <= 1.0, \ 26 | "crop_frac can't be greater than 1.0" 27 | if sliding_crop_position is not None: 28 | # max positions are fixed to 9 29 | assert sliding_crop_position < 9 30 | 31 | assert (crop_type is None or crop_type in CROP_TYPE), ( 32 | "{} is not a valid crop_type".format(crop_type)) 33 | 34 | self.crop_type = crop_type 35 | self.crop_frac = crop_frac 36 | self.sliding_crop_position = sliding_crop_position 37 | 38 | def __call__(self, img): 39 | """ 40 | Args: 41 | img (Tensor): Image to be cropped. 42 | Returns: 43 | """ 44 | assert img is not None, "img should not be None" 45 | assert is_tensor(img), "Tensor expected" 46 | h = img.size(1) 47 | w = img.size(2) 48 | h2 = int(h * self.crop_frac) 49 | w2 = int(w * self.crop_frac) 50 | h_range = h - h2 51 | w_range = w - w2 52 | 53 | if self.crop_type == 'sliding': 54 | assert self.sliding_crop_position is not None 55 | row = int(self.sliding_crop_position / 3) 56 | col = self.sliding_crop_position % 3 57 | x = col * int(w_range / 2) 58 | y = row * int(h_range / 2) 59 | 60 | elif self.crop_type == 'random': 61 | x, y = random.randint(0, w_range), random.randint(0, h_range) 62 | 63 | elif self.crop_type == 'center': 64 | y = int(h_range / 2) 65 | x = int(w_range / 2) 66 | 67 | if self.crop_type is not None: 68 | img = img.narrow(1, y, h2).narrow(2, x, w2).clone() 69 | 70 | return img 71 | 72 | def update_sliding_position(self, sliding_crop_position): 73 | assert sliding_crop_position >= 0 and sliding_crop_position < 9, \ 74 | "Only 9 sliding positions supported" 75 | self.sliding_crop_position = sliding_crop_position 76 | 77 | 78 | class Scale(object): 79 | """Scale the given img tensor. 80 | Args: 81 | size (sequence or int): Desired output size of the crop. If size is an 82 | int instead of sequence like (h, w), a square crop (size, size) is 83 | made. 84 | """ 85 | def __init__(self, size, mean_std=None): 86 | 87 | if mean_std is not None: 88 | assert 'MEAN' in mean_std 89 | assert 'STD' in mean_std 90 | self.size = size 91 | self.mean_std = mean_std 92 | 93 | def __call__(self, img): 94 | """ 95 | Args: 96 | img (Tensor): Image to be cropped. 97 | Returns: 98 | 99 | """ 100 | assert img is not None, "img should not be None" 101 | assert is_tensor(img), "Tensor expected" 102 | 103 | if not img.size(1) == self.size: 104 | # TODO: We should not need to Unnormalize for scaling(validate if its true) 105 | if self.mean_std: 106 | img = Unnormalize(mean=self.mean_std['MEAN'], 107 | std=self.mean_std['STD'])(img) 108 | img = torch_trans.ToPILImage()(img) 109 | img = torch_trans.Scale(self.size)(img) 110 | img = torch_trans.ToTensor()(img) 111 | if self.mean_std: 112 | img = torch_trans.Normalize(mean=self.mean_std['MEAN'], 113 | std=self.mean_std['STD'])(img) 114 | 115 | return img 116 | 117 | 118 | class Unnormalize(object): 119 | def __init__(self, mean, std): 120 | self.mean = mean 121 | self.std = std 122 | 123 | def __call__(self, imgs): 124 | assert imgs is not None, "img should not be None" 125 | assert is_tensor(imgs), "Tensor expected" 126 | imgs_trans = imgs.clone() 127 | if len(imgs.size()) == 3: 128 | for i in range(imgs.size(0)): 129 | imgs_trans[i, :, :] = imgs_trans[i, :, :] * self.std[i] + self.mean[i] 130 | else: 131 | for i in range(imgs.size(1)): 132 | imgs_trans[:, i, :, :] = ((imgs_trans[:, i, :, :] * self.std[i]) + 133 | self.mean[i]) 134 | return imgs_trans 135 | 136 | 137 | class Normalize(object): 138 | def __init__(self, mean, std): 139 | self.mean = mean 140 | self.std = std 141 | 142 | def __call__(self, imgs): 143 | assert imgs is not None, "img should not be None" 144 | assert is_tensor(imgs), "Tensor expected" 145 | imgs_trans = imgs.clone() 146 | if len(imgs.size()) == 3: 147 | for i in range(imgs.size(0)): 148 | imgs_trans[i, :, :] = (imgs_trans[i, :, :] - self.mean[i]) / self.std[i] 149 | else: 150 | for i in range(imgs.size(1)): 151 | imgs_trans[:, i, :, :] = ((imgs_trans[:, i, :, :] - self.mean[i]) / 152 | self.std[i]) 153 | return imgs_trans 154 | -------------------------------------------------------------------------------- /networks/nearest_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function, Variable 5 | import torch.nn.functional as F 6 | 7 | 8 | class NearestEmbedFunc(Function): 9 | """ 10 | Input: 11 | ------ 12 | x - (batch_size, emb_dim, *) 13 | Last dimensions may be arbitrary 14 | emb - (emb_dim, num_emb) 15 | """ 16 | @staticmethod 17 | def forward(ctx, input, emb): 18 | if input.size(1) != emb.size(0): 19 | raise RuntimeError('invalid argument: input.size(1) ({}) must be equal to emb.size(0) ({})'. 20 | format(input.size(1), emb.size(0))) 21 | 22 | # save sizes for backward 23 | ctx.batch_size = input.size(0) 24 | ctx.num_latents = int(np.prod(np.array(input.size()[2:]))) 25 | ctx.emb_dim = emb.size(0) 26 | ctx.num_emb = emb.size(1) 27 | ctx.input_type = type(input) 28 | ctx.dims = list(range(len(input.size()))) 29 | 30 | # expand to be broadcast-able 31 | x_expanded = input.unsqueeze(-1) 32 | num_arbitrary_dims = len(ctx.dims) - 2 33 | if num_arbitrary_dims: 34 | emb_expanded = emb.view(emb.shape[0], *([1] * num_arbitrary_dims), emb.shape[1]) 35 | else: 36 | emb_expanded = emb 37 | 38 | # find nearest neighbors 39 | dist = torch.norm(x_expanded - emb_expanded, 2, 1) 40 | _, argmin = dist.min(-1) 41 | shifted_shape = [input.shape[0], *list(input.shape[2:]) ,input.shape[1]] 42 | result = emb.t().index_select(0, argmin.view(-1)).view(shifted_shape).permute(0, ctx.dims[-1], *ctx.dims[1:-1]) 43 | 44 | ctx.save_for_backward(argmin) 45 | return result.contiguous(), argmin 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output, argmin=None): 49 | grad_input = grad_emb = None 50 | if ctx.needs_input_grad[0]: 51 | grad_input = grad_output 52 | 53 | if ctx.needs_input_grad[1]: 54 | argmin, = ctx.saved_variables 55 | latent_indices = torch.arange(ctx.num_emb).type_as(argmin) 56 | idx_choices = (argmin.view(-1, 1) == latent_indices.view(1, -1)).type_as(grad_output.data) 57 | n_idx_choice = idx_choices.sum(0) 58 | n_idx_choice[n_idx_choice == 0] = 1 59 | idx_avg_choices = idx_choices / n_idx_choice 60 | grad_output = grad_output.permute(0, *ctx.dims[2:], 1).contiguous() 61 | grad_output = grad_output.view(ctx.batch_size * ctx.num_latents, ctx.emb_dim) 62 | grad_emb = torch.sum(grad_output.data.view(-1, ctx.emb_dim, 1) * 63 | idx_avg_choices.view(-1, 1, ctx.num_emb), 0) 64 | return grad_input, grad_emb, None, None 65 | 66 | 67 | def nearest_embed(x, emb): 68 | return NearestEmbedFunc().apply(x, emb) 69 | 70 | 71 | class NearestEmbed(nn.Module): 72 | def __init__(self, num_embeddings, embeddings_dim): 73 | super(NearestEmbed, self).__init__() 74 | self.weight = nn.Parameter(torch.rand(embeddings_dim, num_embeddings)) 75 | 76 | def forward(self, x, weight_sg=False): 77 | """Input: 78 | --------- 79 | x - (batch_size, emb_size, *) 80 | """ 81 | return nearest_embed(x, self.weight.detach() if weight_sg else self.weight) 82 | 83 | 84 | # adapted from https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py#L25 85 | # that adapted from https://github.com/deepmind/sonnet 86 | 87 | 88 | class NearestEmbedEMA(nn.Module): 89 | def __init__(self, n_emb, emb_dim, decay=0.99, eps=1e-5): 90 | super(NearestEmbedEMA, self).__init__() 91 | self.decay = decay 92 | self.eps = eps 93 | self.embeddings_dim = emb_dim 94 | self.n_emb = n_emb 95 | self.emb_dim = emb_dim 96 | embed = torch.rand(emb_dim, n_emb) 97 | self.register_buffer('weight', embed) 98 | self.register_buffer('cluster_size', torch.zeros(n_emb)) 99 | self.register_buffer('embed_avg', embed.clone()) 100 | 101 | def forward(self, x): 102 | """Input: 103 | --------- 104 | x - (batch_size, emb_size, *) 105 | """ 106 | 107 | dims = list(range(len(x.size()))) 108 | x_expanded = x.unsqueeze(-1) 109 | num_arbitrary_dims = len(dims) - 2 110 | if num_arbitrary_dims: 111 | emb_expanded = self.weight.view(self.emb_dim, *([1] * num_arbitrary_dims), self.n_emb) 112 | else: 113 | emb_expanded = self.weight 114 | 115 | # find nearest neighbors 116 | dist = torch.norm(x_expanded - emb_expanded, 2, 1) 117 | _, argmin = dist.min(-1) 118 | shifted_shape = [x.shape[0], *list(x.shape[2:]), x.shape[1]] 119 | result = self.weight.t().index_select(0, argmin.view(-1)).view(shifted_shape).permute(0, dims[-1], *dims[1:-1]) 120 | 121 | if self.training: 122 | latent_indices = torch.arange(self.n_emb).type_as(argmin) 123 | emb_onehot = (argmin.view(-1, 1) == latent_indices.view(1, -1)).type_as(x.data) 124 | n_idx_choice = emb_onehot.sum(0) 125 | n_idx_choice[n_idx_choice == 0] = 1 126 | flatten = x.permute(1, 0, *dims[-2:]).contiguous().view(x.shape[1], -1) 127 | 128 | self.cluster_size.data.mul_(self.decay).add_( 129 | 1 - self.decay, n_idx_choice 130 | ) 131 | embed_sum = flatten @ emb_onehot 132 | self.embed_avg.data.mul_(self.decay).add_(1 - self.decay, embed_sum) 133 | 134 | n = self.cluster_size.sum() 135 | cluster_size = ( 136 | (self.cluster_size + self.eps) / (n + self.n_emb * self.eps) * n 137 | ) 138 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) 139 | self.weight.data.copy_(embed_normalized) 140 | 141 | return result, argmin 142 | -------------------------------------------------------------------------------- /utils/randAugment.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import random 3 | import PIL.ImageOps 4 | import PIL.ImageEnhance 5 | import PIL.ImageDraw 6 | from PIL import Image 7 | import numpy as np 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class myRandAugment(transforms.RandomOrder): 12 | """ Apply randomly N transformations from: 13 | - rotation, transX, transY, shearXY, brightness, contrast, saturation, hue 14 | 15 | """ 16 | 17 | def __init__(self, N, L=None): 18 | self.transforms = [ 19 | transforms.RandomAffine(30, translate=None, shear=None), 20 | transforms.RandomAffine(0, translate=(0.2, 0.0), shear=None), 21 | transforms.RandomAffine(0, translate=(0.0, 0.2), shear=None), 22 | transforms.RandomAffine(0, translate=None, shear=30), 23 | transforms.ColorJitter( 24 | brightness=0.5, contrast=0, saturation=0, hue=0), 25 | transforms.ColorJitter( 26 | brightness=0, contrast=0.5, saturation=0, hue=0), 27 | transforms.ColorJitter( 28 | brightness=0, contrast=0, saturation=0.5, hue=0), 29 | transforms.ColorJitter( 30 | brightness=0, contrast=0, saturation=0, hue=0.5) 31 | ] 32 | self.N = N 33 | 34 | def __call__(self, img): 35 | order = list(range(len(self.transforms))) 36 | random.shuffle(order) 37 | for i in order[:self.N]: 38 | img = self.transforms[i](img) 39 | return img 40 | 41 | 42 | def ShearX(img, v): 43 | assert -0.3 <= v <= 0.3 44 | if random.random() > 0.5: 45 | v = -v 46 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 47 | 48 | 49 | def ShearY(img, v): 50 | assert -0.3 <= v <= 0.3 51 | if random.random() > 0.5: 52 | v = -v 53 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 54 | 55 | 56 | def TranslateX(img, v): 57 | assert -0.45 <= v <= 0.45 58 | if random.random() > 0.5: 59 | v = -v 60 | v = v * img.size[0] 61 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 62 | 63 | 64 | def TranslateXabs(img, v): 65 | assert 0 <= v 66 | if random.random() > 0.5: 67 | v = -v 68 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 69 | 70 | 71 | def TranslateY(img, v): 72 | assert -0.45 <= v <= 0.45 73 | if random.random() > 0.5: 74 | v = -v 75 | v = v * img.size[1] 76 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 77 | 78 | 79 | def TranslateYabs(img, v): 80 | assert 0 <= v 81 | if random.random() > 0.5: 82 | v = -v 83 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 84 | 85 | 86 | def Rotate(img, v): 87 | assert -30 <= v <= 30 88 | if random.random() > 0.5: 89 | v = -v 90 | return img.rotate(v) 91 | 92 | 93 | def AutoContrast(img, _): 94 | return PIL.ImageOps.autocontrast(img) 95 | 96 | 97 | def Invert(img, _): 98 | return PIL.ImageOps.invert(img) 99 | 100 | 101 | def Equalize(img, _): 102 | return PIL.ImageOps.equalize(img) 103 | 104 | 105 | def Flip(img, _): 106 | return PIL.ImageOps.mirror(img) 107 | 108 | 109 | def Solarize(img, v): 110 | assert 0 <= v <= 256 111 | return PIL.ImageOps.solarize(img, v) 112 | 113 | 114 | def SolarizeAdd(img, addition=0, threshold=128): 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + addition 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def Posterize(img, v): 124 | v = int(v) 125 | v = max(1, v) 126 | return PIL.ImageOps.posterize(img, v) 127 | 128 | 129 | def Contrast(img, v): 130 | assert 0.1 <= v <= 1.9 131 | return PIL.ImageEnhance.Contrast(img).enhance(v) 132 | 133 | 134 | def Color(img, v): 135 | assert 0.1 <= v <= 1.9 136 | return PIL.ImageEnhance.Color(img).enhance(v) 137 | 138 | 139 | def Brightness(img, v): 140 | assert 0.1 <= v <= 1.9 141 | return PIL.ImageEnhance.Brightness(img).enhance(v) 142 | 143 | 144 | def Sharpness(img, v): 145 | assert 0.1 <= v <= 1.9 146 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 147 | 148 | 149 | def CutoutAbs(img, v): 150 | if v < 0: 151 | return img 152 | w, h = img.size 153 | x0 = np.random.uniform(w) 154 | y0 = np.random.uniform(h) 155 | 156 | x0 = int(max(0, x0 - v / 2.)) 157 | y0 = int(max(0, y0 - v / 2.)) 158 | x1 = min(w, x0 + v) 159 | y1 = min(h, y0 + v) 160 | 161 | xy = (x0, y0, x1, y1) 162 | color = (125, 123, 114) 163 | img = img.copy() 164 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 165 | return img 166 | 167 | 168 | def augment_list(): 169 | l = [ 170 | (AutoContrast, 0, 1), 171 | (Equalize, 0, 1), 172 | (Invert, 0, 1), 173 | (Rotate, 0, 30), 174 | (Posterize, 0, 4), 175 | (Solarize, 0, 256), 176 | (SolarizeAdd, 0, 110), 177 | (Color, 0.1, 1.9), 178 | (Contrast, 0.1, 1.9), 179 | (Brightness, 0.1, 1.9), 180 | (Sharpness, 0.1, 1.9), 181 | (ShearX, 0., 0.3), 182 | (ShearY, 0., 0.3), 183 | (CutoutAbs, 0, 16), 184 | (TranslateXabs, 0., 16), 185 | (TranslateYabs, 0., 16), 186 | ] 187 | return l 188 | 189 | 190 | class RandAugment: 191 | def __init__(self, n, m): 192 | self.n = n 193 | self.m = m 194 | self.augment_list = augment_list() 195 | 196 | def __call__(self, img): 197 | ops = random.choices(self.augment_list, k=self.n) 198 | for op, minval, maxval in ops: 199 | val = (float(self.m) / 30) * float(maxval - minval) + minval 200 | img = op(img, val) 201 | return img 202 | -------------------------------------------------------------------------------- /detection/models/vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import abc 3 | import os 4 | import math 5 | 6 | import numpy as np 7 | import logging 8 | import torch 9 | import torch.utils.data 10 | from torch import nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | 15 | class AbstractAutoEncoder(nn.Module): 16 | __metaclass__ = abc.ABCMeta 17 | 18 | @abc.abstractmethod 19 | def encode(self, x): 20 | return 21 | 22 | @abc.abstractmethod 23 | def decode(self, z): 24 | return 25 | 26 | @abc.abstractmethod 27 | def forward(self, x): 28 | """model return (reconstructed_x, *)""" 29 | return 30 | 31 | @abc.abstractmethod 32 | def sample(self, size): 33 | """sample new images from model""" 34 | return 35 | 36 | @abc.abstractmethod 37 | def loss_function(self, **kwargs): 38 | """accepts (original images, *) where * is the same as returned from forward()""" 39 | return 40 | 41 | @abc.abstractmethod 42 | def latest_losses(self): 43 | """returns the latest losses in a dictionary. Useful for logging.""" 44 | return 45 | 46 | class ResBlock(nn.Module): 47 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 48 | super(ResBlock, self).__init__() 49 | 50 | if mid_channels is None: 51 | mid_channels = out_channels 52 | 53 | layers = [ 54 | nn.LeakyReLU(), 55 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 56 | nn.LeakyReLU(), 57 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)] 58 | if bn: 59 | layers.insert(2, nn.BatchNorm2d(out_channels)) 60 | self.convs = nn.Sequential(*layers) 61 | 62 | def forward(self, x): 63 | return x + self.convs(x) 64 | 65 | class CVAE(AbstractAutoEncoder): 66 | def __init__(self, d, z, **kwargs): 67 | super(CVAE, self).__init__() 68 | 69 | self.encoder = nn.Sequential( 70 | nn.Conv2d(3, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 71 | nn.BatchNorm2d(d // 2), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False), 74 | nn.BatchNorm2d(d), 75 | nn.ReLU(inplace=True), 76 | ResBlock(d, d, bn=True), 77 | nn.BatchNorm2d(d), 78 | ResBlock(d, d, bn=True), 79 | ) 80 | 81 | self.decoder = nn.Sequential( 82 | ResBlock(d, d, bn=True), 83 | nn.BatchNorm2d(d), 84 | ResBlock(d, d, bn=True), 85 | nn.BatchNorm2d(d), 86 | 87 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 88 | nn.BatchNorm2d(d // 2), 89 | nn.LeakyReLU(inplace=True), 90 | nn.ConvTranspose2d(d // 2, 3, kernel_size=4, stride=2, padding=1, bias=False), 91 | ) 92 | self.xi_bn = nn.BatchNorm2d(3) 93 | 94 | self.f = 8 95 | self.d = d 96 | self.z = z 97 | self.fc11 = nn.Linear(d * self.f ** 2, self.z) 98 | self.fc12 = nn.Linear(d * self.f ** 2, self.z) 99 | self.fc21 = nn.Linear(self.z, d * self.f ** 2) 100 | 101 | def encode(self, x): 102 | h = self.encoder(x) 103 | h1 = h.view(-1, self.d * self.f ** 2) 104 | return h, self.fc11(h1), self.fc12(h1) 105 | 106 | def reparameterize(self, mu, logvar): 107 | if self.training: 108 | std = logvar.mul(0.5).exp_() 109 | eps = std.new(std.size()).normal_() 110 | return eps.mul(std).add_(mu) 111 | else: 112 | return mu 113 | 114 | def decode(self, z): 115 | z = z.view(-1, self.d, self.f, self.f) 116 | h3 = self.decoder(z) 117 | return torch.tanh(h3) 118 | 119 | def forward(self, x): 120 | 121 | _, mu, logvar = self.encode(x) 122 | hi = self.reparameterize(mu, logvar) 123 | hi_projected = self.fc21(hi) 124 | xi = self.decode(hi_projected) 125 | xi = self.xi_bn(xi) 126 | 127 | return xi 128 | 129 | class CVAE_imagenet(nn.Module): 130 | def __init__(self, d, k=10, num_channels=3, **kwargs): 131 | super(CVAE_imagenet, self).__init__() 132 | 133 | self.encoder = nn.Sequential( 134 | nn.Conv2d(num_channels, d, kernel_size=4, stride=2, padding=1), 135 | nn.BatchNorm2d(d), 136 | nn.LeakyReLU(inplace=True), 137 | nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), 138 | nn.BatchNorm2d(d), 139 | nn.LeakyReLU(inplace=True), 140 | ResBlock(d, d), 141 | nn.BatchNorm2d(d), 142 | ResBlock(d, d), 143 | nn.BatchNorm2d(d), 144 | ) 145 | self.decoder = nn.Sequential( 146 | ResBlock(d, d), 147 | nn.BatchNorm2d(d), 148 | ResBlock(d, d), 149 | nn.ConvTranspose2d(d, d, kernel_size=4, stride=2, padding=1), 150 | nn.BatchNorm2d(d), 151 | nn.LeakyReLU(inplace=True), 152 | nn.ConvTranspose2d(d, num_channels, kernel_size=4, stride=2, padding=1), 153 | ) 154 | self.d = d 155 | self.emb = NearestEmbed(k, d) 156 | 157 | for l in self.modules(): 158 | if isinstance(l, nn.Linear) or isinstance(l, nn.Conv2d): 159 | l.weight.detach().normal_(0, 0.02) 160 | torch.fmod(l.weight, 0.04) 161 | nn.init.constant_(l.bias, 0) 162 | 163 | self.encoder[-1].weight.detach().fill_(1 / 40) 164 | 165 | self.emb.weight.detach().normal_(0, 0.02) 166 | torch.fmod(self.emb.weight, 0.04) 167 | self.L_bn = nn.BatchNorm2d(num_channels) 168 | 169 | 170 | def encode(self, x): 171 | return self.encoder(x) 172 | 173 | def decode(self, x): 174 | return torch.tanh(self.decoder(x)) 175 | 176 | def forward(self, x): 177 | 178 | z_e = self.encode(x) 179 | 180 | z_q, _ = self.emb(z_e, weight_sg=True) 181 | emb, _ = self.emb(z_e.detach()) 182 | 183 | l = self.decode(z_q) 184 | xi = self.L_bn(l) 185 | 186 | return xi -------------------------------------------------------------------------------- /utils/randaugment4fixmatch.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, 16) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, 16) 220 | return img -------------------------------------------------------------------------------- /detection/ADV_Generate_Mahalanobis_Subspace.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import data_loader 5 | import numpy as np 6 | import calculate_log as callog 7 | import models 8 | import os 9 | import lib_generation 10 | from torch import nn 11 | from torchvision import transforms 12 | from torch.autograd import Variable 13 | import pdb 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch code: Mahalanobis detector') 16 | parser.add_argument('--batch_size', type=int, default=200, metavar='N', help='batch size for data loader') 17 | parser.add_argument('--dataset', required=True, help='cifar10 | cifar100 | svhn') 18 | parser.add_argument('--dataroot', default='../../data', help='path to dataset') 19 | parser.add_argument('--outf', default='./adv_output/', help='folder to output results') 20 | parser.add_argument('--num_classes', type=int, default=10, help='the # of classes') 21 | parser.add_argument('--net_type', required=True, help='resnet | densenet') 22 | parser.add_argument('--gpu', type=int, default=0, help='gpu index') 23 | parser.add_argument('--adv_type', required=True, help='FGSM | BIM | PGD | CW') 24 | parser.add_argument('--vae_path', default='./data/96.32/model_epoch252.pth', help='folder to output results') 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | def main(): 29 | # set the path to pre-trained model and output 30 | args.outf = args.outf + args.net_type + '_' + args.dataset + '/' 31 | if os.path.isdir(args.outf) == False: 32 | os.mkdir(args.outf) 33 | torch.cuda.manual_seed(0) 34 | torch.cuda.set_device(args.gpu) 35 | 36 | in_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),]) 37 | model = models.Wide_ResNet(28, 10, 0.3, 10) 38 | model = nn.DataParallel(model) 39 | model_dict = model.state_dict() 40 | save_model = torch.load(args.vae_path) 41 | state_dict = {k.replace('classifier.',''): v for k, v in save_model.items() if k.replace('classifier.','') in model_dict.keys()} 42 | print(state_dict.keys()) 43 | model_dict.update(state_dict) 44 | model.load_state_dict(model_dict) 45 | model.cuda() 46 | model.eval() 47 | print('load model: ' + args.net_type) 48 | vae = models.CVAE(d=32, z=2048) 49 | vae = nn.DataParallel(vae) 50 | model_dict = vae.state_dict() 51 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 52 | print(state_dict.keys()) 53 | model_dict.update(state_dict) 54 | vae.load_state_dict(model_dict) 55 | vae.cuda() 56 | vae.eval() 57 | 58 | # load dataset 59 | print('load target data: ', args.dataset) 60 | train_loader, _ = data_loader.getTargetDataSet(args.dataset, args.batch_size, in_transform, args.dataroot) 61 | test_clean_data = torch.load(args.outf + 'clean_data_%s_%s_%s.pth' % (args.net_type, args.dataset, args.adv_type)) 62 | test_adv_data = torch.load(args.outf + 'adv_data_%s_%s_%s.pth' % (args.net_type, args.dataset, args.adv_type)) 63 | test_noisy_data = torch.load(args.outf + 'noisy_data_%s_%s_%s.pth' % (args.net_type, args.dataset, args.adv_type)) 64 | test_label = torch.load(args.outf + 'label_%s_%s_%s.pth' % (args.net_type, args.dataset, args.adv_type)) 65 | 66 | # set information about feature extaction 67 | model.eval() 68 | temp_x = torch.rand(2,3,32,32).cuda() 69 | temp_x = Variable(temp_x) 70 | temp_list = model.module.feature_list(temp_x-vae(temp_x))[1] 71 | num_output = len(temp_list) 72 | feature_list = np.empty(num_output) 73 | count = 0 74 | for out in temp_list: 75 | feature_list[count] = out.size(1) 76 | count += 1 77 | 78 | print('get sample mean and covariance') 79 | sample_mean, precision = lib_generation.sample_estimator(model, vae, args.num_classes, feature_list, train_loader) 80 | 81 | print('get Mahalanobis scores') 82 | m_list = [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005] 83 | for magnitude in m_list: 84 | print('\nNoise: ' + str(magnitude)) 85 | for i in range(num_output): 86 | M_in \ 87 | = lib_generation.get_Mahalanobis_score_adv(model, vae, test_clean_data, test_label, \ 88 | args.num_classes, args.outf, args.net_type, \ 89 | sample_mean, precision, i, magnitude) 90 | M_in = np.asarray(M_in, dtype=np.float32) 91 | if i == 0: 92 | Mahalanobis_in = M_in.reshape((M_in.shape[0], -1)) 93 | else: 94 | Mahalanobis_in = np.concatenate((Mahalanobis_in, M_in.reshape((M_in.shape[0], -1))), axis=1) 95 | 96 | for i in range(num_output): 97 | M_out \ 98 | = lib_generation.get_Mahalanobis_score_adv(model, vae, test_adv_data, test_label, \ 99 | args.num_classes, args.outf, args.net_type, \ 100 | sample_mean, precision, i, magnitude) 101 | M_out = np.asarray(M_out, dtype=np.float32) 102 | if i == 0: 103 | Mahalanobis_out = M_out.reshape((M_out.shape[0], -1)) 104 | else: 105 | Mahalanobis_out = np.concatenate((Mahalanobis_out, M_out.reshape((M_out.shape[0], -1))), axis=1) 106 | 107 | for i in range(num_output): 108 | M_noisy \ 109 | = lib_generation.get_Mahalanobis_score_adv(model, vae, test_noisy_data, test_label, \ 110 | args.num_classes, args.outf, args.net_type, \ 111 | sample_mean, precision, i, magnitude) 112 | M_noisy = np.asarray(M_noisy, dtype=np.float32) 113 | if i == 0: 114 | Mahalanobis_noisy = M_noisy.reshape((M_noisy.shape[0], -1)) 115 | else: 116 | Mahalanobis_noisy = np.concatenate((Mahalanobis_noisy, M_noisy.reshape((M_noisy.shape[0], -1))), axis=1) 117 | Mahalanobis_in = np.asarray(Mahalanobis_in, dtype=np.float32) 118 | Mahalanobis_out = np.asarray(Mahalanobis_out, dtype=np.float32) 119 | Mahalanobis_noisy = np.asarray(Mahalanobis_noisy, dtype=np.float32) 120 | Mahalanobis_pos = np.concatenate((Mahalanobis_in, Mahalanobis_noisy)) 121 | 122 | Mahalanobis_data, Mahalanobis_labels = lib_generation.merge_and_generate_labels(Mahalanobis_out, Mahalanobis_pos) 123 | file_name = os.path.join(args.outf, 'Mahalanobis_%s_%s_%s.npy' % (str(magnitude), args.dataset, args.adv_type)) 124 | 125 | Mahalanobis_data = np.concatenate((Mahalanobis_data, Mahalanobis_labels), axis=1) 126 | np.save(file_name, Mahalanobis_data) 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | 133 | -------------------------------------------------------------------------------- /detection/models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | self.droprate = dropRate 15 | def forward(self, x): 16 | out = self.conv1(self.relu(self.bn1(x))) 17 | if self.droprate > 0: 18 | out = F.dropout(out, p=self.droprate, training=self.training) 19 | return torch.cat([x, out], 1) 20 | 21 | class BottleneckBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes, dropRate=0.0): 23 | super(BottleneckBlock, self).__init__() 24 | inter_planes = out_planes * 4 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | self.bn2 = nn.BatchNorm2d(inter_planes) 30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 31 | padding=1, bias=False) 32 | self.droprate = dropRate 33 | def forward(self, x): 34 | out = self.conv1(self.relu(self.bn1(x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 37 | out = self.conv2(self.relu(self.bn2(out))) 38 | if self.droprate > 0: 39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 40 | return torch.cat([x, out], 1) 41 | 42 | class TransitionBlock(nn.Module): 43 | def __init__(self, in_planes, out_planes, dropRate=0.0): 44 | super(TransitionBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 48 | padding=0, bias=False) 49 | self.droprate = dropRate 50 | def forward(self, x): 51 | out = self.conv1(self.relu(self.bn1(x))) 52 | if self.droprate > 0: 53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 54 | return F.avg_pool2d(out, 2) 55 | 56 | class DenseBlock(nn.Module): 57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 58 | super(DenseBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 64 | return nn.Sequential(*layers) 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | class DenseNet3(nn.Module): 69 | def __init__(self, depth, num_classes, growth_rate=12, 70 | reduction=0.5, bottleneck=True, dropRate=0.0): 71 | super(DenseNet3, self).__init__() 72 | in_planes = 2 * growth_rate 73 | n = (depth - 4) / 3 74 | if bottleneck == True: 75 | n = n/2 76 | block = BottleneckBlock 77 | else: 78 | block = BasicBlock 79 | # 1st conv before any dense block 80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 84 | in_planes = int(in_planes+n*growth_rate) 85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 86 | in_planes = int(math.floor(in_planes*reduction)) 87 | # 2nd block 88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 89 | in_planes = int(in_planes+n*growth_rate) 90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 91 | in_planes = int(math.floor(in_planes*reduction)) 92 | # 3rd block 93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 94 | in_planes = int(in_planes+n*growth_rate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(in_planes) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc = nn.Linear(in_planes, num_classes) 99 | self.in_planes = in_planes 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.trans1(self.block1(out)) 114 | out = self.trans2(self.block2(out)) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.avg_pool2d(out, 8) 118 | out = out.view(-1, self.in_planes) 119 | return self.fc(out) 120 | 121 | # function to extact the multiple features 122 | def feature_list(self, x): 123 | out_list = [] 124 | out = self.conv1(x) 125 | out_list.append(out) 126 | out = self.trans1(self.block1(out)) 127 | out_list.append(out) 128 | out = self.trans2(self.block2(out)) 129 | out_list.append(out) 130 | out = self.block3(out) 131 | out = self.relu(self.bn1(out)) 132 | out_list.append(out) 133 | out = F.avg_pool2d(out, 8) 134 | out = out.view(-1, self.in_planes) 135 | 136 | return self.fc(out), out_list 137 | 138 | def intermediate_forward(self, x, layer_index): 139 | out = self.conv1(x) 140 | if layer_index == 1: 141 | out = self.trans1(self.block1(out)) 142 | elif layer_index == 2: 143 | out = self.trans1(self.block1(out)) 144 | out = self.trans2(self.block2(out)) 145 | elif layer_index == 3: 146 | out = self.trans1(self.block1(out)) 147 | out = self.trans2(self.block2(out)) 148 | out = self.block3(out) 149 | out = self.relu(self.bn1(out)) 150 | return out 151 | 152 | # function to extact the penultimate features 153 | def penultimate_forward(self, x): 154 | out = self.conv1(x) 155 | out = self.trans1(self.block1(out)) 156 | out = self.trans2(self.block2(out)) 157 | out = self.block3(out) 158 | penultimate = self.relu(self.bn1(out)) 159 | out = F.avg_pool2d(penultimate, 8) 160 | out = out.view(-1, self.in_planes) 161 | return self.fc(out), penultimate -------------------------------------------------------------------------------- /detection/lib/quilting.cpp: -------------------------------------------------------------------------------- 1 | #include "quilting.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "findseam.h" 7 | 8 | void generatePatches( 9 | float* result, 10 | float* img, 11 | unsigned int imgH, 12 | unsigned int imgW, 13 | unsigned int patchSize, 14 | unsigned int overlap) { 15 | int n = 0; 16 | 17 | for (int y = 0; y < imgH - patchSize; y += patchSize - overlap) { 18 | for (int x = 0; x < imgW - patchSize; x += patchSize - overlap) { 19 | for (int c = 0; c < 3; c++) { 20 | for (int j = 0; j < patchSize; j++) { 21 | for (int i = 0; i < patchSize; i++) { 22 | result 23 | [n * 3 * patchSize * patchSize + c * patchSize * patchSize + 24 | j * patchSize + i] = 25 | img[c * imgH * imgW + (y + j) * imgW + (x + i)]; 26 | } 27 | } 28 | } 29 | n++; 30 | } 31 | } 32 | } 33 | 34 | // "from" nodes and "to" nodes 35 | using graphLattice = std::pair, std::vector>; 36 | 37 | std::map, graphLattice> cache; 38 | 39 | graphLattice _getFourLattice(unsigned int h, unsigned int w, bool useCache) { 40 | if (useCache) { 41 | auto iter = cache.find(std::make_pair(h, w)); 42 | if (iter != cache.end()) { 43 | return iter->second; 44 | } 45 | } 46 | 47 | std::vector from, to; 48 | 49 | // right 50 | for (int j = 0; j < h; j++) { 51 | for (int i = 0; i < w - 1; i++) { 52 | from.push_back(j * w + i); 53 | to.push_back(j * w + (i + 1)); 54 | } 55 | } 56 | 57 | // left 58 | for (int j = 0; j < h; j++) { 59 | for (int i = 1; i < w; i++) { 60 | from.push_back(j * w + i); 61 | to.push_back(j * w + (i - 1)); 62 | } 63 | } 64 | 65 | // down 66 | for (int j = 0; j < h - 1; j++) { 67 | for (int i = 0; i < w; i++) { 68 | from.push_back(j * w + i); 69 | to.push_back((j + 1) * w + i); 70 | } 71 | } 72 | 73 | // up 74 | for (int j = 1; j < h; j++) { 75 | for (int i = 0; i < w; i++) { 76 | from.push_back(j * w + i); 77 | to.push_back((j - 1) * w + i); 78 | } 79 | } 80 | 81 | graphLattice result = std::make_pair(from, to); 82 | 83 | if (useCache) { 84 | cache[std::make_pair(h, w)] = result; 85 | } 86 | 87 | return result; 88 | } 89 | 90 | void _findSeam( 91 | int* result, 92 | float* im1, 93 | float* im2, 94 | unsigned int patchSize, 95 | unsigned int* mask) { 96 | graphLattice graph = _getFourLattice(patchSize, patchSize, true); 97 | std::vector from = graph.first; 98 | std::vector to = graph.second; 99 | int edgeNum = 4 * patchSize * patchSize - 2 * (patchSize + patchSize); 100 | 101 | float* values = new float[edgeNum]; 102 | for (int i = 0; i < edgeNum; i++) { 103 | values[i] = 0; 104 | } 105 | 106 | for (int c = 0; c < 3; c++) { 107 | for (int i = 0; i < edgeNum; i++) { 108 | values[i] += fabs( 109 | im2[c * patchSize * patchSize + to[i]] - 110 | im1[c * patchSize * patchSize + from[i]]); 111 | } 112 | } 113 | 114 | int nodeNum = patchSize * patchSize; 115 | float* tvalues = new float[nodeNum * 2]; 116 | for (int i = 0; i < nodeNum * 2; i++) { 117 | tvalues[i] = 0; 118 | } 119 | 120 | for (int j = 0; j < patchSize; j++) { 121 | for (int i = 0; i < patchSize; i++) { 122 | for (int c = 0; c < 2; c++) { 123 | if (mask[j * patchSize + i] == c + 1) { 124 | tvalues[(j * patchSize + i) * 2 + c] = 125 | std::numeric_limits::infinity(); 126 | } 127 | } 128 | } 129 | } 130 | 131 | findseam(nodeNum, edgeNum, from.data(), to.data(), values, tvalues, result); 132 | delete[] values; 133 | delete[] tvalues; 134 | } 135 | 136 | void stitch( 137 | float* result, 138 | float* im1, 139 | float* im2, 140 | unsigned int patchSize, 141 | unsigned int overlap, 142 | unsigned int y, 143 | unsigned int x) { 144 | unsigned int* mask = new unsigned int[patchSize * patchSize]; 145 | 146 | for (int j = 0; j < patchSize; j++) { 147 | for (int i = 0; i < patchSize; i++) { 148 | mask[j * patchSize + i] = 2; 149 | } 150 | } 151 | 152 | if (y > 0) { 153 | for (int j = 0; j < overlap; j++) { 154 | for (int i = 0; i < patchSize; i++) { 155 | mask[j * patchSize + i] = 0; 156 | } 157 | } 158 | } 159 | 160 | if (x > 0) { 161 | for (int j = 0; j < patchSize; j++) { 162 | for (int i = 0; i < overlap; i++) { 163 | mask[j * patchSize + i] = 0; 164 | } 165 | } 166 | } 167 | 168 | int* seamMask = new int[patchSize * patchSize]; 169 | _findSeam(seamMask, im1, im2, patchSize, mask); 170 | 171 | int offset; 172 | for (int c = 0; c < 3; c++) { 173 | for (int j = 0; j < patchSize; j++) { 174 | for (int i = 0; i < patchSize; i++) { 175 | offset = c * patchSize * patchSize + j * patchSize + i; 176 | result[offset] = 177 | (seamMask[j * patchSize + i] == 1) ? im2[offset] : im1[offset]; 178 | } 179 | } 180 | } 181 | delete [] mask; 182 | delete [] seamMask; 183 | } 184 | 185 | void generateQuiltedImages( 186 | float* result, 187 | long* neighbors, 188 | float* patchDict, 189 | unsigned int imgH, 190 | unsigned int imgW, 191 | unsigned int patchSize, 192 | unsigned int overlap, 193 | bool graphcut) { 194 | int n = 0; 195 | for (int y = 0; y < imgH - patchSize; y += patchSize - overlap) { 196 | for (int x = 0; x < imgW - patchSize; x += patchSize - overlap) { 197 | if (neighbors[n] != -1) { 198 | if (graphcut) { 199 | float* patch = new float[3 * patchSize * patchSize]; 200 | for (int c = 0; c < 3; c++) { 201 | for (int j = 0; j < patchSize; j++) { 202 | for (int i = 0; i < patchSize; i++) { 203 | patch[c * patchSize * patchSize + j * patchSize + i] = 204 | result[c * imgH * imgW + (y + j) * imgW + (x + i)]; 205 | } 206 | } 207 | } 208 | 209 | float* stitched = new float[3 * patchSize * patchSize]; 210 | float* matched = 211 | patchDict + (neighbors[n] * 3 * patchSize * patchSize); 212 | stitch(stitched, patch, matched, patchSize, overlap, y, x); 213 | for (int c = 0; c < 3; c++) { 214 | for (int j = 0; j < patchSize; j++) { 215 | for (int i = 0; i < patchSize; i++) { 216 | result[c * imgH * imgW + (y + j) * imgW + (x + i)] = 217 | stitched[c * patchSize * patchSize + j * patchSize + i]; 218 | } 219 | } 220 | } 221 | delete[] patch; 222 | delete[] stitched; 223 | } else { 224 | for (int c = 0; c < 3; c++) { 225 | for (int j = 0; j < patchSize; j++) { 226 | for (int i = 0; i < patchSize; i++) { 227 | result[c * imgH * imgW + (y + j) * imgW + (x + i)] = patchDict 228 | [neighbors[n] * 3 * patchSize * patchSize + 229 | c * patchSize * patchSize + j * patchSize + i]; 230 | } 231 | } 232 | } 233 | } 234 | } 235 | 236 | n++; 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /detection/lib/quilting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import ctypes 9 | import torch 10 | 11 | # load seam-finding library: 12 | FINDSEAM_LIB = ctypes.cdll.LoadLibrary( 13 | 'libexperimental_deeplearning_lvdmaaten_adversarial_findseam.so') 14 | 15 | # other globals: 16 | LATTICE_CACHE = {} # cache lattices here 17 | 18 | 19 | # function that constructs a four-connected lattice: 20 | def __four_lattice__(height, width, use_cache=True): 21 | 22 | # try the cache first: 23 | if use_cache and (height, width) in LATTICE_CACHE: 24 | return LATTICE_CACHE[(height, width)] 25 | 26 | # assertions and initialization: 27 | assert type(width) == int and type(height) == int and \ 28 | width > 0 and height > 0, 'height and width should be positive integers' 29 | N = height * width 30 | height, width = width, height # tensors are in row-major format 31 | graph = { 32 | 'from': torch.LongTensor(4 * N - (height + width) * 2), 33 | 'to': torch.LongTensor(4 * N - (height + width) * 2), 34 | } 35 | 36 | # closure that copies stuff in: 37 | def add_edges(i, j, offset): 38 | graph['from'].narrow(0, offset, i.nelement()).copy_(i) 39 | graph['from'].narrow(0, offset + i.nelement(), j.nelement()).copy_(j) 40 | graph['to'].narrow(0, offset, j.nelement()).copy_(j) 41 | graph['to'].narrow(0, offset + j.nelement(), i.nelement()).copy_(i) 42 | 43 | # add vertical connections: 44 | i = torch.arange(0, N).squeeze().long() 45 | mask = torch.ByteTensor(N).fill_(1) 46 | mask.index_fill_(0, torch.arange(height - 1, N, height).squeeze().long(), 0) 47 | i = i[mask] 48 | add_edges(i, torch.add(i, 1), 0) 49 | 50 | # add horizontal connections: 51 | offset = 2 * i.nelement() 52 | i = torch.arange(0, N - height).squeeze().long() 53 | add_edges(i, torch.add(i, height), offset) 54 | 55 | # cache and return graph: 56 | if use_cache: 57 | LATTICE_CACHE[(height, width)] = graph 58 | return graph 59 | 60 | 61 | # utility function for checking inputs: 62 | def __assert_inputs__(im1, im2, mask=None): 63 | assert type(im1) == torch.ByteTensor or type(im1) == torch.FloatTensor, \ 64 | 'im1 should be a ByteTensor or FloatTensor' 65 | assert type(im2) == torch.ByteTensor or type(im2) == torch.FloatTensor, \ 66 | 'im2 should be a ByteTensor or FloatTensor' 67 | assert im1.dim() == 3, 'im1 should be three-dimensional' 68 | assert im2.dim() == 3, 'im2 should be three-dimensional' 69 | assert im1.size() == im2.size(), 'im1 and im2 should have same size' 70 | if mask is not None: 71 | assert mask.dim() == 2, 'mask should be two-dimensional' 72 | assert type(mask) == torch.ByteTensor, 'mask should be torch.ByteTensor' 73 | assert mask.size(0) == im1.size(1) and mask.size(1) == im1.size(2), \ 74 | 'mask should have same height and width as images' 75 | 76 | 77 | # function that finds seam between two images: 78 | def find_seam(im1, im2, mask): 79 | 80 | # assertions: 81 | __assert_inputs__(im1, im2, mask) 82 | im1 = im1.float() 83 | im2 = im2.float() 84 | 85 | # construct edge weights: 86 | graph = __four_lattice__(im1.size(1), im1.size(2)) 87 | values = torch.FloatTensor(graph['from'].size(0)).fill_(0.) 88 | for c in range(im1.size(0)): 89 | im1c = im1[c].contiguous().view(im1.size(1) * im1.size(2)) 90 | im2c = im2[c].contiguous().view(im2.size(1) * im2.size(2)) 91 | values.add_(torch.abs( 92 | im2c.index_select(0, graph['to']) - 93 | im1c.index_select(0, graph['from']) 94 | )) 95 | 96 | # construct terminal weights: 97 | idxim = torch.arange(0, mask.nelement()).long().view(mask.size()) 98 | tvalues = torch.FloatTensor(mask.nelement(), 2).fill_(0) 99 | for c in range(2): 100 | select_c = (mask == (c + 1)) 101 | if select_c.any(): 102 | tvalues.select(1, c).index_fill_(0, idxim[select_c], float('inf')) 103 | 104 | # convert graph to IntTensor (make sure this is not GC'ed): 105 | graph_from = graph['from'].int() 106 | graph_to = graph['to'].int() 107 | 108 | # run the Boykov algorithm to obtain stitching mask: 109 | labels = torch.IntTensor(mask.nelement()) 110 | FINDSEAM_LIB.findseam( 111 | ctypes.c_int(mask.nelement()), 112 | ctypes.c_int(values.nelement()), 113 | ctypes.c_void_p(graph_from.data_ptr()), 114 | ctypes.c_void_p(graph_to.data_ptr()), 115 | ctypes.c_void_p(values.data_ptr()), 116 | ctypes.c_void_p(tvalues.data_ptr()), 117 | ctypes.c_void_p(labels.data_ptr()), 118 | ) 119 | mask = labels.resize_(mask.size()).byte() 120 | return mask 121 | 122 | 123 | # function that performs the stitch: 124 | def __stitch__(im1, im2, overlap, y, x): 125 | 126 | # assertions: 127 | __assert_inputs__(im1, im2) 128 | 129 | # construct mask: 130 | patch_size = im1.size(1) 131 | mask = torch.ByteTensor(patch_size, patch_size).fill_(2) 132 | if y > 0: # there is not overlap at the border 133 | mask.narrow(0, 0, overlap).fill_(0) 134 | if x > 0: # there is not overlap at the border 135 | mask.narrow(1, 0, overlap).fill_(0) 136 | 137 | # seam the two patches: 138 | seam_mask = find_seam(im1, im2, mask) 139 | stitched_im = im1.clone() 140 | for c in range(stitched_im.size(0)): 141 | stitched_im[c][seam_mask == 1] = im2[c][seam_mask] 142 | return stitched_im 143 | 144 | 145 | # main quilting function: 146 | def quilting(img, faiss_index, patch_dict, patch_size=5, overlap=2, 147 | graphcut=False, patch_transform=None): 148 | 149 | # assertions: 150 | assert torch.is_tensor(img) 151 | assert torch.is_tensor(patch_dict) and patch_dict.dim() == 2 152 | assert type(patch_size) == int and patch_size > 0 153 | assert type(overlap) == int and overlap > 0 154 | assert patch_size > overlap 155 | if patch_transform is not None: 156 | assert callable(patch_transform) 157 | 158 | # gather all image patches: 159 | patches = [] 160 | y_range = range(0, img.size(1) - patch_size, patch_size - overlap) 161 | x_range = range(0, img.size(2) - patch_size, patch_size - overlap) 162 | for y in y_range: 163 | for x in range(0, img.size(2) - patch_size, patch_size - overlap): 164 | patch = img[:, y:y + patch_size, x:x + patch_size] 165 | if patch_transform is not None: 166 | patch = patch_transform(patch) 167 | patches.append(patch) 168 | 169 | # find nearest patches in faiss index: 170 | patches = torch.stack(patches, dim=0) 171 | patches = patches.view(patches.size(0), int(patches.nelement() / patches.size(0))) 172 | faiss_index.nprobe = 5 173 | _, neighbors = faiss_index.search(patches.numpy(), 1) 174 | neighbors = torch.LongTensor(neighbors).squeeze() 175 | if (neighbors == -1).any(): 176 | print('WARNING: %d out of %d neighbor searches failed.' % 177 | ((neighbors == -1).sum(), neighbors.nelement())) 178 | 179 | # piece the image back together: 180 | n = 0 181 | quilt_img = img.clone().fill_(0) 182 | for y in y_range: 183 | for x in x_range: 184 | if neighbors[n] != -1: 185 | 186 | # get current image and new patch: 187 | patch = patch_dict[neighbors[n]].view( 188 | img.size(0), patch_size, patch_size 189 | ) 190 | cur_img = quilt_img[:, y:y + patch_size, x:x + patch_size] 191 | 192 | # compute graph cut if requested: 193 | if graphcut: 194 | patch = __stitch__(cur_img, patch, overlap, y, x) 195 | 196 | # copy the patch into the image: 197 | cur_img.copy_(patch) 198 | n += 1 199 | 200 | # return the quilted image: 201 | return quilt_img 202 | -------------------------------------------------------------------------------- /detection/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | BasicBlock and Bottleneck module is from the original ResNet paper: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | PreActBlock and PreActBottleneck module is from the later paper: 6 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 8 | Original code is from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 9 | ''' 10 | import os 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from torch.autograd import Variable 17 | from torch.nn.parameter import Parameter 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, in_planes, planes, stride=1): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(in_planes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | 33 | self.shortcut = nn.Sequential() 34 | if stride != 1 or in_planes != self.expansion*planes: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 37 | nn.BatchNorm2d(self.expansion*planes) 38 | ) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.bn2(self.conv2(out)) 43 | out += self.shortcut(x) 44 | out = F.relu(out) 45 | return out 46 | 47 | 48 | class PreActBlock(nn.Module): 49 | '''Pre-activation version of the BasicBlock.''' 50 | expansion = 1 51 | 52 | def __init__(self, in_planes, planes, stride=1): 53 | super(PreActBlock, self).__init__() 54 | self.bn1 = nn.BatchNorm2d(in_planes) 55 | self.conv1 = conv3x3(in_planes, planes, stride) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv2 = conv3x3(planes, planes) 58 | 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out += shortcut 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion*planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion*planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class PreActBottleneck(nn.Module): 102 | '''Pre-activation version of the original Bottleneck module.''' 103 | expansion = 4 104 | 105 | def __init__(self, in_planes, planes, stride=1): 106 | super(PreActBottleneck, self).__init__() 107 | self.bn1 = nn.BatchNorm2d(in_planes) 108 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes) 112 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 113 | 114 | if stride != 1 or in_planes != self.expansion*planes: 115 | self.shortcut = nn.Sequential( 116 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 117 | ) 118 | 119 | def forward(self, x): 120 | out = F.relu(self.bn1(x)) 121 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 122 | out = self.conv1(out) 123 | out = self.conv2(F.relu(self.bn2(out))) 124 | out = self.conv3(F.relu(self.bn3(out))) 125 | out += shortcut 126 | return out 127 | 128 | 129 | class ResNet(nn.Module): 130 | def __init__(self, block, num_blocks, num_classes=9): 131 | super(ResNet, self).__init__() 132 | self.in_planes = 64 133 | 134 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3, 135 | bias=False) 136 | self.bn1 = nn.BatchNorm2d(64) 137 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 138 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 141 | self.fc = nn.Linear(512*block.expansion, num_classes) 142 | 143 | def _make_layer(self, block, planes, num_blocks, stride): 144 | strides = [stride] + [1]*(num_blocks-1) 145 | layers = [] 146 | for stride in strides: 147 | layers.append(block(self.in_planes, planes, stride)) 148 | self.in_planes = planes * block.expansion 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | out = F.relu(self.bn1(self.conv1(x))) 153 | out = self.layer1(out) 154 | out = self.layer2(out) 155 | out = self.layer3(out) 156 | out = self.layer4(out) 157 | out = F.avg_pool2d(out, 4) 158 | out = out.view(out.size(0), -1) 159 | y = self.fc(out) 160 | return y 161 | 162 | # function to extact the multiple features 163 | def feature_list(self, x): 164 | out_list = [] 165 | out = F.relu(self.bn1(self.conv1(x))) 166 | out_list.append(out) 167 | out = self.layer1(out) 168 | out_list.append(out) 169 | out = self.layer2(out) 170 | out_list.append(out) 171 | out = self.layer3(out) 172 | out_list.append(out) 173 | out = self.layer4(out) 174 | out_list.append(out) 175 | out = F.avg_pool2d(out, 4) 176 | out = out.view(out.size(0), -1) 177 | y = self.fc(out) 178 | return y, out_list 179 | 180 | # function to extact a specific feature 181 | def intermediate_forward(self, x, layer_index): 182 | out = F.relu(self.bn1(self.conv1(x))) 183 | if layer_index == 1: 184 | out = self.layer1(out) 185 | elif layer_index == 2: 186 | out = self.layer1(out) 187 | out = self.layer2(out) 188 | elif layer_index == 3: 189 | out = self.layer1(out) 190 | out = self.layer2(out) 191 | out = self.layer3(out) 192 | elif layer_index == 4: 193 | out = self.layer1(out) 194 | out = self.layer2(out) 195 | out = self.layer3(out) 196 | out = self.layer4(out) 197 | return out 198 | 199 | # function to extact the penultimate features 200 | def penultimate_forward(self, x): 201 | out = F.relu(self.bn1(self.conv1(x))) 202 | out = self.layer1(out) 203 | out = self.layer2(out) 204 | out = self.layer3(out) 205 | penultimate = self.layer4(out) 206 | out = F.avg_pool2d(penultimate, 4) 207 | out = out.view(out.size(0), -1) 208 | y = self.fc(out) 209 | return y, penultimate 210 | 211 | def ResNet18(num_c): 212 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_c) 213 | 214 | def ResNet34(num_c): 215 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_c) 216 | 217 | def ResNet50(num_c): 218 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_c) 219 | 220 | def ResNet101(): 221 | return ResNet(Bottleneck, [3,4,23,3]) 222 | 223 | def ResNet152(): 224 | return ResNet(Bottleneck, [3,8,36,3]) 225 | 226 | 227 | def test(): 228 | net = ResNet18() 229 | y = net(Variable(torch.randn(1,3,32,32))) 230 | print(y.size()) 231 | 232 | # test() 233 | -------------------------------------------------------------------------------- /detection/ADV_Samples_Subspace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Oct 25 2018 3 | @author: Kimin Lee 4 | """ 5 | from __future__ import print_function 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | import data_loader 10 | import numpy as np 11 | import models 12 | import os 13 | import lib.adversary as adversary 14 | from lib.attacks import DeltaAttack 15 | import pdb 16 | from torchvision import transforms 17 | from torch.autograd import Variable 18 | import lib_generation 19 | parser = argparse.ArgumentParser(description='PyTorch code: Mahalanobis detector') 20 | parser.add_argument('--batch_size', type=int, default=200, metavar='N', help='batch size for data loader') 21 | parser.add_argument('--dataset', required=True, help='cifar10 | imagenet') 22 | parser.add_argument('--dataroot', default='../../data/', help='path to dataset') 23 | parser.add_argument('--outf', default='./adv_output/', help='folder to output results') 24 | parser.add_argument('--num_classes', type=int, default=10, help='the # of classes') 25 | parser.add_argument('--net_type', required=True, help='resnet') 26 | parser.add_argument('--gpu', type=int, default=0, help='gpu index') 27 | parser.add_argument('--adv_type', required=True, help='FGSM | BIM | PGD | CW') 28 | parser.add_argument('--vae_path', default='./data/96.32/model_epoch252.pth', help='folder to output results') 29 | parser.add_argument('--pertubation', type=float, default=8/255, help='adversarial pertubation') 30 | parser.add_argument('--steps', type=int, default=5, help='adversarial iteration') 31 | args = parser.parse_args() 32 | print(args) 33 | 34 | 35 | def main(): 36 | args.outf = args.outf + args.net_type + '_' + args.dataset + '/' 37 | if os.path.isdir(args.outf) == False: 38 | os.makedirs(args.outf) 39 | torch.cuda.manual_seed(0) 40 | torch.cuda.set_device(args.gpu) 41 | 42 | if args.adv_type == 'FGSM': 43 | adv_noise = args.pertubation 44 | 45 | in_transform = transforms.Compose([transforms.ToTensor(), \ 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),]) 47 | 48 | min_pixel = -1.9894736842105263 49 | max_pixel = 2.126488706365503 50 | if args.adv_type == 'FGSM': 51 | random_noise_size = 0.25 / 4 52 | else: 53 | random_noise_size = 0.13 / 2 54 | model = models.Wide_ResNet(28, 10, 0.3, 10) 55 | model = nn.DataParallel(model) 56 | model_dict = model.state_dict() 57 | save_model = torch.load(args.vae_path) 58 | state_dict = {k.replace('classifier.',''): v for k, v in save_model.items() if k.replace('classifier.','') in model_dict.keys()} 59 | print(state_dict.keys()) 60 | model_dict.update(state_dict) 61 | model.load_state_dict(model_dict) 62 | model.cuda() 63 | model.eval() 64 | print('load model: ' + args.net_type) 65 | 66 | vae = models.CVAE(d=32, z=2048) 67 | vae = nn.DataParallel(vae) 68 | model_dict = vae.state_dict() 69 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 70 | print(state_dict.keys()) 71 | model_dict.update(state_dict) 72 | vae.load_state_dict(model_dict) 73 | vae.cuda() 74 | vae.eval() 75 | 76 | # load dataset 77 | print('load target data: ', args.dataset) 78 | train_loader, test_loader = data_loader.getTargetDataSet(args.dataset, args.batch_size, in_transform, args.dataroot) 79 | 80 | print('Attack: ' + args.adv_type + ', Dist: ' + args.dataset + '\n') 81 | model.eval() 82 | adv_data_tot, clean_data_tot, noisy_data_tot = 0, 0, 0 83 | label_tot = 0 84 | 85 | correct, adv_correct, noise_correct = 0, 0, 0 86 | total, generated_noise = 0, 0 87 | 88 | criterion = nn.CrossEntropyLoss().cuda() 89 | print('Generating testset:') 90 | selected_list = [] 91 | selected_index = 0 92 | 93 | for data, target in test_loader: 94 | data, target = data.cuda(), target.cuda() 95 | data, target = Variable(data), Variable(target) 96 | output = model( data - vae(data)) 97 | 98 | # compute the accuracy 99 | pred = output.data.max(1)[1] 100 | equal_flag = pred.eq(target.data).cpu() 101 | correct += equal_flag.sum() 102 | 103 | noisy_data = torch.add(data.data, torch.randn(data.size()).cuda(), alpha=random_noise_size) 104 | noisy_data = torch.clamp(noisy_data, min_pixel, max_pixel) 105 | 106 | if total == 0: 107 | clean_data_tot = data.clone().data.cpu() 108 | label_tot = target.clone().data.cpu() 109 | noisy_data_tot = noisy_data.clone().cpu() 110 | else: 111 | clean_data_tot = torch.cat((clean_data_tot, data.clone().data.cpu()),0) 112 | label_tot = torch.cat((label_tot, target.clone().data.cpu()), 0) 113 | noisy_data_tot = torch.cat((noisy_data_tot, noisy_data.clone().cpu()),0) 114 | 115 | # generate adversarial 116 | model.zero_grad() 117 | vae.zero_grad() 118 | 119 | if args.adv_type == 'FGSM': 120 | inputs = Variable(data.data, requires_grad=True) 121 | output = model(inputs - vae(inputs)) 122 | loss = criterion(output, target) 123 | loss.backward() 124 | gradient = torch.ge(inputs.grad.data, 0) 125 | gradient = (gradient.float()-0.5)*2 126 | gradient.index_copy_(1, torch.LongTensor([0]).cuda(), \ 127 | gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2470)) 128 | gradient.index_copy_(1, torch.LongTensor([1]).cuda(), \ 129 | gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.2435)) 130 | gradient.index_copy_(1, torch.LongTensor([2]).cuda(), \ 131 | gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2616)) 132 | adv_data = torch.add(inputs.data, gradient, alpha=adv_noise) 133 | 134 | adv_data = torch.clamp(adv_data, min_pixel, max_pixel) 135 | 136 | elif args.adv_type == 'BIM': 137 | attack = DeltaAttack(model, vae, num_iterations=5, datasets=args.dataset, rand_init=False) 138 | adv_data = attack(data, target) 139 | 140 | elif args.adv_type == 'PGD': 141 | attack = DeltaAttack(model, vae, num_iterations=5, datasets=args.dataset) 142 | adv_data = attack(data, target) 143 | 144 | elif args.adv_type == 'CW': 145 | attack = DeltaAttack(model, vae, num_iterations=5, datasets=args.dataset, loss='margin') 146 | adv_data = attack(data, target) 147 | 148 | elif args.adv_type == 'PGD-L2': 149 | attack = DeltaAttack(model, vae, eps_max=1.0, num_iterations=5, datasets=args.dataset, norm='l2') 150 | adv_data = attack(data, target) 151 | # measure the noise 152 | temp_noise_max = torch.abs((data.data - adv_data).view(adv_data.size(0), -1)) 153 | temp_noise_max, _ = torch.max(temp_noise_max, dim=1) 154 | generated_noise += torch.sum(temp_noise_max) 155 | 156 | if total == 0: 157 | adv_data_tot = adv_data.clone().cpu() 158 | else: 159 | adv_data_tot = torch.cat((adv_data_tot, adv_data.clone().cpu()),0) 160 | with torch.no_grad(): 161 | output = model(Variable(adv_data)-vae(Variable(adv_data))) 162 | # compute the accuracy 163 | pred = output.data.max(1)[1] 164 | equal_flag_adv = pred.eq(target.data).cpu() 165 | adv_correct += equal_flag_adv.sum() 166 | with torch.no_grad(): 167 | output = model(Variable(noisy_data)-vae(Variable(noisy_data))) 168 | # compute the accuracy 169 | pred = output.data.max(1)[1] 170 | equal_flag_noise = pred.eq(target.data).cpu() 171 | noise_correct += equal_flag_noise.sum() 172 | 173 | for i in range(data.size(0)): 174 | if equal_flag[i] == 1 and equal_flag_noise[i] == 1 and equal_flag_adv[i] == 0: 175 | selected_list.append(selected_index) 176 | selected_index += 1 177 | 178 | total += data.size(0) 179 | 180 | selected_list = torch.LongTensor(selected_list) 181 | clean_data_tot = torch.index_select(clean_data_tot, 0, selected_list) 182 | adv_data_tot = torch.index_select(adv_data_tot, 0, selected_list) 183 | noisy_data_tot = torch.index_select(noisy_data_tot, 0, selected_list) 184 | label_tot = torch.index_select(label_tot, 0, selected_list) 185 | 186 | torch.save(clean_data_tot, '%s/clean_data_%s_%s_%s.pth' % (args.outf, args.net_type, args.dataset, args.adv_type)) 187 | torch.save(adv_data_tot, '%s/adv_data_%s_%s_%s.pth' % (args.outf, args.net_type, args.dataset, args.adv_type)) 188 | torch.save(noisy_data_tot, '%s/noisy_data_%s_%s_%s.pth' % (args.outf, args.net_type, args.dataset, args.adv_type)) 189 | torch.save(label_tot, '%s/label_%s_%s_%s.pth' % (args.outf, args.net_type, args.dataset, args.adv_type)) 190 | 191 | print('Adversarial Noise:({:.2f})\n'.format(generated_noise / total)) 192 | print('Final Accuracy: {}/{} ({:.2f}%)\n'.format(correct, total, 100. * correct / total)) 193 | print('Adversarial Accuracy: {}/{} ({:.2f}%)\n'.format(adv_correct, total, 100. * adv_correct / total)) 194 | print('Noisy Accuracy: {}/{} ({:.2f}%)\n'.format(noise_correct, total, 100. * noise_correct / total)) 195 | 196 | 197 | if __name__ == '__main__': 198 | main() 199 | -------------------------------------------------------------------------------- /networks/adv_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import abc 3 | import os 4 | import math 5 | 6 | import numpy as np 7 | import logging 8 | import torch 9 | import torch.utils.data 10 | from torch import nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | import pdb 15 | import sys 16 | sys.path.append('.') 17 | sys.path.append('..') 18 | from utils.normalize import * 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 23 | 24 | 25 | def conv_init(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Conv') != -1: 28 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 29 | init.constant_(m.bias, 0) 30 | elif classname.find('BatchNorm') != -1: 31 | init.constant_(m.weight, 1) 32 | init.constant_(m.bias, 0) 33 | 34 | 35 | class CIFARNormalizer(nn.Module): 36 | def __init__(self, mean=[0.4914, 0.4822, 0.4465], 37 | std=[0.2470, 0.2435, 0.2616]): 38 | super().__init__() 39 | self.mean = mean 40 | self.std = std 41 | 42 | def forward(self, x): 43 | mean = torch.tensor(self.mean, device=x.device) 44 | std = torch.tensor(self.std, device=x.device) 45 | 46 | return ( 47 | (x - mean[None, :, None, None]) / 48 | std[None, :, None, None] 49 | ) 50 | 51 | 52 | class wide_basic(nn.Module): 53 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 54 | super(wide_basic, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(in_planes) 56 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 57 | self.dropout = nn.Dropout(p=dropout_rate) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 60 | 61 | self.shortcut = nn.Sequential() 62 | if stride != 1 or in_planes != planes: 63 | self.shortcut = nn.Sequential( 64 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 65 | ) 66 | 67 | def forward(self, x): 68 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 69 | out = self.conv2(F.relu(self.bn2(out))) 70 | out += self.shortcut(x) 71 | 72 | return out 73 | 74 | 75 | class Wide_ResNet(nn.Module): 76 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, norm=False): 77 | super(Wide_ResNet, self).__init__() 78 | self.in_planes = 16 79 | 80 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 81 | n = (depth-4)/6 82 | k = widen_factor 83 | 84 | print('| Wide-Resnet %dx%d' %(depth, k)) 85 | nStages = [16, 16*k, 32*k, 64*k] 86 | 87 | self.conv1 = conv3x3(3,nStages[0]) 88 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 89 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 90 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 91 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 92 | self.linear = nn.Linear(nStages[3], num_classes) 93 | self.norm = norm 94 | self.normalizer = CIFARNormalizer() 95 | 96 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 97 | strides = [stride] + [1]*(int(num_blocks)-1) 98 | layers = [] 99 | 100 | for stride in strides: 101 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 102 | self.in_planes = planes 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | if self.norm: 108 | x = self.normalizer(x) 109 | out = self.conv1(x) 110 | out = self.layer1(out) 111 | out = self.layer2(out) 112 | out = self.layer3(out) 113 | out = F.relu(self.bn1(out)) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(out.size(0), -1) 116 | out = self.linear(out) 117 | 118 | return out 119 | 120 | 121 | class CD_VAE(nn.Module): 122 | def __init__(self, vae_path, model_path): 123 | super(CD_VAE, self).__init__() 124 | self.vae = CVAE_cifar(d=32, z=2048, with_classifier=False) 125 | self.model = Wide_ResNet(28, 10, 0.3, 10) 126 | self.vae.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(vae_path).items()}) 127 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()}) 128 | self.normalize = CIFARNORMALIZE(32) 129 | 130 | def forward(self, x): 131 | gx, _, _ = self.vae(self.normalize(x)) 132 | out = self.model(gx) 133 | return out 134 | 135 | 136 | class ResBlock(nn.Module): 137 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 138 | super(ResBlock, self).__init__() 139 | 140 | if mid_channels is None: 141 | mid_channels = out_channels 142 | 143 | layers = [ 144 | nn.LeakyReLU(), 145 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 146 | nn.LeakyReLU(), 147 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)] 148 | if bn: 149 | layers.insert(2, nn.BatchNorm2d(out_channels)) 150 | self.convs = nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | return x + self.convs(x) 154 | 155 | 156 | class AbstractAutoEncoder(nn.Module): 157 | __metaclass__ = abc.ABCMeta 158 | 159 | @abc.abstractmethod 160 | def encode(self, x): 161 | return 162 | 163 | @abc.abstractmethod 164 | def decode(self, z): 165 | return 166 | 167 | @abc.abstractmethod 168 | def forward(self, x): 169 | """model return (reconstructed_x, *)""" 170 | return 171 | 172 | @abc.abstractmethod 173 | def sample(self, size): 174 | """sample new images from model""" 175 | return 176 | 177 | @abc.abstractmethod 178 | def loss_function(self, **kwargs): 179 | """accepts (original images, *) where * is the same as returned from forward()""" 180 | return 181 | 182 | @abc.abstractmethod 183 | def latest_losses(self): 184 | """returns the latest losses in a dictionary. Useful for logging.""" 185 | return 186 | 187 | 188 | class CVAE_cifar(AbstractAutoEncoder): 189 | def __init__(self, d, z, with_classifier=True, **kwargs): 190 | super(CVAE_cifar, self).__init__() 191 | 192 | self.encoder = nn.Sequential( 193 | nn.Conv2d(3, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 194 | nn.BatchNorm2d(d // 2), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False), 197 | nn.BatchNorm2d(d), 198 | nn.ReLU(inplace=True), 199 | ResBlock(d, d, bn=True), 200 | nn.BatchNorm2d(d), 201 | ResBlock(d, d, bn=True), 202 | ) 203 | 204 | self.decoder = nn.Sequential( 205 | ResBlock(d, d, bn=True), 206 | nn.BatchNorm2d(d), 207 | ResBlock(d, d, bn=True), 208 | nn.BatchNorm2d(d), 209 | 210 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 211 | nn.BatchNorm2d(d // 2), 212 | nn.LeakyReLU(inplace=True), 213 | nn.ConvTranspose2d(d // 2, 3, kernel_size=4, stride=2, padding=1, bias=False), 214 | ) 215 | 216 | self.xi_bn = nn.BatchNorm2d(3) 217 | 218 | self.f = 8 219 | self.d = d 220 | self.z = z 221 | self.fc11 = nn.Linear(d * self.f ** 2, self.z) 222 | self.fc12 = nn.Linear(d * self.f ** 2, self.z) 223 | self.fc21 = nn.Linear(self.z, d * self.f ** 2) 224 | 225 | self.with_classifier = with_classifier 226 | if self.with_classifier: 227 | self.classifier = Wide_ResNet(28, 10, 0.3, 10) 228 | 229 | def encode(self, x): 230 | h = self.encoder(x) 231 | h1 = h.view(-1, self.d * self.f ** 2) 232 | return h, self.fc11(h1), self.fc12(h1) 233 | 234 | def reparameterize(self, mu, logvar): 235 | std = logvar.mul(0.5).exp_() 236 | eps = std.new(std.size()).normal_() 237 | return eps.mul(std).add_(mu) 238 | 239 | def decode(self, z): 240 | z = z.view(-1, self.d, self.f, self.f) 241 | h3 = self.decoder(z) 242 | return torch.tanh(h3) 243 | 244 | def forward(self, x): 245 | _, mu, logvar = self.encode(x) 246 | hi = self.reparameterize(mu, logvar) 247 | hi_projected = self.fc21(hi) 248 | xi = self.decode(hi_projected) 249 | xi = self.xi_bn(xi) 250 | if self.with_classifier: 251 | out = self.classifier(torch.cat((xi, x-xi), dim=0)) 252 | out_g = out[0:x.size(0)] 253 | out_r = out[x.size(0):] 254 | return out_g, out_r, hi, xi, mu, logvar 255 | else: 256 | return xi, mu, logvar -------------------------------------------------------------------------------- /detection/lib/attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import pdb 5 | import lib.runutils 6 | from torch.autograd import Variable 7 | import operator as op 8 | 9 | from typing import Union, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | def _var2numpy(var): 16 | """ 17 | Make Variable to numpy array. No transposition will be made. 18 | 19 | :param var: Variable instance on whatever device 20 | :type var: Variable 21 | :return: the corresponding numpy array 22 | :rtype: np.ndarray 23 | """ 24 | return var.data.cpu().numpy() 25 | 26 | CIFAR_MEAN = [0.4914, 0.4822, 0.4465] 27 | CIFAR_STD = [0.2470, 0.2435, 0.2616] 28 | 29 | def get_eps_params(base_eps, resol): 30 | eps_list = [] 31 | max_list = [] 32 | min_list = [] 33 | for i in range(3): 34 | eps_list.append(torch.full((resol, resol), base_eps, device='cuda')) 35 | min_list.append(torch.full((resol, resol), 0., device='cuda')) 36 | max_list.append(torch.full((resol, resol), 255., device='cuda')) 37 | 38 | eps_t = torch.unsqueeze(torch.stack(eps_list), 0) 39 | max_t = torch.unsqueeze(torch.stack(max_list), 0) 40 | min_t = torch.unsqueeze(torch.stack(min_list), 0) 41 | return eps_t, max_t, min_t 42 | 43 | def get_cifar_params(resol): 44 | mean_list = [] 45 | std_list = [] 46 | for i in range(3): 47 | mean_list.append(torch.full((resol, resol), CIFAR_MEAN[i], device='cuda')) 48 | std_list.append(torch.full((resol, resol), CIFAR_STD[i], device='cuda')) 49 | return torch.unsqueeze(torch.stack(mean_list), 0), torch.unsqueeze(torch.stack(std_list), 0) 50 | 51 | class CIFARNORMALIZE(nn.Module): 52 | def __init__(self, resol): 53 | super().__init__() 54 | self.mean, self.std = get_cifar_params(resol) 55 | 56 | def forward(self, x): 57 | ''' 58 | Parameters: 59 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 60 | ''' 61 | x = x.sub(self.mean) 62 | x = x.div(self.std) 63 | return x 64 | 65 | class CIFARINNORMALIZE(nn.Module): 66 | def __init__(self, resol): 67 | super().__init__() 68 | self.mean, self.std = get_cifar_params(resol) 69 | 70 | def forward(self, x): 71 | ''' 72 | Parameters: 73 | x: input image with pixels normalized to ([0, 1] - IMAGENET_MEAN) / IMAGENET_STD 74 | ''' 75 | x = x.mul(self.std) 76 | x = x.add(*self.mean) 77 | return x 78 | 79 | 80 | class _MahalanobisLoss(nn.Module): 81 | def __init__(self): 82 | """ 83 | n_bins (int): number of confidence interval bins 84 | """ 85 | super(_MahalanobisLoss, self).__init__() 86 | 87 | def forward(self, feature, mean, inverse_cov): 88 | mahalaonbis_loss = 0 89 | zero_f = feature - mean 90 | mahalaonbis_loss = torch.mm(torch.mm(zero_f, inverse_cov), zero_f.t()).diag() 91 | mahalaonbis_loss = torch.mean(mahalaonbis_loss) 92 | return mahalaonbis_loss 93 | 94 | 95 | class _MahalanobisEnsembleLoss(nn.Module): 96 | def __init__(self): 97 | """ 98 | n_bins (int): number of confidence interval bins 99 | """ 100 | super(_MahalanobisEnsembleLoss, self).__init__() 101 | 102 | def forward(self, feature, mean, inverse_cov, weight, top2_index): 103 | mahalaonbis_loss = 0 104 | for i in range(len(mean)): 105 | temp_loss = 0 106 | final_mean = mean[i].index_select(0, top2_index.cuda()) 107 | final_mean = Variable(final_mean) 108 | zero_f = feature[i] - final_mean 109 | temp_loss = torch.mm(torch.mm(zero_f, Variable(inverse_cov[i])), zero_f.t()).diag() 110 | mahalaonbis_loss += weight[i]*torch.mean(temp_loss) 111 | return mahalaonbis_loss 112 | 113 | 114 | class MarginLoss(nn.Module): 115 | """ 116 | Calculates the margin loss max(kappa, (max z_k (x) k != y) - z_y(x)), 117 | also known as the f6 loss used by the Carlini & Wagner attack. 118 | """ 119 | 120 | def __init__(self, kappa=float('inf'), targeted=False): 121 | super().__init__() 122 | self.kappa = kappa 123 | self.targeted = targeted 124 | 125 | def forward(self, logits, labels): 126 | correct_logits = torch.gather(logits, 1, labels.view(-1, 1)) 127 | 128 | max_2_logits, argmax_2_logits = torch.topk(logits, 2, dim=1) 129 | top_max, second_max = max_2_logits.chunk(2, dim=1) 130 | top_argmax, second_argmax = argmax_2_logits.chunk(2, dim=1) 131 | labels_eq_max = top_argmax.squeeze().eq(labels).float().view(-1, 1) 132 | labels_ne_max = top_argmax.squeeze().ne(labels).float().view(-1, 1) 133 | max_incorrect_logits = labels_eq_max * second_max + labels_ne_max * top_max 134 | max_incorrect_index = labels_eq_max * second_argmax + labels_ne_max * top_argmax 135 | if self.targeted: 136 | return (correct_logits - max_incorrect_logits) \ 137 | .clamp(max=self.kappa).squeeze() 138 | else: 139 | return (max_incorrect_logits - correct_logits) \ 140 | .clamp(max=self.kappa).squeeze() 141 | 142 | 143 | class DeltaAttack(nn.Module): 144 | def __init__(self, model, vae, eps_max=8/255, step_size=None, num_iterations=7, datasets = 'cifar10', norm='linf', rand_init=True, scale_each=False, loss='ce'): 145 | super().__init__() 146 | self.nb_its = num_iterations 147 | self.eps_max = eps_max 148 | if step_size is None: 149 | step_size = eps_max / (self.nb_its ** 0.5) 150 | self.step_size = step_size 151 | 152 | self.norm = norm 153 | self.rand_init = rand_init 154 | self.scale_each = scale_each 155 | self.loss = loss 156 | 157 | if self.loss == 'margin': 158 | self.criterion = MarginLoss(kappa=1000) 159 | else: 160 | self.criterion = nn.CrossEntropyLoss().cuda() 161 | self.model = model 162 | self.vae = vae 163 | self.datasets = datasets 164 | if self.datasets == 'cifar10': 165 | self.normalize = CIFARNORMALIZE(32) 166 | self.innormalize = CIFARINNORMALIZE(32) 167 | 168 | def _init(self, shape, eps): 169 | if self.rand_init: 170 | if self.norm == 'linf': 171 | init = torch.rand(shape, dtype=torch.float32, device='cuda') * 2 - 1 172 | elif self.norm == 'l2': 173 | init = torch.randn(shape, dtype=torch.float32, device='cuda') 174 | init_norm = torch.norm(init.view(init.size()[0], -1), 2.0, dim=1) 175 | normalized_init = init / init_norm[:, None, None, None] 176 | dim = init.size()[1] * init.size()[2] * init.size()[3] 177 | rand_norms = torch.pow(torch.rand(init.size()[0], dtype=torch.float32, device='cuda'), 1/dim) 178 | init = normalized_init * rand_norms[:, None, None, None] 179 | else: 180 | raise NotImplementedError 181 | init = eps[:, None, None, None] * init 182 | init.requires_grad_() 183 | return init 184 | else: 185 | return torch.zeros(shape, requires_grad=True, device='cuda') 186 | 187 | def forward(self, img, labels): 188 | img = self.innormalize(img) #0-1 189 | base_eps = self.eps_max * torch.ones(img.size()[0], device='cuda') 190 | step_size = self.step_size * torch.ones(img.size()[0], device='cuda') 191 | 192 | img = img.detach() 193 | img.requires_grad = True 194 | delta = self._init(img.size(), base_eps) 195 | 196 | s = self.model(self.normalize(img + delta) 197 | - self.vae(self.normalize(img + delta))) 198 | if self.norm == 'l2': 199 | l2_max = base_eps 200 | for it in range(self.nb_its): 201 | loss = self.criterion(s, labels) 202 | 203 | if self.loss == 'margin': 204 | loss.sum().backward() 205 | else: 206 | loss.backward() 207 | ''' 208 | Because of batching, this grad is scaled down by 1 / batch_size, which does not matter 209 | for what follows because of normalization. 210 | ''' 211 | grad = delta.grad.data 212 | 213 | if self.norm == 'linf': 214 | grad_sign = grad.sign() 215 | delta.data = delta.data + step_size[:, None, None, None] * grad_sign 216 | delta.data = torch.max(torch.min(delta.data, base_eps[:, None, None, None]), -base_eps[:, None, None, None]) 217 | delta.data = torch.clamp(img.data + delta.data, 0., 1.) - img.data 218 | elif self.norm == 'l2': 219 | batch_size = delta.data.size()[0] 220 | grad_norm = torch.norm(grad.view(batch_size, -1), 2.0, dim=1) 221 | normalized_grad = grad / grad_norm[:, None, None, None] 222 | delta.data = delta.data + step_size[:, None, None, None] * normalized_grad 223 | l2_delta = torch.norm(delta.data.view(batch_size, -1), 2.0, dim=1) 224 | # Check for numerical instability 225 | proj_scale = torch.min(torch.ones_like(l2_delta, device='cuda'), l2_max / l2_delta) 226 | delta.data *= proj_scale[:, None, None, None] 227 | delta.data = torch.clamp(img.data + delta.data, 0., 1.) - img.data 228 | else: 229 | raise NotImplementedError 230 | 231 | if it != self.nb_its - 1: 232 | s = self.model(self.normalize(img + delta) 233 | - self.vae(self.normalize(img + delta))) 234 | delta.grad.data.zero_() 235 | delta.data[torch.isnan(delta.data)] = 0 236 | adv_sample = img + delta 237 | adv_sample = torch.clamp(adv_sample.detach(), 0, 1) 238 | return self.normalize(adv_sample) 239 | -------------------------------------------------------------------------------- /networks/vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import abc 3 | import os 4 | import math 5 | 6 | import numpy as np 7 | import logging 8 | import torch 9 | import torch.utils.data 10 | from torch import nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | import pdb 15 | import sys 16 | from .resnet import resnet50 17 | from .nearest_embed import NearestEmbed 18 | sys.path.append('.') 19 | sys.path.append('..') 20 | from utils.normalize import * 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 25 | 26 | 27 | def conv_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 31 | init.constant_(m.bias, 0) 32 | elif classname.find('BatchNorm') != -1: 33 | init.constant_(m.weight, 1) 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | class wide_basic(nn.Module): 38 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 39 | super(wide_basic, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 42 | self.dropout = nn.Dropout(p=dropout_rate) 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 45 | 46 | self.shortcut = nn.Sequential() 47 | if stride != 1 or in_planes != planes: 48 | self.shortcut = nn.Sequential( 49 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 50 | ) 51 | 52 | def forward(self, x): 53 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 54 | out = self.conv2(F.relu(self.bn2(out))) 55 | out += self.shortcut(x) 56 | 57 | return out 58 | 59 | 60 | class Wide_ResNet(nn.Module): 61 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, norm=False): 62 | super(Wide_ResNet, self).__init__() 63 | self.in_planes = 16 64 | 65 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 66 | n = (depth-4)/6 67 | k = widen_factor 68 | 69 | print('| Wide-Resnet %dx%d' %(depth, k)) 70 | nStages = [16, 16*k, 32*k, 64*k] 71 | 72 | self.conv1 = conv3x3(3,nStages[0]) 73 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 74 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 75 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 76 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 77 | self.linear = nn.Linear(nStages[3], num_classes) 78 | self.normalize = CIFARNORMALIZE(32) 79 | self.norm = norm 80 | 81 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 82 | strides = [stride] + [1]*(int(num_blocks)-1) 83 | layers = [] 84 | 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 87 | self.in_planes = planes 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | if self.norm: 93 | x = self.normalize(x) 94 | out = self.conv1(x) 95 | out = self.layer1(out) 96 | out = self.layer2(out) 97 | out = self.layer3(out) 98 | out = F.relu(self.bn1(out)) 99 | out = F.avg_pool2d(out, 8) 100 | out = out.view(out.size(0), -1) 101 | out = self.linear(out) 102 | 103 | return out 104 | 105 | 106 | class ResBlock(nn.Module): 107 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 108 | super(ResBlock, self).__init__() 109 | 110 | if mid_channels is None: 111 | mid_channels = out_channels 112 | 113 | layers = [ 114 | nn.LeakyReLU(), 115 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 116 | nn.LeakyReLU(), 117 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)] 118 | if bn: 119 | layers.insert(2, nn.BatchNorm2d(out_channels)) 120 | self.convs = nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | return x + self.convs(x) 124 | 125 | 126 | class AbstractAutoEncoder(nn.Module): 127 | __metaclass__ = abc.ABCMeta 128 | 129 | @abc.abstractmethod 130 | def encode(self, x): 131 | return 132 | 133 | @abc.abstractmethod 134 | def decode(self, z): 135 | return 136 | 137 | @abc.abstractmethod 138 | def forward(self, x): 139 | """model return (reconstructed_x, *)""" 140 | return 141 | 142 | @abc.abstractmethod 143 | def sample(self, size): 144 | """sample new images from model""" 145 | return 146 | 147 | @abc.abstractmethod 148 | def loss_function(self, **kwargs): 149 | """accepts (original images, *) where * is the same as returned from forward()""" 150 | return 151 | 152 | @abc.abstractmethod 153 | def latest_losses(self): 154 | """returns the latest losses in a dictionary. Useful for logging.""" 155 | return 156 | 157 | 158 | class CVAE_cifar(AbstractAutoEncoder): 159 | def __init__(self, d, z, with_classifier=True, **kwargs): 160 | super(CVAE_cifar, self).__init__() 161 | 162 | self.encoder = nn.Sequential( 163 | nn.Conv2d(3, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 164 | nn.BatchNorm2d(d // 2), 165 | nn.ReLU(inplace=True), 166 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False), 167 | nn.BatchNorm2d(d), 168 | nn.ReLU(inplace=True), 169 | ResBlock(d, d, bn=True), 170 | nn.BatchNorm2d(d), 171 | ResBlock(d, d, bn=True), 172 | ) 173 | 174 | self.decoder = nn.Sequential( 175 | ResBlock(d, d, bn=True), 176 | nn.BatchNorm2d(d), 177 | ResBlock(d, d, bn=True), 178 | nn.BatchNorm2d(d), 179 | 180 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 181 | nn.BatchNorm2d(d // 2), 182 | nn.LeakyReLU(inplace=True), 183 | nn.ConvTranspose2d(d // 2, 3, kernel_size=4, stride=2, padding=1, bias=False), 184 | ) 185 | 186 | self.xi_bn = nn.BatchNorm2d(3) 187 | 188 | self.f = 8 189 | self.d = d 190 | self.z = z 191 | self.fc11 = nn.Linear(d * self.f ** 2, self.z) 192 | self.fc12 = nn.Linear(d * self.f ** 2, self.z) 193 | self.fc21 = nn.Linear(self.z, d * self.f ** 2) 194 | 195 | self.with_classifier = with_classifier 196 | if self.with_classifier: 197 | self.classifier = Wide_ResNet(28, 10, 0.3, 10) 198 | 199 | def encode(self, x): 200 | h = self.encoder(x) 201 | h1 = h.view(-1, self.d * self.f ** 2) 202 | return h, self.fc11(h1), self.fc12(h1) 203 | 204 | def reparameterize(self, mu, logvar): 205 | if self.training: 206 | std = logvar.mul(0.5).exp_() 207 | eps = std.new(std.size()).normal_() 208 | return eps.mul(std).add_(mu) 209 | else: 210 | return mu 211 | 212 | def decode(self, z): 213 | z = z.view(-1, self.d, self.f, self.f) 214 | h3 = self.decoder(z) 215 | return torch.tanh(h3) 216 | 217 | def forward(self, x): 218 | _, mu, logvar = self.encode(x) 219 | hi = self.reparameterize(mu, logvar) #+ noise* torch.randn(mu.size()).cuda() 220 | hi_projected = self.fc21(hi) 221 | xi = self.decode(hi_projected) 222 | xi = self.xi_bn(xi) 223 | 224 | if self.with_classifier: 225 | out = self.classifier(x - xi) 226 | return out, hi, xi, mu, logvar 227 | else: 228 | return xi 229 | 230 | class CVAE_imagenet(nn.Module): 231 | def __init__(self, d, k=10, num_classes=9, num_channels=3, **kwargs): 232 | super(CVAE_imagenet, self).__init__() 233 | 234 | self.encoder = nn.Sequential( 235 | nn.Conv2d(num_channels, d, kernel_size=4, stride=2, padding=1), 236 | nn.BatchNorm2d(d), 237 | nn.LeakyReLU(inplace=True), 238 | nn.Conv2d(d, d, kernel_size=4, stride=2, padding=1), 239 | nn.BatchNorm2d(d), 240 | nn.LeakyReLU(inplace=True), 241 | ResBlock(d, d), 242 | nn.BatchNorm2d(d), 243 | ResBlock(d, d), 244 | nn.BatchNorm2d(d), 245 | ) 246 | self.decoder = nn.Sequential( 247 | ResBlock(d, d), 248 | nn.BatchNorm2d(d), 249 | ResBlock(d, d), 250 | nn.ConvTranspose2d(d, d, kernel_size=4, stride=2, padding=1), 251 | nn.BatchNorm2d(d), 252 | nn.LeakyReLU(inplace=True), 253 | nn.ConvTranspose2d(d, num_channels, kernel_size=4, stride=2, padding=1), 254 | ) 255 | self.d = d 256 | self.emb = NearestEmbed(k, d) 257 | 258 | for l in self.modules(): 259 | if isinstance(l, nn.Linear) or isinstance(l, nn.Conv2d): 260 | l.weight.detach().normal_(0, 0.02) 261 | torch.fmod(l.weight, 0.04) 262 | nn.init.constant_(l.bias, 0) 263 | 264 | self.encoder[-1].weight.detach().fill_(1 / 40) 265 | 266 | self.emb.weight.detach().normal_(0, 0.02) 267 | torch.fmod(self.emb.weight, 0.04) 268 | 269 | self.classifier = resnet50(pretrained=True, num_classes=num_classes) 270 | 271 | self.L_bn = nn.BatchNorm2d(num_channels) 272 | 273 | def encode(self, x): 274 | return self.encoder(x) 275 | 276 | def decode(self, x): 277 | return torch.tanh(self.decoder(x)) 278 | 279 | def forward(self, x): 280 | 281 | z_e = self.encode(x) 282 | 283 | z_q, _ = self.emb(z_e, weight_sg=True) 284 | emb, _ = self.emb(z_e.detach()) 285 | 286 | l = self.decode(z_q) 287 | gx = self.L_bn(l) 288 | out = self.classifier(x-gx) 289 | 290 | return out, gx, z_e, emb 291 | -------------------------------------------------------------------------------- /tools/disentangle_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import wandb 13 | import os 14 | import time 15 | import argparse 16 | import datetime 17 | from torch.autograd import Variable 18 | import pdb 19 | import sys 20 | 21 | sys.path.append('.') 22 | 23 | from networks.vae import * 24 | from utils.set import * 25 | from utils.randaugment4fixmatch import RandAugmentMC 26 | 27 | 28 | def reconst_images(epoch=2, batch_size=64, batch_num=2, dataloader=None, model=None): 29 | cifar10_dataloader = dataloader 30 | 31 | model.eval() 32 | 33 | with torch.no_grad(): 34 | for batch_idx, (X, y) in enumerate(cifar10_dataloader): 35 | if batch_idx >= batch_num: 36 | break 37 | else: 38 | X, y = X.cuda(), y.cuda().view(-1, ) 39 | _,_, gx, _, _ = model(X) 40 | 41 | grid_X = torchvision.utils.make_grid(X[:batch_size].data, nrow=8, padding=2, normalize=True) 42 | wandb.log({"_Batch_{batch}_X.jpg".format(batch=batch_idx): [ 43 | wandb.Image(grid_X)]}, commit=False) 44 | grid_GX = torchvision.utils.make_grid(gx[:batch_size].data, nrow=8, padding=2, normalize=True) 45 | wandb.log({"_Batch_{batch}_GX.jpg".format(batch=batch_idx): [ 46 | wandb.Image(grid_GX)]}, commit=False) 47 | grid_RX = torchvision.utils.make_grid((X[:batch_size] - gx[:batch_size]).data, nrow=8, padding=2, 48 | normalize=True) 49 | wandb.log({"_Batch_{batch}_RX.jpg".format(batch=batch_idx): [ 50 | wandb.Image(grid_RX)]}, commit=False) 51 | print('reconstruction complete!') 52 | 53 | 54 | def test(epoch, model, testloader): 55 | # set model as testing mode 56 | model.eval() 57 | # all_l, all_s, all_y, all_z, all_mu, all_logvar = [], [], [], [], [], [] 58 | acc_avg = AverageMeter() 59 | sparse_avg = AverageMeter() 60 | top1 = AverageMeter() 61 | TC = AverageMeter() 62 | 63 | with torch.no_grad(): 64 | for batch_idx, (x, y) in enumerate(testloader): 65 | # distribute data to device 66 | x, y = x.cuda(), y.cuda().view(-1, ) 67 | bs = x.size(0) 68 | norm = torch.norm(torch.abs(x.view(100, -1)), p=2, dim=1) 69 | out, hi, gx, mu, logvar = model(x) 70 | acc_gx = 1 - F.mse_loss(torch.div(gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 71 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 72 | reduction='sum') / 100 73 | acc_rx = 1 - F.mse_loss(torch.div(x - gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 74 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 75 | reduction='sum') / 100 76 | 77 | acc_avg.update(acc_gx.data.item(), bs) 78 | # measure accuracy and record loss 79 | sparse_avg.update(acc_rx.data.item(), bs) 80 | # measure accuracy and record loss 81 | prec1, _, _, _ = accuracy(out.data, y.data, topk=(1, 5)) 82 | top1.update(prec1.item(), bs) 83 | 84 | tc = total_correlation(hi, mu, logvar) / bs / args.dim 85 | TC.update(tc.item(), bs) 86 | 87 | wandb.log({'acc_avg': acc_avg.avg, \ 88 | 'sparse_avg': sparse_avg.avg, \ 89 | 'test-RX-acc': top1.avg, \ 90 | 'test-TC': TC.avg}, commit=False) 91 | # plot progress 92 | print("\n| Validation Epoch #%d\t\tRec Acc: %.4f Class Acc: %.4f TC: %.4f" % (epoch, acc_avg.avg, top1.avg, TC.avg)) 93 | reconst_images(epoch=epoch, batch_size=64, batch_num=2, dataloader=testloader, model=model) 94 | torch.save(model.state_dict(), 95 | os.path.join(args.save_dir, 'model_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 96 | print("Epoch {} model saved!".format(epoch + 1)) 97 | 98 | 99 | def train(args, epoch, model, optimizer, trainloader): 100 | model.train() 101 | model.training = True 102 | 103 | loss_avg = AverageMeter() 104 | loss_rec = AverageMeter() 105 | loss_ce = AverageMeter() 106 | loss_entropy = AverageMeter() 107 | loss_kl = AverageMeter() 108 | top1 = AverageMeter() 109 | 110 | print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, optimizer.param_groups[0]['lr'])) 111 | for batch_idx, (x, y) in enumerate(trainloader): 112 | x, y, y_b, lam, mixup_index = mixup_data(x, y, alpha=args.alpha) 113 | x, y, y_b = x.cuda(), y.cuda().view(-1, ), y_b.cuda().view(-1, ) 114 | x, y = Variable(x), [Variable(y), Variable(y_b)] 115 | bs = x.size(0) 116 | optimizer.zero_grad() 117 | 118 | out, _, xi, mu, logvar = model(x) 119 | 120 | if args.curriculum: 121 | if epoch < 100: 122 | re = 10*args.re 123 | elif epoch < 200: 124 | re = 5*args.re 125 | else: 126 | re = args.re 127 | else: 128 | re = args.re 129 | 130 | l1 = F.mse_loss(xi, x) 131 | cross_entropy = lam * F.cross_entropy(out, y[0]) + (1. - lam) * F.cross_entropy(out, y[1]) 132 | l2 = cross_entropy 133 | l3 = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 134 | l3 /= bs * 3 * args.dim 135 | loss = re * l1 + args.ce * l2 + args.kl * l3 136 | loss.backward() 137 | optimizer.step() 138 | 139 | prec1, prec5, correct, pred = accuracy(out.data, y[0].data, topk=(1, 5)) 140 | loss_avg.update(loss.data.item(), bs) 141 | loss_rec.update(l1.data.item(), bs) 142 | loss_ce.update(cross_entropy.data.item(), bs) 143 | loss_kl.update(l3.data.item(), bs) 144 | top1.update(prec1.item(), bs) 145 | 146 | n_iter = (epoch - 1) * len(trainloader) + batch_idx 147 | wandb.log({'loss': loss_avg.avg, \ 148 | 'loss_rec': loss_rec.avg, \ 149 | 'loss_ce': loss_ce.avg, \ 150 | 'loss_kl': loss_kl.avg, \ 151 | 'acc': top1.avg, 152 | 're_weight': re, 153 | 'lr':optimizer.param_groups[0]['lr']}, step=n_iter) 154 | if (batch_idx + 1) % 30 == 0: 155 | sys.stdout.write('\r') 156 | sys.stdout.write( 157 | '| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Loss_rec: %.4f Loss_ce: %.4f Loss_entropy: %.4f Loss_kl: %.4f Acc@1: %.3f%%' 158 | % (epoch, args.epochs, batch_idx + 1, 159 | len(trainloader), loss_avg.avg, loss_rec.avg, loss_ce.avg, loss_entropy.avg, loss_kl.avg, top1.avg)) 160 | 161 | 162 | def main(args): 163 | learning_rate = 1.e-3 164 | learning_rate_min = 2.e-4 165 | CNN_embed_dim = args.dim 166 | feature_dim = args.fdim 167 | setup_logger(args.save_dir) 168 | use_cuda = torch.cuda.is_available() 169 | best_acc = 0 170 | print('\n[Phase 1] : Data Preparation') 171 | transform_train = transforms.Compose([ 172 | transforms.RandomCrop(32, padding=4), 173 | transforms.RandomHorizontalFlip(), 174 | RandAugmentMC(n=2, m=10), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), 177 | ]) 178 | 179 | transform_test = transforms.Compose([ 180 | transforms.ToTensor(), 181 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), 182 | ]) 183 | if (args.dataset == 'cifar10'): 184 | print("| Preparing CIFAR-10 dataset...") 185 | sys.stdout.write("| ") 186 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train) 187 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test) 188 | 189 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0) 190 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 191 | 192 | # Model 193 | print('\n[Phase 2] : Model setup') 194 | model = CVAE_cifar(d=feature_dim, z=CNN_embed_dim) 195 | 196 | if use_cuda: 197 | model.cuda() 198 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 199 | cudnn.benchmark = True 200 | 201 | optimizer = AdamW([ 202 | {'params': model.parameters()} 203 | ], lr=learning_rate, betas=(0.9, 0.999), weight_decay=1.e-6) 204 | 205 | if args.optim == 'cosine': 206 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, 207 | eta_min=learning_rate_min) 208 | else: 209 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.step, gamma=0.1, last_epoch=-1) 210 | 211 | print('\n[Phase 3] : Training model') 212 | print('| Training Epochs = ' + str(args.epochs)) 213 | 214 | start_epoch = 1 215 | elapsed_time = 0 216 | for epoch in range(start_epoch, start_epoch + args.epochs): 217 | start_time = time.time() 218 | train(args, epoch, model, optimizer, trainloader) 219 | scheduler.step() 220 | if epoch % 10 == 0: 221 | test(epoch, model, testloader) 222 | 223 | epoch_time = time.time() - start_time 224 | elapsed_time += epoch_time 225 | 226 | wandb.finish() 227 | print('\n[Phase 4] : Testing model') 228 | print('* Test results : Acc@1 = %.2f%%' % (best_acc)) 229 | 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 233 | parser.add_argument('--save_dir', default='./results/autoaug_new_8_0.5/', type=str, help='save_dir') 234 | parser.add_argument('--seed', default=666, type=int, help='seed') 235 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100]') 236 | parser.add_argument('--optim', default='cosine', type=str, help='optimizer') 237 | parser.add_argument('--alpha', default=2.0, type=float, help='mix up') 238 | parser.add_argument('--epochs', default=300, type=int, help='training_epochs') 239 | parser.add_argument('--batch_size', default=256, type=int, help='batch size') 240 | parser.add_argument('--dim', default=2048, type=int, help='CNN_embed_dim') 241 | parser.add_argument('--T', default=50, type=int, help='Cosine T') 242 | parser.add_argument('--fdim', default=32, type=int, help='featdim') 243 | parser.add_argument('--step', nargs='+', type=int) 244 | parser.add_argument('--re', default=1.0, type=float, help='reconstruction weight') 245 | parser.add_argument('--curriculum', default=True, 246 | help='Curriculum for reconstruction term which helps for better convergence') 247 | parser.add_argument('--kl', default=0.2, type=float, help='kl weight') 248 | parser.add_argument('--ce', default=0.2, type=float, help='cross entropy weight') 249 | args = parser.parse_args() 250 | wandb.init(config=args, name=args.save_dir.replace("results/", '')) 251 | set_random_seed(args.seed) 252 | main(args) 253 | -------------------------------------------------------------------------------- /tools/adv_train_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | from tqdm import tqdm 3 | from copy import deepcopy 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import wandb 7 | import os 8 | import time 9 | import argparse 10 | import datetime 11 | from torch.autograd import Variable 12 | import pdb 13 | import sys 14 | import torch.optim as optim 15 | sys.path.append('.') 16 | 17 | from networks.adv_vae import * 18 | from utils.set import * 19 | from utils.randaugment4fixmatch import RandAugmentMC 20 | from utils.normalize import * 21 | from advex.attacks import * 22 | from evaluation import evaluate_cdvae 23 | normalize = CIFARNORMALIZE(32) 24 | 25 | 26 | def Incorrect_Logits(logits, labels, margin): 27 | max_2_logits, argmax_2_logits = torch.topk(logits, 2, dim=1) 28 | top_max, second_max = max_2_logits.chunk(2, dim=1) 29 | top_argmax, second_argmax = argmax_2_logits.chunk(2, dim=1) 30 | labels_eq_max = top_argmax.squeeze().eq(labels).float().view(-1, 1) 31 | labels_ne_max = top_argmax.squeeze().ne(labels).float().view(-1, 1) 32 | max_incorrect_logits = labels_eq_max * second_max + labels_ne_max * top_max 33 | correct_logits = torch.gather(logits, 1, labels.view(-1, 1)) 34 | return ((correct_logits - max_incorrect_logits)= 80: 41 | factor = factor + 1 42 | 43 | lr = args.lr * (0.1 ** factor) 44 | 45 | """Warmup""" 46 | if epoch < 5: 47 | lr = lr * float(1 + step) / (5. * len_epoch) 48 | 49 | # if(args.local_rank == 0): 50 | # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) 51 | 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = lr 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 58 | parser.add_argument('--lr', default=0.1, type=float, help='learning_rate') 59 | parser.add_argument('--save_dir', default='./results/autoaug_new_8_0.5/', type=str, help='save_dir') 60 | parser.add_argument('--seed', default=666, type=int, help='seed') 61 | parser.add_argument('--batch_size', default=128, type=int, help='seed') 62 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100]') 63 | parser.add_argument('--epochs', default=150, type=int, help='training_epochs') 64 | parser.add_argument('--dim', default=2048, type=int, help='CNN_embed_dim') 65 | parser.add_argument('--fdim', default=32, type=int, help='featdim') 66 | parser.add_argument('--margin', default=8.0, type=float, help='margin') 67 | parser.add_argument('--re', default=1.0, type=float, help='re weight') 68 | parser.add_argument('--kl', default=0.01, type=float, help='kl weight') 69 | parser.add_argument('--cr', default=1.0, type=float, help='cross entropy weight') 70 | parser.add_argument('--cg', default=1.0, type=float, help='cross entropy weight') 71 | parser.add_argument("--model_path", type=str, default="./pretrained/wide_resnet.pth") 72 | parser.add_argument("--vae_path", type=str, default="./pretrained/cd-vae-2.pth") 73 | parser.add_argument('--clip_grad', type=float, default=1.0, 74 | help='clip gradients to this value') 75 | args = parser.parse_args() 76 | 77 | wandb.init(config=args, name=args.save_dir.replace("results/", '')) 78 | set_random_seed(args.seed) 79 | CNN_embed_dim = args.dim 80 | feature_dim = args.fdim 81 | setup_logger(args.save_dir) 82 | use_cuda = torch.cuda.is_available() 83 | 84 | print('\n[Phase 1] : Data Preparation') 85 | transform_train = transforms.Compose([ 86 | transforms.RandomCrop(32, padding=4), 87 | transforms.RandomHorizontalFlip(), 88 | RandAugmentMC(n=2, m=10), 89 | transforms.ToTensor(), 90 | ]) # meanstd transformation 91 | 92 | transform_test = transforms.Compose([ 93 | transforms.ToTensor(), 94 | ]) 95 | 96 | print("| Preparing CIFAR-10 dataset...") 97 | sys.stdout.write("| ") 98 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train) 99 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test) 100 | 101 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0) 102 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) 103 | 104 | # Model 105 | print('\n[Phase 2] : Model setup') 106 | vae = CVAE_cifar(d=feature_dim, z=CNN_embed_dim, with_classifier=False) 107 | model_r = Wide_ResNet(28, 10, 0.3, 10) 108 | model_g = Wide_ResNet(28, 10, 0.3, 10) 109 | if use_cuda: 110 | vae.cuda() 111 | vae = torch.nn.DataParallel(vae, device_ids=range(torch.cuda.device_count())) 112 | 113 | model_r.cuda() 114 | model_r = torch.nn.DataParallel(model_r, device_ids=range(torch.cuda.device_count())) 115 | 116 | model_g.cuda() 117 | model_g = torch.nn.DataParallel(model_g, device_ids=range(torch.cuda.device_count())) 118 | 119 | cudnn.benchmark = True 120 | 121 | save_model = torch.load(args.vae_path) 122 | model_dict = vae.state_dict() 123 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 124 | model_dict.update(state_dict) 125 | vae.load_state_dict(model_dict) 126 | 127 | model_g.load_state_dict(torch.load(args.model_path)) 128 | 129 | model_dict = model_r.state_dict() 130 | state_dict = {k.replace('classifier.', ''): v for k, v in save_model.items() if 131 | k.replace('classifier.', '') in model_dict.keys()} 132 | model_dict.update(state_dict) 133 | model_r.load_state_dict(model_dict) 134 | 135 | optimizer = optim.SGD([ 136 | {'params': vae.parameters()}, 137 | {'params': model_r.parameters()}, 138 | {'params': model_g.parameters()}], 139 | lr=args.lr, 140 | momentum=0.9, 141 | weight_decay=2e-4) 142 | 143 | print('\n[Phase 3] : Training model') 144 | print('| Training Epochs = ' + str(args.epochs)) 145 | 146 | iteration = 0 147 | attack = AttackV2(model_g, vae, num_iterations=10, loss='margin') 148 | validation_attacks = [NoAttack(), 149 | AttackV2(model_g, vae, num_iterations=100), 150 | AttackV2(model_g, vae, num_iterations=100, norm='l2', eps_max=1.0), 151 | AttackV2(model_g, vae, num_iterations=100, loss='margin'), 152 | AttackV2(model_g, vae, num_iterations=100, norm='l2', eps_max=1.0,loss='margin')] 153 | elapsed_time = 0 154 | 155 | def run_iter(inputs, labels, iteration): 156 | model_r.eval() # set model to eval to generate adversarial examples 157 | model_g.eval() 158 | vae.eval() 159 | 160 | if torch.cuda.is_available(): 161 | inputs = inputs.cuda() 162 | labels = labels.cuda() 163 | bs = inputs.size(0) 164 | with torch.no_grad(): 165 | gx, _, _ = vae(normalize(inputs)) 166 | orig_logits = model_g(gx) 167 | orig_accuracy, _, _, _ = accuracy(orig_logits.data, labels.data, topk=(1, 5)) 168 | to_attack = orig_logits.argmax(1) == labels 169 | 170 | adv_inputs = inputs.clone() 171 | if to_attack.sum()>0: 172 | adv_inputs[to_attack]= attack(inputs[to_attack], labels[to_attack]) 173 | 174 | with torch.no_grad(): 175 | gx, _, _ = vae(normalize(adv_inputs)) 176 | adv_logits_g = model_g(gx) 177 | adv_accuracy, _, _, _ = accuracy(adv_logits_g.data, labels.data, topk=(1, 5)) 178 | incorrect = Incorrect_Logits(adv_logits_g, labels, args.margin) 179 | 180 | optimizer.zero_grad() 181 | model_r.train() 182 | model_g.train() 183 | vae.train() 184 | 185 | gx, mu, logvar = vae(normalize(adv_inputs)) 186 | logits_g = model_g(gx) 187 | logits_r = model_r(normalize(adv_inputs)-gx) 188 | 189 | l1 = F.mse_loss(gx, normalize(inputs)) 190 | 191 | l2 = args.cr * F.cross_entropy(logits_r, adv_logits_g.argmax(1)) \ 192 | + args.cg * F.cross_entropy(logits_g[incorrect], labels[incorrect]) 193 | 194 | l3 = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 195 | l3 /= bs * 3 * args.dim 196 | 197 | loss = args.re * l1 + l2 + args.kl * l3 198 | loss.backward() 199 | nn.utils.clip_grad_value_(model_r.parameters(), args.clip_grad) 200 | nn.utils.clip_grad_value_(model_g.parameters(), args.clip_grad) 201 | nn.utils.clip_grad_value_(vae.parameters(), args.clip_grad) 202 | optimizer.step() 203 | 204 | wandb.log({'loss': loss.item()}, step=iteration) 205 | wandb.log({'loss1': l1.item()}, step=iteration) 206 | wandb.log({'loss2': l2.item()}, step=iteration) 207 | wandb.log({'loss3': l3.item()}, step=iteration) 208 | wandb.log({'adversarial_accuracy': adv_accuracy.item()}, step=iteration) 209 | wandb.log({'orig_accuracy': orig_accuracy.item()}, step=iteration) 210 | wandb.log({'lr': optimizer.param_groups[0]['lr']}, step=iteration) 211 | 212 | print(f'ITER {iteration:06d}', 213 | f'loss: {loss.item():.2f}', 214 | f'loss1: {l1.item():.2f}', 215 | f'loss2: {l2.item():.2f}', 216 | f'loss3: {l3.item():.2f}', 217 | f'orig_accuracy: {orig_accuracy.item():5.1f}%', 218 | f'acc_adv: {adv_accuracy.item() :5.1f}%', 219 | sep='\t') 220 | 221 | start_epoch = 1 222 | for epoch in range(start_epoch, start_epoch + args.epochs): 223 | start_time = time.time() 224 | for batch_index, (inputs, labels) in enumerate(trainloader): 225 | adjust_learning_rate(optimizer, epoch, iteration, len(trainloader)) 226 | run_iter(inputs, labels, iteration) 227 | iteration += 1 228 | 229 | if epoch % 10 == 1: 230 | print("\n| Validation Epoch #%d\t\t" % (epoch)) 231 | evaluate_cdvae.evaluate(wandb, model_r, model_g, vae, testloader, validation_attacks, 10) 232 | torch.save(model_r.state_dict(), 233 | os.path.join(args.save_dir, 'model_r_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 234 | torch.save(model_g.state_dict(), 235 | os.path.join(args.save_dir, 'robust_model_g_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 236 | torch.save(vae.state_dict(), 237 | os.path.join(args.save_dir, 'robust_vae_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 238 | print("Epoch {} model saved!".format(epoch + 1)) 239 | 240 | evaluate_cdvae.evaluate(wandb, model_r, model_g, vae, testloader, validation_attacks) 241 | torch.save(model_r.state_dict(), 242 | os.path.join(args.save_dir, 'model_r_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 243 | torch.save(model_g.state_dict(), 244 | os.path.join(args.save_dir, 'robust_model_g_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 245 | torch.save(vae.state_dict(), 246 | os.path.join(args.save_dir, 'robust_vae_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 247 | print("Epoch {} model saved!".format(epoch + 1)) 248 | wandb.finish() 249 | -------------------------------------------------------------------------------- /detection/lib/_tv_bregman.patch: -------------------------------------------------------------------------------- 1 | --- _denoise_cy.pyx 2017-12-11 10:48:58.425296545 -0800 2 | +++ _tv_bregman.pyx 2017-12-11 10:50:10.234351675 -0800 3 | @@ -7,229 +7,82 @@ 4 | import numpy as np 5 | from libc.math cimport exp, fabs, sqrt 6 | from libc.float cimport DBL_MAX 7 | -from .._shared.interpolation cimport get_pixel3d 8 | -from ..util import img_as_float 9 | - 10 | - 11 | -cdef inline double _gaussian_weight(double sigma, double value): 12 | - return exp(-0.5 * (value / sigma)**2) 13 | - 14 | - 15 | -cdef double[:] _compute_color_lut(Py_ssize_t bins, double sigma, double max_value): 16 | - 17 | - cdef: 18 | - double[:] color_lut = np.empty(bins, dtype=np.double) 19 | - Py_ssize_t b 20 | - 21 | - for b in range(bins): 22 | - color_lut[b] = _gaussian_weight(sigma, b * max_value / bins) 23 | - 24 | - return color_lut 25 | - 26 | - 27 | -cdef double[:] _compute_range_lut(Py_ssize_t win_size, double sigma): 28 | - 29 | - cdef: 30 | - double[:] range_lut = np.empty(win_size**2, dtype=np.double) 31 | - Py_ssize_t kr, kc 32 | - Py_ssize_t window_ext = (win_size - 1) / 2 33 | - double dist 34 | - 35 | - for kr in range(win_size): 36 | - for kc in range(win_size): 37 | - dist = sqrt((kr - window_ext)**2 + (kc - window_ext)**2) 38 | - range_lut[kr * win_size + kc] = _gaussian_weight(sigma, dist) 39 | - 40 | - return range_lut 41 | - 42 | - 43 | -cdef inline Py_ssize_t Py_ssize_t_min(Py_ssize_t value1, Py_ssize_t value2): 44 | - if value1 < value2: 45 | - return value1 46 | - else: 47 | - return value2 48 | - 49 | - 50 | -def _denoise_bilateral(image, Py_ssize_t win_size, sigma_color, 51 | - double sigma_spatial, Py_ssize_t bins, 52 | - mode, double cval): 53 | - cdef: 54 | - double min_value, max_value 55 | - 56 | - min_value = image.min() 57 | - max_value = image.max() 58 | - 59 | - if min_value == max_value: 60 | - return image 61 | - 62 | - # if image.max() is 0, then dist_scale can have an unverified value 63 | - # and color_lut[(dist * dist_scale)] may cause a segmentation fault 64 | - # so we verify we have a positive image and that the max is not 0.0. 65 | - if min_value < 0.0: 66 | - raise ValueError("Image must contain only positive values") 67 | - 68 | - if max_value == 0.0: 69 | - raise ValueError("The maximum value found in the image was 0.") 70 | - 71 | - image = np.atleast_3d(img_as_float(image)) 72 | - 73 | - cdef: 74 | - Py_ssize_t rows = image.shape[0] 75 | - Py_ssize_t cols = image.shape[1] 76 | - Py_ssize_t dims = image.shape[2] 77 | - Py_ssize_t window_ext = (win_size - 1) / 2 78 | - Py_ssize_t max_color_lut_bin = bins - 1 79 | - 80 | - double[:, :, ::1] cimage 81 | - double[:, :, ::1] out 82 | - 83 | - double[:] color_lut 84 | - double[:] range_lut 85 | - 86 | - Py_ssize_t r, c, d, wr, wc, kr, kc, rr, cc, pixel_addr, color_lut_bin 87 | - double value, weight, dist, total_weight, csigma_color, color_weight, \ 88 | - range_weight 89 | - double dist_scale 90 | - double[:] values 91 | - double[:] centres 92 | - double[:] total_values 93 | - 94 | - if sigma_color is None: 95 | - csigma_color = image.std() 96 | - else: 97 | - csigma_color = sigma_color 98 | - 99 | - if mode not in ('constant', 'wrap', 'symmetric', 'reflect', 'edge'): 100 | - raise ValueError("Invalid mode specified. Please use `constant`, " 101 | - "`edge`, `wrap`, `symmetric` or `reflect`.") 102 | - cdef char cmode = ord(mode[0].upper()) 103 | - 104 | - cimage = np.ascontiguousarray(image) 105 | - 106 | - out = np.zeros((rows, cols, dims), dtype=np.double) 107 | - color_lut = _compute_color_lut(bins, csigma_color, max_value) 108 | - range_lut = _compute_range_lut(win_size, sigma_spatial) 109 | - dist_scale = bins / dims / max_value 110 | - values = np.empty(dims, dtype=np.double) 111 | - centres = np.empty(dims, dtype=np.double) 112 | - total_values = np.empty(dims, dtype=np.double) 113 | - 114 | - for r in range(rows): 115 | - for c in range(cols): 116 | - total_weight = 0 117 | - for d in range(dims): 118 | - total_values[d] = 0 119 | - centres[d] = cimage[r, c, d] 120 | - for wr in range(-window_ext, window_ext + 1): 121 | - rr = wr + r 122 | - kr = wr + window_ext 123 | - for wc in range(-window_ext, window_ext + 1): 124 | - cc = wc + c 125 | - kc = wc + window_ext 126 | - 127 | - # save pixel values for all dims and compute euclidian 128 | - # distance between centre stack and current position 129 | - dist = 0 130 | - for d in range(dims): 131 | - value = get_pixel3d(&cimage[0, 0, 0], rows, cols, dims, 132 | - rr, cc, d, cmode, cval) 133 | - values[d] = value 134 | - dist += (centres[d] - value)**2 135 | - dist = sqrt(dist) 136 | 137 | - range_weight = range_lut[kr * win_size + kc] 138 | - 139 | - color_lut_bin = Py_ssize_t_min( 140 | - (dist * dist_scale), max_color_lut_bin) 141 | - color_weight = color_lut[color_lut_bin] 142 | 143 | - weight = range_weight * color_weight 144 | - for d in range(dims): 145 | - total_values[d] += values[d] * weight 146 | - total_weight += weight 147 | - for d in range(dims): 148 | - out[r, c, d] = total_values[d] / total_weight 149 | - 150 | - return np.squeeze(np.asarray(out)) 151 | - 152 | - 153 | -def _denoise_tv_bregman(image, double weight, int max_iter, double eps, 154 | - char isotropic): 155 | - image = np.atleast_3d(img_as_float(image)) 156 | +def _denoise_tv_bregman(image, mask, double weight, int max_iter, int gs_iter, 157 | + double eps, char isotropic): 158 | + image = np.atleast_3d(image) 159 | 160 | cdef: 161 | Py_ssize_t rows = image.shape[0] 162 | Py_ssize_t cols = image.shape[1] 163 | Py_ssize_t dims = image.shape[2] 164 | - Py_ssize_t rows2 = rows + 2 165 | - Py_ssize_t cols2 = cols + 2 166 | Py_ssize_t r, c, k 167 | 168 | Py_ssize_t total = rows * cols * dims 169 | 170 | - shape_ext = (rows2, cols2, dims) 171 | - u = np.zeros(shape_ext, dtype=np.double) 172 | + u = np.zeros(image.shape, dtype=np.double) 173 | + u[:, :, :] = image 174 | 175 | cdef: 176 | double[:, :, ::1] cimage = np.ascontiguousarray(image) 177 | + char[:, :, ::1] cmask = mask 178 | double[:, :, ::1] cu = u 179 | 180 | - double[:, :, ::1] dx = np.zeros(shape_ext, dtype=np.double) 181 | - double[:, :, ::1] dy = np.zeros(shape_ext, dtype=np.double) 182 | - double[:, :, ::1] bx = np.zeros(shape_ext, dtype=np.double) 183 | - double[:, :, ::1] by = np.zeros(shape_ext, dtype=np.double) 184 | + double[:, :, ::1] dx = np.zeros(image.shape, dtype=np.double) 185 | + double[:, :, ::1] dy = np.zeros(image.shape, dtype=np.double) 186 | + double[:, :, ::1] bx = np.zeros(image.shape, dtype=np.double) 187 | + double[:, :, ::1] by = np.zeros(image.shape, dtype=np.double) 188 | + double[:, :, ::1] z = np.zeros(image.shape, dtype=np.double) 189 | + double[:, :, ::1] uprev = np.ascontiguousarray(image) 190 | 191 | - double ux, uy, uprev, unew, bxx, byy, dxx, dyy, s 192 | + double ux, uy, unew, bxx, byy, dxx, dyy, s 193 | int i = 0 194 | double lam = 2 * weight 195 | double rmse = DBL_MAX 196 | - double norm = (weight + 4 * lam) 197 | - 198 | - u[1:-1, 1:-1] = image 199 | - 200 | - # reflect image 201 | - u[0, 1:-1] = image[1, :] 202 | - u[1:-1, 0] = image[:, 1] 203 | - u[-1, 1:-1] = image[-2, :] 204 | - u[1:-1, -1] = image[:, -2] 205 | + double neighbors = 0 206 | + double inner = 0 207 | 208 | while i < max_iter and rmse > eps: 209 | 210 | - rmse = 0 211 | - 212 | + for _ in range(gs_iter): 213 | - for k in range(dims): 214 | + for k in range(dims): 215 | - for r in range(1, rows + 1): 216 | - for c in range(1, cols + 1): 217 | - 218 | - uprev = cu[r, c, k] 219 | - 220 | - # forward derivatives 221 | - ux = cu[r, c + 1, k] - uprev 222 | - uy = cu[r + 1, c, k] - uprev 223 | - 224 | + for r in range(rows): 225 | + for c in range(cols): 226 | # Gauss-Seidel method 227 | - unew = ( 228 | - lam * ( 229 | - + cu[r + 1, c, k] 230 | - + cu[r - 1, c, k] 231 | - + cu[r, c + 1, k] 232 | - + cu[r, c - 1, k] 233 | - 234 | - + dx[r, c - 1, k] 235 | - - dx[r, c, k] 236 | - + dy[r - 1, c, k] 237 | - - dy[r, c, k] 238 | - 239 | - - bx[r, c - 1, k] 240 | - + bx[r, c, k] 241 | - - by[r - 1, c, k] 242 | - + by[r, c, k] 243 | - ) + weight * cimage[r - 1, c - 1, k] 244 | - ) / norm 245 | + inner = z[r, c, k] 246 | + neighbors = 0 247 | + if r > 0: 248 | + inner += cu[r - 1, c, k] 249 | + neighbors += 1 250 | + if r < rows - 1: 251 | + inner += cu[r + 1, c, k] 252 | + neighbors += 1 253 | + if c > 0: 254 | + inner += cu[r, c - 1, k] 255 | + neighbors += 1 256 | + if c < cols - 1: 257 | + inner += cu[r, c + 1, k] 258 | + neighbors += 1 259 | + if cmask[r, c, k] == 1: 260 | + unew = (lam * inner + weight * cimage[r, c, k]) / (weight + neighbors * lam) 261 | + else: 262 | + unew = inner / 4 263 | - cu[r, c, k] = unew 264 | + cu[r, c, k] = unew 265 | 266 | - # update root mean square error 267 | - rmse += (unew - uprev)**2 268 | + rmse = 0 269 | + for k in range(dims): 270 | + for r in range(rows): 271 | + for c in range(cols): 272 | + # forward derivatives 273 | + if c == cols - 1: 274 | + ux = 0 275 | + else: 276 | + ux = cu[r, c + 1, k] - cu[r, c, k] 277 | + if r == rows - 1: 278 | + uy = 0 279 | + else: 280 | + uy = cu[r + 1, c, k] - cu[r, c, k] 281 | 282 | bxx = bx[r, c, k] 283 | byy = by[r, c, k] 284 | @@ -262,7 +114,17 @@ 285 | bx[r, c, k] += ux - dxx 286 | by[r, c, k] += uy - dyy 287 | 288 | + z[r, c, k] = -dx[r, c, k] - dy[r, c, k] + bx[r, c, k] + by[r, c, k] 289 | + if r > 0: 290 | + z[r, c, k] += dy[r - 1, c, k] - by[r - 1, c, k] 291 | + if c > 0: 292 | + z[r, c, k] += dx[r, c - 1, k] - bx[r, c - 1, k] 293 | + 294 | + # update rmse 295 | + rmse += (cu[r, c, k] - uprev[r, c, k])**2 296 | + 297 | rmse = sqrt(rmse / total) 298 | + uprev = np.copy(cu) 299 | i += 1 300 | 301 | - return np.squeeze(np.asarray(u[1:-1, 1:-1])) 302 | + return np.squeeze(np.asarray(u)) 303 | --------------------------------------------------------------------------------