├── data ├── mnist.py ├── __init__.py ├── imagenet.py ├── svhn.py └── cifar.py ├── trainer ├── __init__.py ├── base.py ├── adv.py ├── freeadv.py ├── smooth.py ├── mixtrain.py └── crown-ibp.py ├── requirements.txt ├── images ├── results_table.png ├── comparison_plot.png └── weight_histogram.png ├── symbolic_interval └── __init__.py ├── configs ├── configs.yml ├── configs_imagenet.yml ├── configs_mixtrain.yml └── configs_crown-ibp.yml ├── utils ├── misc.py ├── semisup.py ├── schedules.py ├── logging.py ├── smoothing.py ├── adv.py ├── model.py └── eval.py ├── crown ├── converter.py └── eps_scheduler.py ├── models ├── __init__.py ├── layers.py ├── resnet_cifar.py ├── wrn_cifar.py ├── vgg_cifar.py ├── resnet.py └── basic.py ├── get_compact_net_adv_train.sh ├── get_compact_net_rand_smoothing.sh ├── .gitignore ├── get_compact_net_crown-ibp.sh ├── get_compact_net_mixtrain.sh ├── eval_smoothing.py ├── args.py ├── train_imagenet.py ├── train.py ├── README.md └── locuslab_smoothing └── analyze.py /data/mnist.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorboard 4 | pyyaml 5 | easydict 6 | -------------------------------------------------------------------------------- /images/results_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/hydra/HEAD/images/results_table.png -------------------------------------------------------------------------------- /images/comparison_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/hydra/HEAD/images/comparison_plot.png -------------------------------------------------------------------------------- /images/weight_histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inspire-group/hydra/HEAD/images/weight_histogram.png -------------------------------------------------------------------------------- /symbolic_interval/__init__.py: -------------------------------------------------------------------------------- 1 | from .interval import Interval, Symbolic_interval 2 | from .symbolic_network import Interval_network -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from data.cifar import CIFAR10, CIFAR100 2 | from data.svhn import SVHN 3 | from data.imagenet import imagenet 4 | 5 | __all__ = ["CIFAR10", "CIFAR100", "SVHN", "imagenet"] 6 | -------------------------------------------------------------------------------- /configs/configs.yml: -------------------------------------------------------------------------------- 1 | # ->->->->-> Primary <-<-<-<-<- 2 | arch: "vgg16_bn" 3 | exp_name: "temp" 4 | result_dir: "./trained_models" 5 | num_classes: 10 6 | exp_mode: "pretrain" 7 | layer_type: "subnet" 8 | init_type: "kaiming_normal" 9 | 10 | 11 | # ->->->->-> Pruning <-<-<-<-<- 12 | k: 1.0 13 | 14 | # ->->->->-> Train <-<-<-<-<- 15 | trainer: "base" 16 | epochs: 100 17 | optimizer: "sgd" 18 | lr: 0.1 19 | lr_schedule: "cosine" 20 | wd: 0.0001 21 | momentum: 0.9 22 | #warmup 23 | warmup_epochs: 0 24 | warmup_lr: 0.1 25 | 26 | 27 | # ->->->->-> Eval <-<-<-<-<- 28 | val_method: base 29 | 30 | 31 | # ->->->->-> Dataset <-<-<-<-<- 32 | dataset: CIFAR10 33 | batch_size: 128 34 | test_batch_size: 128 35 | data_dir: "./datasets" 36 | data_fraction: 1.0 37 | 38 | # ->->->->-> Semi-supervised training <-<-<-<-<- 39 | semisup_data: "tinyimages" 40 | semisup_fraction: 1.0 41 | 42 | 43 | # ->->->->-> Adv <-<-<-<-<- 44 | epsilon: 0.031 45 | num_steps: 10 46 | step_size: 0.0078 47 | clip_min: 0 48 | clip_max: 1 49 | distance: "l_inf" 50 | beta: 6.0 51 | 52 | 53 | # ->->->->-> Misc <-<-<-<-<- 54 | gpu: "0" 55 | seed: 1234 56 | print_freq: 100 -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import datasets, transforms 4 | from torch.utils.data.dataset import Dataset 5 | 6 | 7 | class CustomDatasetFromNumpy(Dataset): 8 | def __init__(self, img, label, transform): 9 | self.img = img 10 | self.label = label 11 | self.transform = transform 12 | self.len = len(self.img) 13 | 14 | def __getitem__(self, index): 15 | img_tensor = transforms.ToPILImage()(self.img[index]) 16 | img_tensor = self.transform(img_tensor) 17 | label_tensor = self.label[index] 18 | return (img_tensor, label_tensor) 19 | 20 | def __len__(self): 21 | return self.len 22 | 23 | 24 | def xe_with_one_hot(out, target): 25 | """ 26 | out: [N,k] dim tensor with output logits. 27 | target: [N,k] dim tensor with ground truth probs. 28 | 29 | return: calcuate mean(-1*sum(p_i*out_i)) 30 | """ 31 | log_prob = nn.LogSoftmax(dim=1)(out) 32 | loss = -1 * torch.sum(log_prob * target, dim=1) 33 | loss = torch.sum(loss) / len(loss) 34 | return loss 35 | 36 | -------------------------------------------------------------------------------- /configs/configs_imagenet.yml: -------------------------------------------------------------------------------- 1 | # ->->->->-> Primary <-<-<-<-<- 2 | arch: "ResNet50" 3 | exp_name: "temp" 4 | result_dir: "./trained_models" 5 | num_classes: 1000 6 | exp_mode: "pretrain" 7 | layer_type: "subnet" 8 | init_type: "kaiming_normal" 9 | 10 | 11 | # ->->->->-> Pruning <-<-<-<-<- 12 | k: 1.0 13 | 14 | # ->->->->-> Train <-<-<-<-<- 15 | trainer: "freeadv" 16 | epochs: 90 17 | optimizer: "sgd" 18 | lr: 0.1 19 | lr_schedule: "cosine" 20 | wd: 0.0001 21 | momentum: 0.9 22 | #warmup 23 | warmup_epochs: 0 24 | warmup_lr: 0.1 25 | 26 | 27 | # ->->->->-> Eval <-<-<-<-<- 28 | val_method: base 29 | 30 | 31 | # ->->->->-> Dataset <-<-<-<-<- 32 | dataset: imagenet 33 | batch_size: 256 34 | test_batch_size: 256 35 | data_dir: "/data/nvme/imagenet_data/raw-data/" 36 | data_fraction: 1.0 37 | image_dim: 224 38 | mean: !!python/tuple [0.485, 0.456, 0.406] 39 | std: !!python/tuple [0.229, 0.224, 0.225] 40 | 41 | # ->->->->-> Adv <-<-<-<-<- 42 | epsilon: 0.0156 #(4/255) 43 | num_steps: 10 44 | step_size: 0.00392 #(1/255) 45 | distance: "l_inf" 46 | beta: 6.0 47 | 48 | n_repeats: 4 49 | 50 | 51 | 52 | # ->->->->-> Misc <-<-<-<-<- 53 | gpu: "0,1,2,3" 54 | seed: 1234 55 | print_freq: 10 56 | -------------------------------------------------------------------------------- /configs/configs_mixtrain.yml: -------------------------------------------------------------------------------- 1 | # ->->->->-> Primary <-<-<-<-<- 2 | arch: "vgg16_bn" 3 | exp_name: "temp" 4 | result_dir: "./trained_models" 5 | num_classes: 10 6 | exp_mode: "pretrain" 7 | layer_type: "subnet" 8 | init_type: "kaiming_normal" 9 | 10 | 11 | # ->->->->-> Pruning <-<-<-<-<- 12 | k: 1.0 13 | 14 | # ->->->->-> Train <-<-<-<-<- 15 | trainer: "base" 16 | epochs: 60 17 | optimizer: "adam" 18 | lr: 0.001 19 | lr_schedule: "cosine" 20 | wd: 0.0001 21 | momentum: 0.9 22 | #warmup 23 | warmup_epochs: 0 24 | warmup_lr: 0.1 25 | 26 | 27 | # ->->->->-> Eval <-<-<-<-<- 28 | val_method: base 29 | 30 | 31 | # ->->->->-> Dataset <-<-<-<-<- 32 | dataset: CIFAR10 33 | batch_size: 25 34 | test_batch_size: 10 35 | data_dir: "./datasets" 36 | data_fraction: 1.0 37 | 38 | # ->->->->-> Semi-supervised training <-<-<-<-<- 39 | semisup_data: "tinyimages" 40 | semisup_fraction: 1.0 41 | 42 | 43 | # ->->->->-> Adv <-<-<-<-<- 44 | epsilon: 0.007 45 | num_steps: 10 46 | step_size: 0.0078 47 | clip_min: 0 48 | clip_max: 1 49 | distance: "l_inf" 50 | beta: 6.0 51 | starting_epsilon: 0 52 | schedule_length: 10 53 | interval_weight: 0.1 54 | 55 | 56 | # ->->->->-> Misc <-<-<-<-<- 57 | gpu: "0" 58 | seed: 1234 59 | print_freq: 600 -------------------------------------------------------------------------------- /configs/configs_crown-ibp.yml: -------------------------------------------------------------------------------- 1 | # ->->->->-> Primary <-<-<-<-<- 2 | arch: "vgg16_bn" 3 | exp_name: "temp" 4 | result_dir: "./trained_models" 5 | num_classes: 10 6 | exp_mode: "pretrain" 7 | layer_type: "subnet" 8 | init_type: "kaiming_normal" 9 | 10 | 11 | # ->->->->-> Pruning <-<-<-<-<- 12 | k: 1.0 13 | 14 | # ->->->->-> Train <-<-<-<-<- 15 | trainer: "base" 16 | epochs: 100 17 | optimizer: "adam" 18 | lr: 0.001 19 | lr_schedule: "cosine" 20 | wd: 0.0001 21 | momentum: 0.9 22 | #warmup 23 | warmup_epochs: 0 24 | warmup_lr: 0.1 25 | 26 | 27 | # ->->->->-> Eval <-<-<-<-<- 28 | val_method: base 29 | 30 | 31 | # ->->->->-> Dataset <-<-<-<-<- 32 | dataset: CIFAR10 33 | batch_size: 64 34 | test_batch_size: 128 35 | data_dir: "./datasets" 36 | data_fraction: 1.0 37 | 38 | # ->->->->-> Semi-supervised training <-<-<-<-<- 39 | semisup_data: "tinyimages" 40 | semisup_fraction: 1.0 41 | 42 | 43 | # ->->->->-> Adv <-<-<-<-<- 44 | epsilon: 0.007 45 | num_steps: 10 46 | step_size: 0.0078 47 | clip_min: 0 48 | clip_max: 1 49 | distance: "l_inf" 50 | beta: 6.0 51 | schedule_start: 0 52 | starting_epsilon: 0 53 | schedule_length: 60 54 | interval_weight: 0.1 55 | 56 | 57 | # ->->->->-> Misc <-<-<-<-<- 58 | gpu: "0" 59 | seed: 1234 60 | print_freq: 500 -------------------------------------------------------------------------------- /crown/converter.py: -------------------------------------------------------------------------------- 1 | ## Copyright (C) 2019, Huan Zhang 2 | ## Hongge Chen 3 | ## Chaowei Xiao 4 | ## 5 | ## This program is licenced under the BSD 2-Clause License, 6 | ## contained in the LICENCE file in this directory. 7 | ## 8 | import sys 9 | import copy 10 | import torch 11 | from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss 12 | import numpy as np 13 | from datasets import loaders 14 | from model_defs import Flatten, model_mlp_any, model_cnn_1layer, model_cnn_2layer, model_cnn_4layer, model_cnn_3layer 15 | from bound_layers import BoundSequential 16 | import torch.optim as optim 17 | import time 18 | from datetime import datetime 19 | 20 | from config import load_config, get_path, config_modelloader, config_dataloader, config_modelloader_and_convert2mlp 21 | from argparser import argparser 22 | from pdb import set_trace as st 23 | # sys.settrace(gpu_profile) 24 | 25 | 26 | def main(args): 27 | config = load_config(args) 28 | global_train_config = config["training_params"] 29 | models, model_names = config_modelloader_and_convert2mlp(config) 30 | 31 | if __name__ == "__main__": 32 | args = argparser() 33 | main(args) 34 | -------------------------------------------------------------------------------- /utils/semisup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | import torch 5 | import torchvision 6 | import os 7 | 8 | from utils.misc import CustomDatasetFromNumpy 9 | 10 | 11 | def get_semisup_dataloader(args, transform): 12 | if args.semisup_data == "splitgan": 13 | print(f"Loading {args.semisup_data} generated data") 14 | img, label = ( 15 | np.load( 16 | "/data/scsi/home/vvikash/research/mini_projects/trades_minimal/filter_gan_generate_images_c_99.npy" 17 | ), 18 | np.load( 19 | "/data/scsi/home/vvikash/research/mini_projects/trades_minimal/filter_gan_generate_labels_c_99.npy" 20 | ).astype(np.int64), 21 | ) 22 | if args.semisup_data == "tinyimages": 23 | print(f"Loading {args.semisup_data} dataset") 24 | with open( 25 | os.path.join(args.data_dir, "tiny_images/ti_top_50000_pred_v3.1.pickle"), 26 | "rb", 27 | ) as f: 28 | data = pickle.load(f) 29 | img, label = data["data"], data["extrapolated_targets"] 30 | 31 | # select random subset 32 | index = np.random.permutation(np.arange(len(label)))[ 33 | 0 : int(args.semisup_fraction * len(label)) 34 | ] 35 | 36 | sm_loader = torch.utils.data.DataLoader( 37 | CustomDatasetFromNumpy(img[index], label[index], transform), 38 | batch_size=args.batch_size, 39 | shuffle=True, 40 | ) 41 | print(f"Semisup dataset: {len(sm_loader.dataset)} images.") 42 | return sm_loader 43 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # small-caps refers to cifar-style models i.e., resnet18 -> for cifar vs ResNet18 -> standard arch. 2 | from models.vgg_cifar import ( 3 | vgg2, 4 | vgg2_bn, 5 | vgg4, 6 | vgg4_bn, 7 | vgg6, 8 | vgg6_bn, 9 | vgg8, 10 | vgg8_bn, 11 | vgg11, 12 | vgg11_bn, 13 | vgg13, 14 | vgg13_bn, 15 | vgg16, 16 | vgg16_bn, 17 | ) 18 | from models.resnet_cifar import resnet18, resnet34, resnet50, resnet101, resnet152 19 | from models.wrn_cifar import wrn_28_10, wrn_28_1, wrn_28_4, wrn_34_10, wrn_40_2 20 | from models.basic import ( 21 | lin_1, 22 | lin_2, 23 | lin_3, 24 | lin_4, 25 | mnist_model, 26 | mnist_model_large, 27 | cifar_model, 28 | cifar_model_large, 29 | cifar_model_resnet, 30 | vgg4_without_maxpool, 31 | ) 32 | 33 | from models.resnet import ResNet18, ResNet34, ResNet50 34 | 35 | __all__ = [ 36 | "vgg2", 37 | "vgg2_bn", 38 | "vgg4", 39 | "vgg4_bn", 40 | "vgg6", 41 | "vgg6_bn", 42 | "vgg8", 43 | "vgg8_bn", 44 | "vgg11", 45 | "vgg11_bn", 46 | "vgg13", 47 | "vgg13_bn", 48 | "vgg16", 49 | "vgg16_bn", 50 | "resnet18", 51 | "resnet34", 52 | "resnet50", 53 | "resnet101", 54 | "resnet152", 55 | "wrn_28_10", 56 | "wrn_28_1", 57 | "wrn_28_4", 58 | "wrn_34_10", 59 | "wrn_40_2", 60 | "lin_1", 61 | "lin_2", 62 | "lin_3", 63 | "lin_4", 64 | "mnist_model", 65 | "mnist_model_large", 66 | "cifar_model", 67 | "cifar_model_large", 68 | "cifar_model_resnet", 69 | "vgg4_without_maxpool", 70 | "ResNet18", 71 | "ResNet34", 72 | "ResNet50", 73 | ] 74 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, SubsetRandomSampler 7 | 8 | # NOTE: Each dataset class must have public norm_layer, tr_train, tr_test objects. 9 | # These are needed for ood/semi-supervised dataset used alongwith in the training and eval. 10 | class imagenet: 11 | """ 12 | imagenet dataset. 13 | """ 14 | 15 | def __init__(self, args, normalize=True): 16 | self.args = args 17 | 18 | self.norm_layer = transforms.Normalize( 19 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 20 | ) 21 | 22 | self.tr_train = [ 23 | transforms.RandomResizedCrop(224), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | ] 27 | self.tr_test = [ 28 | transforms.Resize(256), 29 | transforms.CenterCrop(224), 30 | transforms.ToTensor(), 31 | ] 32 | 33 | if normalize: 34 | self.tr_train.append(self.norm_layer) 35 | self.tr_test.append(self.norm_layer) 36 | 37 | self.tr_train = transforms.Compose(self.tr_train) 38 | self.tr_test = transforms.Compose(self.tr_test) 39 | 40 | def data_loaders(self, **kwargs): 41 | trainset = datasets.ImageFolder( 42 | os.path.join(self.args.data_dir, "train"), self.tr_train 43 | ) 44 | testset = datasets.ImageFolder( 45 | os.path.join(self.args.data_dir, "val"), self.tr_test 46 | ) 47 | 48 | train_loader = DataLoader( 49 | trainset, 50 | shuffle=True, 51 | batch_size=self.args.batch_size, 52 | num_workers=8, 53 | pin_memory=True, 54 | **kwargs, 55 | ) 56 | 57 | test_loader = DataLoader( 58 | testset, 59 | batch_size=self.args.test_batch_size, 60 | shuffle=False, 61 | num_workers=4, 62 | pin_memory=True, 63 | **kwargs, 64 | ) 65 | 66 | print( 67 | f"Traing loader: {len(train_loader.dataset)} images, Test loader: {len(test_loader.dataset)} images" 68 | ) 69 | return train_loader, test_loader 70 | -------------------------------------------------------------------------------- /data/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, SubsetRandomSampler 7 | 8 | # NOTE: Each dataset class must have public norm_layer, tr_train, tr_test objects. 9 | # These are needed for ood/semi-supervised dataset used alongwith in the training and eval. 10 | class SVHN: 11 | """ 12 | SVHN dataset. 13 | """ 14 | 15 | def __init__(self, args, normalize=False): 16 | self.args = args 17 | 18 | self.norm_layer = transforms.Normalize( 19 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] 20 | ) 21 | 22 | self.tr_train = [ 23 | transforms.RandomCrop(32, padding=4), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | ] 27 | self.tr_test = [transforms.ToTensor()] 28 | 29 | if normalize: 30 | self.tr_train.append(self.norm_layer) 31 | self.tr_test.append(self.norm_layer) 32 | 33 | self.tr_train = transforms.Compose(self.tr_train) 34 | self.tr_test = transforms.Compose(self.tr_test) 35 | 36 | def data_loaders(self, **kwargs): 37 | trainset = datasets.SVHN( 38 | root=os.path.join(self.args.data_dir, "SVHN"), 39 | split="train", 40 | download=True, 41 | transform=self.tr_train, 42 | ) 43 | 44 | subset_indices = np.random.permutation(np.arange(len(trainset)))[ 45 | : int(self.args.data_fraction * len(trainset)) 46 | ] 47 | 48 | train_loader = DataLoader( 49 | trainset, 50 | batch_size=self.args.batch_size, 51 | sampler=SubsetRandomSampler(subset_indices), 52 | **kwargs, 53 | ) 54 | testset = datasets.SVHN( 55 | root=os.path.join(self.args.data_dir, "SVHN"), 56 | split="test", 57 | download=True, 58 | transform=self.tr_test, 59 | ) 60 | test_loader = DataLoader( 61 | testset, batch_size=self.args.test_batch_size, shuffle=False, **kwargs 62 | ) 63 | 64 | print( 65 | f"Traing loader: {len(train_loader.dataset)} images, Test loader: {len(test_loader.dataset)} images" 66 | ) 67 | return train_loader, test_loader 68 | -------------------------------------------------------------------------------- /get_compact_net_adv_train.sh: -------------------------------------------------------------------------------- 1 | dt=$(date '+%d/%m/%Y %H:%M:%S'); 2 | echo $dt 3 | 4 | # Note: --is-semisup use additional labelled data for CIFAR-10 released by Carmon et al. Do not use this flag with SVHN. 5 | 6 | pretrain_prune_finetune_semisup() { 7 | # Order: exp_name ($1), arch ($2), trainer ($3), val_method ($4), gpu ($5), k ($6), pruning_epochs ($7) 8 | 9 | # pre-training 10 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode pretrain --configs configs/configs.yml \ 11 | --trainer $3 --val_method $4 --gpu $5 --k 1.0 --save-dense ; 12 | 13 | # pruning 14 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode prune --configs configs/configs.yml \ 15 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 16 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --epochs $7; 17 | 18 | # finetuning 19 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode finetune --configs configs/configs.yml \ 20 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense \ 21 | --source-net ./trained_models/$1/prune/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.01 ; 22 | 23 | # weight base pruning 24 | python train.py --is-semisup --exp-name $1"_weight_based_pruning" --arch $2 --exp-mode finetune --configs configs/configs.yml \ 25 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 26 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.01 ; 27 | 28 | } 29 | 30 | 31 | arch="wrn_28_4" 32 | 33 | 34 | # Iterative adv training 35 | ( 36 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_adv-k_0.1-prunepochs_20" $arch "adv" "adv" "0" 0.1 20 & 37 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_adv-k_0.05-prunepochs_20" $arch "adv" "adv" "1" 0.05 20 & 38 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_adv-k_0.01-prunepochs_20" $arch "adv" "adv" "2" 0.01 20 ; 39 | ); 40 | 41 | #Natural training 42 | ( 43 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_base-k_0.1-prunepochs_20" $arch "base" "base" "0" 0.1 20 & 44 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_base-k_0.05-prunepochs_20" $arch "base" "base" "1" 0.05 20 & 45 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_base-k_0.01-prunepochs_20" $arch "base" "base" "2" 0.01 20 ; 46 | ); 47 | 48 | -------------------------------------------------------------------------------- /utils/schedules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | 7 | def get_lr_policy(lr_schedule): 8 | """Implement a new schduler directly in this file. 9 | Args should contain a single choice for learning rate scheduler.""" 10 | 11 | d = { 12 | "constant": constant_schedule, 13 | "cosine": cosine_schedule, 14 | "step": step_schedule, 15 | } 16 | return d[lr_schedule] 17 | 18 | 19 | def get_optimizer(model, args): 20 | if args.optimizer == "sgd": 21 | optim = torch.optim.SGD( 22 | model.parameters(), 23 | lr=args.lr, 24 | momentum=args.momentum, 25 | weight_decay=args.wd, 26 | ) 27 | elif args.optimizer == "adam": 28 | optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd,) 29 | elif args.optimizer == "rmsprop": 30 | optim = torch.optim.RMSprop( 31 | model.parameters(), 32 | lr=args.lr, 33 | momentum=args.momentum, 34 | weight_decay=args.wd, 35 | ) 36 | else: 37 | print(f"{args.optimizer} is not supported.") 38 | sys.exit(0) 39 | return optim 40 | 41 | 42 | def new_lr(optimizer, lr): 43 | for param_group in optimizer.param_groups: 44 | param_group["lr"] = lr 45 | 46 | 47 | def constant_schedule(optimizer, args): 48 | def set_lr(epoch, lr=args.lr, epochs=args.epochs): 49 | if epoch < args.warmup_epochs: 50 | lr = args.warmup_lr 51 | 52 | new_lr(optimizer, lr) 53 | 54 | return set_lr 55 | 56 | 57 | def cosine_schedule(optimizer, args): 58 | def set_lr(epoch, lr=args.lr, epochs=args.epochs): 59 | if epoch < args.warmup_epochs: 60 | a = args.warmup_lr 61 | else: 62 | epoch = epoch - args.warmup_epochs 63 | a = lr * 0.5 * (1 + np.cos((epoch - 1) / epochs * np.pi)) 64 | 65 | new_lr(optimizer, a) 66 | 67 | return set_lr 68 | 69 | 70 | def step_schedule(optimizer, args): 71 | def set_lr(epoch, lr=args.lr, epochs=args.epochs): 72 | if epoch < args.warmup_epochs: 73 | a = args.warmup_lr 74 | else: 75 | epoch = epoch - args.warmup_epochs 76 | 77 | a = lr 78 | if epoch >= 0.75 * epochs: 79 | a = lr * 0.1 80 | if epoch >= 0.9 * epochs: 81 | a = lr * 0.01 82 | if epoch >= epochs: 83 | a = lr * 0.001 84 | 85 | new_lr(optimizer, a) 86 | 87 | return set_lr 88 | -------------------------------------------------------------------------------- /get_compact_net_rand_smoothing.sh: -------------------------------------------------------------------------------- 1 | dt=$(date '+%d/%m/%Y %H:%M:%S'); 2 | echo $dt 3 | 4 | # Note: --is-semisup use additional labelled data for CIFAR-10 released by Carmon et al. Do not use this flag with SVHN. 5 | 6 | pretrain_prune_finetune_semisup() { 7 | # Order: exp_name ($1), arch ($2), trainer ($3), val_method ($4), gpu ($5), k ($6), pruning_epochs ($7) 8 | 9 | # pre-training 10 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode pretrain --configs configs/configs.yml \ 11 | --trainer $3 --val_method $4 --gpu $5 --k 1.0 --save-dense --dataset CIFAR10 --noise-std 0.25 ; 12 | 13 | # pruning 14 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode prune --configs configs/configs.yml \ 15 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 16 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar \ 17 | --epochs $7 --dataset CIFAR10 --noise-std 0.25 ; 18 | 19 | # finetuning 20 | python train.py --is-semisup --exp-name $1 --arch $2 --exp-mode finetune --configs configs/configs.yml \ 21 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --dataset CIFAR10 --noise-std 0.25 \ 22 | --source-net ./trained_models/$1/prune/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.01 ; 23 | 24 | # weight base pruning 25 | python train.py --is-semisup --exp-name $1"_weight_based_pruning" --arch $2 --exp-mode finetune --configs configs/configs.yml \ 26 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init --dataset CIFAR10 --noise-std 0.25 \ 27 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.01 ; 28 | } 29 | 30 | 31 | arch="vgg16_bn" 32 | ( 33 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.1-cifar10" $arch "smooth" "smooth" "0" 0.1 20 & 34 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.05-cifar10" $arch "smooth" "smooth" "1" 0.05 20 & 35 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.01-cifar10" $arch "smooth" "smooth" "2" 0.01 20 ; 36 | ); 37 | 38 | arch="wrn_28_4" 39 | ( 40 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.1-cifar10" $arch "smooth" "smooth" "0" 0.1 20 & 41 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.05-cifar10" $arch "smooth" "smooth" "1" 0.05 20 & 42 | pretrain_prune_finetune_semisup "semisup-$arch-trainer_smooth-k_0.01-cifar10" $arch "smooth" "smooth" "2" 0.01 20 ; 43 | ); -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.pickle 7 | *.pkl 8 | *.npy 9 | *.tar 10 | *.ckpt 11 | *.pth 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /get_compact_net_crown-ibp.sh: -------------------------------------------------------------------------------- 1 | dt=$(date '+%d/%m/%Y %H:%M:%S'); 2 | echo $dt 3 | 4 | # Note: --is-semisup use additional labelled data for CIFAR-10 released by Carmon et al. Do not use this flag with SVHN. 5 | 6 | pretrain_prune_finetune() { 7 | # Order: exp_name ($1), arch ($2), trainer ($3), val_method ($4), gpu ($5), k ($6), pruning_epochs ($7) 8 | 9 | # pre-training 10 | python train.py --exp-name $1 --arch $2 --exp-mode pretrain --configs configs/configs_crown-ibp.yml \ 11 | --trainer $3 --val_method $4 --gpu $5 --k 1.0 --save-dense --dataset SVHN --batch-size 128 --epochs 200 --schedule_length 120; 12 | 13 | pruning 14 | python train.py --exp-name $1 --arch $2 --exp-mode prune --configs configs/configs_crown-ibp.yml \ 15 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 16 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --epochs $7 \ 17 | --schedule_length 1 --lr 0.00001 --dataset SVHN --batch-size 128; 18 | 19 | finetuning 20 | python train.py --exp-name $1 --arch $2 --exp-mode finetune --configs configs/configs_crown-ibp.yml \ 21 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense \ 22 | --source-net ./trained_models/$1/prune/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.00001 \ 23 | --schedule_length 1 --dataset SVHN --batch-size 128; 24 | 25 | weight base pruning 26 | python train.py --exp-name $1"_weight_based_pruning" --arch $2 --exp-mode finetune --configs configs/configs_crown-ibp.yml \ 27 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 28 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.0005\ 29 | --schedule_length 1 --dataset SVHN --batch-size 128; 30 | } 31 | 32 | 33 | arch="cifar_model_large" 34 | 35 | ( 36 | pretrain_prune_finetune "svhn_large_model-trainer_crown-ibp-k_0.1-prunepochs_20" $arch "crown-ibp" "ibp" "0" 0.1 20 & 37 | pretrain_prune_finetune "svhn_large_model-trainer_crown-ibp-k_0.05-prunepochs_20" $arch "crown-ibp" "ibp" "1" 0.05 20 & 38 | pretrain_prune_finetune "svhn_large_model-trainer_crown-ibp-k_0.01-prunepochs_20" $arch "crown-ibp" "ibp" "2" 0.01 20 ; 39 | ); 40 | 41 | 42 | arch="cifar_model" 43 | 44 | ( 45 | pretrain_prune_finetune "svhn_model-trainer_crown-ibp_new-k_0.1-prunepochs_20" $arch "crown-ibp" "ibp" "0" 0.1 20 & 46 | pretrain_prune_finetune "svhn_model-trainer_crown-ibp_new-k_0.05-prunepochs_20" $arch "crown-ibp" "ibp" "1" 0.05 20 & 47 | pretrain_prune_finetune "svhn_model-trainer_crown-ibp_new-k_0.01-prunepochs_20" $arch "crown-ibp" "ibp" "2" 0.01 20 ; 48 | ); -------------------------------------------------------------------------------- /get_compact_net_mixtrain.sh: -------------------------------------------------------------------------------- 1 | dt=$(date '+%d/%m/%Y %H:%M:%S'); 2 | echo $dt 3 | 4 | # Note: --is-semisup use additional labelled data for CIFAR-10 released by Carmon et al. Do not use this flag with SVHN. 5 | 6 | pretrain_prune_finetune() { 7 | # Order: exp_name ($1), arch ($2), trainer ($3), val_method ($4), gpu ($5), k ($6), pruning_epochs ($7), mixtraink($8) 8 | 9 | # pre-training 10 | python train.py --exp-name $1 --arch $2 --exp-mode pretrain --configs configs/configs_mixtrain.yml \ 11 | --trainer $3 --val_method $4 --gpu $5 --k 1.0 --save-dense --dataset SVHN --schedule_length 15 \ 12 | --mixtraink $8 --batch-size 50; 13 | 14 | # pruning 15 | python train.py --exp-name $1 --arch $2 --exp-mode prune --configs configs/configs_mixtrain.yml \ 16 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 17 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --epochs $7 \ 18 | --schedule_length 0 --lr 0.00001 --dataset SVHN --mixtraink $8 --batch-size 50; 19 | 20 | # finetuning 21 | python train.py --exp-name $1 --arch $2 --exp-mode finetune --configs configs/configs_mixtrain.yml \ 22 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense \ 23 | --source-net ./trained_models/$1/prune/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.0005 \ 24 | --schedule_length 0 --dataset SVHN --mixtraink $8 --batch-size 50; 25 | 26 | # weight base pruning 27 | python train.py --exp-name $1"_weight_based_pruning" --arch $2 --exp-mode finetune --configs configs/configs_mixtrain.yml \ 28 | --trainer $3 --val_method $4 --gpu $5 --k $6 --save-dense --scaled-score-init \ 29 | --source-net ./trained_models/$1/pretrain/latest_exp/checkpoint/checkpoint.pth.tar --lr 0.0005 \ 30 | --schedule_length 0 --dataset SVHN --mixtraink $8 --batch-size 50; 31 | } 32 | 33 | arch="cifar_model_large" 34 | 35 | ( 36 | pretrain_prune_finetune "svhn_model_large-trainer_mixtraink1-k_0.1-prunepochs_20" $arch "mixtrain" "mixtrain" "0" 0.1 20 1 & 37 | pretrain_prune_finetune "svhn_model_large-trainer_mixtraink1-k_0.05-prunepochs_20" $arch "mixtrain" "mixtrain" "1" 0.05 20 1 & 38 | pretrain_prune_finetune "svhn_model_large-trainer_mixtraink1-k_0.01-prunepochs_20" $arch "mixtrain" "mixtrain" "2" 0.01 20 1 ; 39 | ); 40 | 41 | 42 | arch="cifar_model" 43 | 44 | ( 45 | pretrain_prune_finetune "svhn_model-trainer_mixtraink5-k_0.1-prunepochs_20" $arch "mixtrain" "mixtrain" "0" 0.1 20 5 & 46 | pretrain_prune_finetune "svhn_model-trainer_mixtraink5-k_0.05-prunepochs_20" $arch "mixtrain" "mixtrain" "1" 0.05 20 5 & 47 | pretrain_prune_finetune "svhn_model-trainer_mixtraink5-k_0.01-prunepochs_20" $arch "mixtrain" "mixtrain" "2" 0.01 20 5 ; 48 | ); -------------------------------------------------------------------------------- /trainer/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | from utils.logging import AverageMeter, ProgressMeter 8 | from utils.eval import accuracy 9 | 10 | # TODO: support sm_loader when len(sm_loader.dataset) < len(train_loader.dataset) 11 | def train( 12 | model, device, train_loader, sm_loader, criterion, optimizer, epoch, args, writer 13 | ): 14 | print(" ->->->->->->->->->-> One epoch with Natural training <-<-<-<-<-<-<-<-<-<-") 15 | 16 | batch_time = AverageMeter("Time", ":6.3f") 17 | data_time = AverageMeter("Data", ":6.3f") 18 | losses = AverageMeter("Loss", ":.4f") 19 | top1 = AverageMeter("Acc_1", ":6.2f") 20 | top5 = AverageMeter("Acc_5", ":6.2f") 21 | progress = ProgressMeter( 22 | len(train_loader), 23 | [batch_time, data_time, losses, top1, top5], 24 | prefix="Epoch: [{}]".format(epoch), 25 | ) 26 | 27 | model.train() 28 | end = time.time() 29 | 30 | dataloader = train_loader if sm_loader is None else zip(train_loader, sm_loader) 31 | 32 | for i, data in enumerate(dataloader): 33 | if sm_loader: 34 | images, target = ( 35 | torch.cat([d[0] for d in data], 0).to(device), 36 | torch.cat([d[1] for d in data], 0).to(device), 37 | ) 38 | else: 39 | images, target = data[0].to(device), data[1].to(device) 40 | 41 | # basic properties of training 42 | if i == 0: 43 | print( 44 | images.shape, 45 | target.shape, 46 | f"Batch_size from args: {args.batch_size}", 47 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 48 | ) 49 | print( 50 | "Pixel range for training images : [{}, {}]".format( 51 | torch.min(images).data.cpu().numpy(), 52 | torch.max(images).data.cpu().numpy(), 53 | ) 54 | ) 55 | 56 | output = model(images) 57 | loss = criterion(output, target) 58 | 59 | # measure accuracy and record loss 60 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 61 | losses.update(loss.item(), images.size(0)) 62 | top1.update(acc1[0], images.size(0)) 63 | top5.update(acc5[0], images.size(0)) 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | 69 | # measure elapsed time 70 | batch_time.update(time.time() - end) 71 | end = time.time() 72 | 73 | if i % args.print_freq == 0: 74 | progress.display(i) 75 | progress.write_to_tensorboard( 76 | writer, "train", epoch * len(train_loader) + i 77 | ) 78 | 79 | # write a sample of training images to tensorboard (helpful for debugging) 80 | if i == 0: 81 | writer.add_image( 82 | "training-images", 83 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 84 | ) 85 | 86 | -------------------------------------------------------------------------------- /trainer/adv.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | from utils.logging import AverageMeter, ProgressMeter 8 | from utils.eval import accuracy 9 | from utils.adv import trades_loss 10 | 11 | # TODO: add adversarial accuracy. 12 | def train( 13 | model, device, train_loader, sm_loader, criterion, optimizer, epoch, args, writer 14 | ): 15 | print( 16 | " ->->->->->->->->->-> One epoch with Adversarial training (TRADES) <-<-<-<-<-<-<-<-<-<-" 17 | ) 18 | 19 | batch_time = AverageMeter("Time", ":6.3f") 20 | data_time = AverageMeter("Data", ":6.3f") 21 | losses = AverageMeter("Loss", ":.4f") 22 | top1 = AverageMeter("Acc_1", ":6.2f") 23 | top5 = AverageMeter("Acc_5", ":6.2f") 24 | progress = ProgressMeter( 25 | len(train_loader), 26 | [batch_time, data_time, losses, top1, top5], 27 | prefix="Epoch: [{}]".format(epoch), 28 | ) 29 | 30 | model.train() 31 | end = time.time() 32 | 33 | dataloader = train_loader if sm_loader is None else zip(train_loader, sm_loader) 34 | 35 | for i, data in enumerate(dataloader): 36 | if sm_loader: 37 | images, target = ( 38 | torch.cat([d[0] for d in data], 0).to(device), 39 | torch.cat([d[1] for d in data], 0).to(device), 40 | ) 41 | else: 42 | images, target = data[0].to(device), data[1].to(device) 43 | 44 | # basic properties of training data 45 | if i == 0: 46 | print( 47 | images.shape, 48 | target.shape, 49 | f"Batch_size from args: {args.batch_size}", 50 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 51 | ) 52 | print(f"Training images range: {[torch.min(images), torch.max(images)]}") 53 | 54 | output = model(images) 55 | 56 | # calculate robust loss 57 | loss = trades_loss( 58 | model=model, 59 | x_natural=images, 60 | y=target, 61 | device=device, 62 | optimizer=optimizer, 63 | step_size=args.step_size, 64 | epsilon=args.epsilon, 65 | perturb_steps=args.num_steps, 66 | beta=args.beta, 67 | clip_min=args.clip_min, 68 | clip_max=args.clip_max, 69 | distance=args.distance, 70 | ) 71 | 72 | # measure accuracy and record loss 73 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 74 | losses.update(loss.item(), images.size(0)) 75 | top1.update(acc1[0], images.size(0)) 76 | top5.update(acc5[0], images.size(0)) 77 | 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | 82 | # measure elapsed time 83 | batch_time.update(time.time() - end) 84 | end = time.time() 85 | 86 | if i % args.print_freq == 0: 87 | progress.display(i) 88 | progress.write_to_tensorboard( 89 | writer, "train", epoch * len(train_loader) + i 90 | ) 91 | 92 | # write a sample of training images to tensorboard (helpful for debugging) 93 | if i == 0: 94 | writer.add_image( 95 | "training-images", 96 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 97 | ) 98 | -------------------------------------------------------------------------------- /crown/eps_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | class EpsilonScheduler(): 6 | def __init__(self, schedule_type, init_step, final_step, init_value, final_value, num_steps_per_epoch, mid_point=.25, beta=4.): 7 | self.schedule_type = schedule_type 8 | self.init_step = init_step 9 | self.final_step = final_step 10 | self.init_value = init_value 11 | self.final_value = final_value 12 | self.mid_point = mid_point 13 | self.beta = beta 14 | self.num_steps_per_epoch = num_steps_per_epoch 15 | assert self.final_value >= self.init_value 16 | assert self.final_step >= self.init_step,\ 17 | "{} should be larger than {}".format(self.final_step, self.init_step) 18 | assert self.beta >= 2. 19 | assert self.mid_point >= 0. and self.mid_point <= 1. 20 | 21 | def get_eps(self, epoch, step): 22 | if self.schedule_type == "smoothed": 23 | return self.smooth_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value, self.mid_point, self.beta) 24 | else: 25 | return self.linear_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value) 26 | 27 | # Smooth schedule that slowly morphs into a linear schedule. 28 | # Code is adapted from DeepMind's IBP implementation: 29 | # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L84 30 | def smooth_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25, beta=4.): 31 | """Smooth schedule that slowly morphs into a linear schedule.""" 32 | assert final_value >= init_value 33 | assert final_step >= init_step 34 | assert beta >= 2. 35 | assert mid_point >= 0. and mid_point <= 1. 36 | mid_step = int((final_step - init_step) * mid_point) + init_step 37 | if mid_step <= init_step: 38 | alpha = 1. 39 | else: 40 | t = (mid_step - init_step) ** (beta - 1.) 41 | alpha = (final_value - init_value) / ((final_step - mid_step) * beta * t + (mid_step - init_step) * t) 42 | mid_value = alpha * (mid_step - init_step) ** beta + init_value 43 | is_ramp = float(step > init_step) 44 | is_linear = float(step >= mid_step) 45 | return (is_ramp * ( 46 | (1. - is_linear) * ( 47 | init_value + 48 | alpha * float(step - init_step) ** beta) + 49 | is_linear * self.linear_schedule( 50 | step, mid_step, final_step, mid_value, final_value)) + 51 | (1. - is_ramp) * init_value) 52 | 53 | # Linear schedule. 54 | # Code is adapted from DeepMind's IBP implementation: 55 | # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L73 56 | def linear_schedule(self, step, init_step, final_step, init_value, final_value): 57 | """Linear schedule.""" 58 | assert final_step >= init_step 59 | if init_step == final_step: 60 | return final_value 61 | rate = float(step - init_step) / float(final_step - init_step) 62 | linear_value = rate * (final_value - init_value) + init_value 63 | return np.clip(linear_value, min(init_value, final_value), max(init_value, final_value)) 64 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | import math 7 | 8 | # https://github.com/allenai/hidden-networks 9 | class GetSubnet(autograd.Function): 10 | @staticmethod 11 | def forward(ctx, scores, k): 12 | # Get the subnetwork by sorting the scores and using the top k% 13 | out = scores.clone() 14 | _, idx = scores.flatten().sort() 15 | j = int((1 - k) * scores.numel()) 16 | 17 | # flat_out and out access the same memory. 18 | flat_out = out.flatten() 19 | flat_out[idx[:j]] = 0 20 | flat_out[idx[j:]] = 1 21 | 22 | return out 23 | 24 | @staticmethod 25 | def backward(ctx, g): 26 | # send the gradient g straight-through on the backward pass. 27 | return g, None 28 | 29 | 30 | class SubnetConv(nn.Conv2d): 31 | # self.k is the % of weights remaining, a real number in [0,1] 32 | # self.popup_scores is a Parameter which has the same shape as self.weight 33 | # Gradients to self.weight, self.bias have been turned off by default. 34 | 35 | def __init__( 36 | self, 37 | in_channels, 38 | out_channels, 39 | kernel_size, 40 | stride=1, 41 | padding=0, 42 | dilation=1, 43 | groups=1, 44 | bias=True, 45 | ): 46 | super(SubnetConv, self).__init__( 47 | in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride, 51 | padding, 52 | dilation, 53 | groups, 54 | bias, 55 | ) 56 | self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) 57 | nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) 58 | 59 | self.weight.requires_grad = False 60 | if self.bias is not None: 61 | self.bias.requires_grad = False 62 | self.w = 0 63 | 64 | def set_prune_rate(self, k): 65 | self.k = k 66 | 67 | def forward(self, x): 68 | # Get the subnetwork by sorting the scores. 69 | adj = GetSubnet.apply(self.popup_scores.abs(), self.k) 70 | 71 | # Use only the subnetwork in the forward pass. 72 | self.w = self.weight * adj 73 | x = F.conv2d( 74 | x, self.w, self.bias, self.stride, self.padding, self.dilation, self.groups 75 | ) 76 | return x 77 | 78 | 79 | class SubnetLinear(nn.Linear): 80 | # self.k is the % of weights remaining, a real number in [0,1] 81 | # self.popup_scores is a Parameter which has the same shape as self.weight 82 | # Gradients to self.weight, self.bias have been turned off. 83 | 84 | def __init__(self, in_features, out_features, bias=True): 85 | super(SubnetLinear, self).__init__(in_features, out_features, bias=True) 86 | self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) 87 | nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) 88 | self.weight.requires_grad = False 89 | self.bias.requires_grad = False 90 | self.w = 0 91 | # self.register_buffer('w', None) 92 | 93 | def set_prune_rate(self, k): 94 | self.k = k 95 | 96 | def forward(self, x): 97 | # Get the subnetwork by sorting the scores. 98 | adj = GetSubnet.apply(self.popup_scores.abs(), self.k) 99 | 100 | # Use only the subnetwork in the forward pass. 101 | self.w = self.weight * adj 102 | x = F.linear(x, self.w, self.bias) 103 | 104 | return x 105 | 106 | -------------------------------------------------------------------------------- /trainer/freeadv.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | from torch.autograd import Variable 9 | 10 | from utils.logging import AverageMeter, ProgressMeter 11 | from utils.eval import accuracy 12 | from utils.adv import fgsm 13 | 14 | 15 | def train( 16 | model, 17 | device, 18 | train_loader, 19 | sm_loader, 20 | criterion, 21 | optimizer, 22 | epoch, 23 | args, 24 | writer=None, 25 | ): 26 | 27 | assert ( 28 | not args.normalize 29 | ), "Explicit normalization is done in the training loop, Dataset should have [0, 1] dynamic range." 30 | 31 | global_noise_data = torch.zeros( 32 | [args.batch_size, 3, args.image_dim, args.image_dim] 33 | ).to(device) 34 | 35 | mean = torch.Tensor(np.array(args.mean)[:, np.newaxis, np.newaxis]) 36 | mean = mean.expand(3, args.image_dim, args.image_dim).to(device) 37 | std = torch.Tensor(np.array(args.std)[:, np.newaxis, np.newaxis]) 38 | std = std.expand(3, args.image_dim, args.image_dim).to(device) 39 | 40 | batch_time = AverageMeter("Time", ":6.3f") 41 | data_time = AverageMeter("Data", ":6.3f") 42 | losses = AverageMeter("Loss", ":.4f") 43 | top1 = AverageMeter("Acc_1", ":6.2f") 44 | top5 = AverageMeter("Acc_5", ":6.2f") 45 | progress = ProgressMeter( 46 | len(train_loader), 47 | [batch_time, data_time, losses, top1, top5], 48 | prefix="Epoch: [{}]".format(epoch), 49 | ) 50 | 51 | # switch to train mode 52 | model.train() 53 | for i, (input, target) in enumerate(train_loader): 54 | end = time.time() 55 | input = input.to(device, non_blocking=True) 56 | target = target.to(device, non_blocking=True) 57 | data_time.update(time.time() - end) 58 | 59 | for _ in range(args.n_repeats): 60 | # Ascend on the global noise 61 | noise_batch = Variable( 62 | global_noise_data[0 : input.size(0)], requires_grad=True 63 | ).to(device) 64 | in1 = input + noise_batch 65 | in1.clamp_(0, 1.0) 66 | in1.sub_(mean).div_(std) 67 | output = model(in1) 68 | loss = criterion(output, target) 69 | 70 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 71 | losses.update(loss.item(), input.size(0)) 72 | top1.update(prec1[0], input.size(0)) 73 | top5.update(prec5[0], input.size(0)) 74 | 75 | # compute gradient and do SGD step 76 | optimizer.zero_grad() 77 | loss.backward() 78 | 79 | # Update the noise for the next iteration 80 | pert = fgsm(noise_batch.grad, args.epsilon) 81 | global_noise_data[0 : input.size(0)] += pert.data 82 | global_noise_data.clamp_(-args.epsilon, args.epsilon) 83 | 84 | optimizer.step() 85 | 86 | # measure elapsed time 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | 90 | if i % args.print_freq == 0: 91 | progress.display(i) 92 | progress.write_to_tensorboard( 93 | writer, "train", epoch * len(train_loader) + i 94 | ) 95 | 96 | if i == 0: 97 | print( 98 | in1.shape, 99 | target.shape, 100 | f"Batch_size from args: {args.batch_size}", 101 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 102 | ) 103 | print(f"Training images range: {[torch.min(in1), torch.max(in1)]}") 104 | 105 | # write a sample of training images to tensorboard (helpful for debugging) 106 | if i == 0: 107 | writer.add_image( 108 | "training-images", 109 | torchvision.utils.make_grid(input[0 : len(input) // 4]), 110 | ) 111 | 112 | -------------------------------------------------------------------------------- /trainer/smooth.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | from utils.logging import AverageMeter, ProgressMeter 9 | from utils.eval import accuracy 10 | 11 | # TODO: support sm_loader when len(sm_loader.dataset) < len(train_loader.dataset) 12 | def train( 13 | model, device, train_loader, sm_loader, criterion, optimizer, epoch, args, writer 14 | ): 15 | print(" ->->->->->->->->->-> One epoch with Natural training <-<-<-<-<-<-<-<-<-<-") 16 | 17 | batch_time = AverageMeter("Time", ":6.3f") 18 | data_time = AverageMeter("Data", ":6.3f") 19 | losses = AverageMeter("Loss", ":.4f") 20 | top1 = AverageMeter("Acc_1", ":6.2f") 21 | top5 = AverageMeter("Acc_5", ":6.2f") 22 | progress = ProgressMeter( 23 | len(train_loader), 24 | [batch_time, data_time, losses, top1, top5], 25 | prefix="Epoch: [{}]".format(epoch), 26 | ) 27 | 28 | model.train() 29 | end = time.time() 30 | 31 | dataloader = train_loader if sm_loader is None else zip(train_loader, sm_loader) 32 | 33 | for i, data in enumerate(dataloader): 34 | if sm_loader: 35 | images, target = ( 36 | torch.cat([d[0] for d in data], 0).to(device), 37 | torch.cat([d[1] for d in data], 0).to(device), 38 | ) 39 | else: 40 | images, target = data[0].to(device), data[1].to(device) 41 | 42 | # basic properties of training 43 | if i == 0: 44 | print( 45 | images.shape, 46 | target.shape, 47 | f"Batch_size from args: {args.batch_size}", 48 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 49 | ) 50 | print( 51 | "Pixel range for training images : [{}, {}]".format( 52 | torch.min(images).data.cpu().numpy(), 53 | torch.max(images).data.cpu().numpy(), 54 | ) 55 | ) 56 | 57 | # stability-loss 58 | if args.dataset == "imagenet": 59 | std = ( 60 | torch.tensor([0.229, 0.224, 0.225]) 61 | .unsqueeze(0) 62 | .unsqueeze(-1) 63 | .unsqueeze(-1) 64 | ).to(device) 65 | noise = (torch.randn_like(images) / std).to(device) * args.noise_std 66 | output = model(images + noise) 67 | loss = nn.CrossEntropyLoss()(output, target) 68 | else: 69 | output = model(images) 70 | loss_natural = nn.CrossEntropyLoss()(output, target) 71 | loss_robust = (1.0 / len(images)) * nn.KLDivLoss(size_average=False)( 72 | F.log_softmax( 73 | model( 74 | images + torch.randn_like(images).to(device) * args.noise_std 75 | ), 76 | dim=1, 77 | ), 78 | F.softmax(output, dim=1), 79 | ) 80 | loss = loss_natural + args.beta * loss_robust 81 | 82 | # measure accuracy and record loss 83 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 84 | losses.update(loss.item(), images.size(0)) 85 | top1.update(acc1[0], images.size(0)) 86 | top5.update(acc5[0], images.size(0)) 87 | 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | # measure elapsed time 93 | batch_time.update(time.time() - end) 94 | end = time.time() 95 | 96 | if i % args.print_freq == 0: 97 | progress.display(i) 98 | progress.write_to_tensorboard( 99 | writer, "train", epoch * len(train_loader) + i 100 | ) 101 | 102 | # write a sample of training images to tensorboard (helpful for debugging) 103 | if i == 0: 104 | writer.add_image( 105 | "training-images", 106 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 107 | ) 108 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | import yaml 5 | import sys 6 | import shutil, errno 7 | from distutils.dir_util import copy_tree 8 | from utils.model import subnet_to_dense 9 | 10 | 11 | def save_checkpoint( 12 | state, is_best, args, result_dir, filename="checkpoint.pth.tar", save_dense=False 13 | ): 14 | torch.save(state, os.path.join(result_dir, filename)) 15 | if is_best: 16 | shutil.copyfile( 17 | os.path.join(result_dir, filename), 18 | os.path.join(result_dir, "model_best.pth.tar"), 19 | ) 20 | 21 | if save_dense: 22 | state["state_dict"] = subnet_to_dense(state["state_dict"], args.k) 23 | torch.save( 24 | subnet_to_dense(state, args.k), 25 | os.path.join(result_dir, "checkpoint_dense.pth.tar"), 26 | ) 27 | if is_best: 28 | shutil.copyfile( 29 | os.path.join(result_dir, "checkpoint_dense.pth.tar"), 30 | os.path.join(result_dir, "model_best_dense.pth.tar"), 31 | ) 32 | 33 | 34 | def create_subdirs(sub_dir): 35 | os.mkdir(sub_dir) 36 | os.mkdir(os.path.join(sub_dir, "checkpoint")) 37 | 38 | 39 | def write_to_file(file, data, option): 40 | with open(file, option) as f: 41 | f.write(data) 42 | 43 | 44 | def clone_results_to_latest_subdir(src, dst): 45 | if not os.path.exists(dst): 46 | os.mkdir(dst) 47 | copy_tree(src, dst) 48 | 49 | 50 | # ref:https://github.com/allenai/hidden-networks/blob/master/configs/parser.py 51 | def trim_preceding_hyphens(st): 52 | i = 0 53 | while st[i] == "-": 54 | i += 1 55 | 56 | return st[i:] 57 | 58 | 59 | def arg_to_varname(st: str): 60 | st = trim_preceding_hyphens(st) 61 | st = st.replace("-", "_") 62 | 63 | return st.split("=")[0] 64 | 65 | 66 | def argv_to_vars(argv): 67 | var_names = [] 68 | for arg in argv: 69 | if arg.startswith("-") and arg_to_varname(arg) != "config": 70 | var_names.append(arg_to_varname(arg)) 71 | 72 | return var_names 73 | 74 | 75 | # ref: https://github.com/allenai/hidden-networks/blob/master/args.py 76 | def parse_configs_file(args): 77 | # get commands from command line 78 | override_args = argv_to_vars(sys.argv) 79 | 80 | # load yaml file 81 | yaml_txt = open(args.configs).read() 82 | 83 | # override args 84 | loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader) 85 | for v in override_args: 86 | loaded_yaml[v] = getattr(args, v) 87 | 88 | print(f"=> Reading YAML config from {args.configs}") 89 | args.__dict__.update(loaded_yaml) 90 | 91 | 92 | class AverageMeter(object): 93 | """Computes and stores the average and current value""" 94 | 95 | def __init__(self, name, fmt=":f"): 96 | self.name = name 97 | self.fmt = fmt 98 | self.reset() 99 | 100 | def reset(self): 101 | self.val = 0 102 | self.avg = 0 103 | self.sum = 0 104 | self.count = 0 105 | 106 | def update(self, val, n=1): 107 | self.val = val 108 | self.sum += val * n 109 | self.count += n 110 | self.avg = self.sum / self.count 111 | 112 | def __str__(self): 113 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 114 | return fmtstr.format(**self.__dict__) 115 | 116 | 117 | class ProgressMeter(object): 118 | def __init__(self, num_batches, meters, prefix=""): 119 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 120 | self.meters = meters 121 | self.prefix = prefix 122 | 123 | def display(self, batch): 124 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 125 | entries += [str(meter) for meter in self.meters] 126 | print("\t".join(entries)) 127 | 128 | def _get_batch_fmtstr(self, num_batches): 129 | num_digits = len(str(num_batches // 1)) 130 | fmt = "{:" + str(num_digits) + "d}" 131 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 132 | 133 | def write_to_tensorboard(self, writer, prefix, global_step): 134 | for meter in self.meters: 135 | writer.add_scalar(f"{prefix}/{meter.name}", meter.val, global_step) 136 | 137 | -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, SubsetRandomSampler 7 | 8 | # NOTE: Each dataset class must have public norm_layer, tr_train, tr_test objects. 9 | # These are needed for ood/semi-supervised dataset used alongwith in the training and eval. 10 | class CIFAR10: 11 | """ 12 | CIFAR-10 dataset. 13 | """ 14 | 15 | def __init__(self, args, normalize=False): 16 | self.args = args 17 | 18 | self.norm_layer = transforms.Normalize( 19 | mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262] 20 | ) 21 | 22 | self.tr_train = [ 23 | transforms.RandomCrop(32, padding=4), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | ] 27 | self.tr_test = [transforms.ToTensor()] 28 | 29 | if normalize: 30 | self.tr_train.append(self.norm_layer) 31 | self.tr_test.append(self.norm_layer) 32 | 33 | self.tr_train = transforms.Compose(self.tr_train) 34 | self.tr_test = transforms.Compose(self.tr_test) 35 | 36 | def data_loaders(self, **kwargs): 37 | trainset = datasets.CIFAR10( 38 | root=os.path.join(self.args.data_dir, "CIFAR10"), 39 | train=True, 40 | download=True, 41 | transform=self.tr_train, 42 | ) 43 | 44 | subset_indices = np.random.permutation(np.arange(len(trainset)))[ 45 | : int(self.args.data_fraction * len(trainset)) 46 | ] 47 | 48 | train_loader = DataLoader( 49 | trainset, 50 | batch_size=self.args.batch_size, 51 | sampler=SubsetRandomSampler(subset_indices), 52 | **kwargs, 53 | ) 54 | testset = datasets.CIFAR10( 55 | root=os.path.join(self.args.data_dir, "CIFAR10"), 56 | train=False, 57 | download=True, 58 | transform=self.tr_test, 59 | ) 60 | test_loader = DataLoader( 61 | testset, batch_size=self.args.test_batch_size, shuffle=False, **kwargs 62 | ) 63 | 64 | print( 65 | f"Traing loader: {len(train_loader.dataset)} images, Test loader: {len(test_loader.dataset)} images" 66 | ) 67 | return train_loader, test_loader 68 | 69 | 70 | 71 | class CIFAR100: 72 | """ 73 | CIFAR-100 dataset. 74 | """ 75 | 76 | def __init__(self, args, normalize=False): 77 | self.args = args 78 | 79 | self.norm_layer = transforms.Normalize( 80 | mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276] 81 | ) 82 | 83 | self.tr_train = [ 84 | transforms.RandomCrop(32, padding=4), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | ] 88 | self.tr_test = [transforms.ToTensor()] 89 | 90 | if normalize: 91 | self.tr_train.append(self.norm_layer) 92 | self.tr_test.append(self.norm_layer) 93 | 94 | self.tr_train = transforms.Compose(self.tr_train) 95 | self.tr_test = transforms.Compose(self.tr_test) 96 | 97 | def data_loaders(self, **kwargs): 98 | trainset = datasets.CIFAR100( 99 | root=os.path.join(self.args.data_dir, "CIFAR10"), 100 | train=True, 101 | download=True, 102 | transform=self.tr_train, 103 | ) 104 | 105 | subset_indices = np.random.permutation(np.arange(len(trainset)))[ 106 | : int(self.args.data_fraction * len(trainset)) 107 | ] 108 | 109 | train_loader = DataLoader( 110 | trainset, 111 | batch_size=self.args.batch_size, 112 | sampler=SubsetRandomSampler(subset_indices), 113 | **kwargs, 114 | ) 115 | testset = datasets.CIFAR10( 116 | root=os.path.join(self.args.data_dir, "CIFAR10"), 117 | train=False, 118 | download=True, 119 | transform=self.tr_test, 120 | ) 121 | test_loader = DataLoader( 122 | testset, batch_size=self.args.test_batch_size, shuffle=False, **kwargs 123 | ) 124 | 125 | print( 126 | f"Traing loader: {len(train_loader.dataset)} images, Test loader: {len(test_loader.dataset)} images" 127 | ) 128 | return train_loader, test_loader -------------------------------------------------------------------------------- /trainer/mixtrain.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | from utils.logging import AverageMeter, ProgressMeter 8 | from utils.eval import accuracy 9 | from utils.adv import trades_loss 10 | 11 | import numpy as np 12 | 13 | from symbolic_interval.symbolic_network import sym_interval_analyze, naive_interval_analyze, mix_interval_analyze 14 | 15 | 16 | def set_epsilon(args, epoch): 17 | if epoch->->->->->->->->-> One epoch with MixTrain{} (SYM {:.3f})" 42 | " <-<-<-<-<-<-<-<-<-<-".format(k, epsilon) 43 | ) 44 | 45 | batch_time = AverageMeter("Time", ":6.3f") 46 | data_time = AverageMeter("Data", ":6.3f") 47 | losses = AverageMeter("Loss", ":.4f") 48 | sym_losses = AverageMeter("Sym_Loss", ":.4f") 49 | top1 = AverageMeter("Acc_1", ":6.2f") 50 | sym1 = AverageMeter("Sym1", ":6.2f") 51 | progress = ProgressMeter( 52 | len(train_loader), 53 | [batch_time, data_time, losses, sym_losses, top1, sym1], 54 | prefix="Epoch: [{}]".format(epoch), 55 | ) 56 | 57 | model.train() 58 | end = time.time() 59 | 60 | dataloader = train_loader if sm_loader is None else zip(train_loader, sm_loader) 61 | 62 | for i, data in enumerate(dataloader): 63 | if sm_loader: 64 | images, target = ( 65 | torch.cat([d[0] for d in data], 0).to(device), 66 | torch.cat([d[1] for d in data], 0).to(device), 67 | ) 68 | else: 69 | images, target = data[0].to(device), data[1].to(device) 70 | 71 | # basic properties of training data 72 | if i == 0: 73 | print( 74 | images.shape, 75 | target.shape, 76 | f"Batch_size from args: {args.batch_size}", 77 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 78 | ) 79 | print(f"Training images range: {[torch.min(images), torch.max(images)]}") 80 | 81 | output = model(images) 82 | ce = nn.CrossEntropyLoss()(output, target) 83 | 84 | if(np.random.uniform()<=alpha): 85 | r = np.random.randint(low=0, high=images.shape[0], size=k) 86 | rce, rerr = sym_interval_analyze(model, epsilon, 87 | images[r], target[r], 88 | use_cuda=torch.cuda.is_available(), 89 | parallel=False) 90 | 91 | #print("sym:", rce.item(), ce.item()) 92 | loss = iw * rce + ce 93 | sym_losses.update(rce.item(), k) 94 | sym1.update((1-rerr)*100., images.size(0)) 95 | else: 96 | loss = ce 97 | 98 | # measure accuracy and record loss 99 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 100 | top1.update(acc1[0], images.size(0)) 101 | losses.update(ce.item(), images.size(0)) 102 | 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | 107 | # measure elapsed time 108 | batch_time.update(time.time() - end) 109 | end = time.time() 110 | 111 | if i % args.print_freq == 0: 112 | progress.display(i) 113 | progress.write_to_tensorboard( 114 | writer, "train", epoch * len(train_loader) + i 115 | ) 116 | 117 | # write a sample of training images to tensorboard (helpful for debugging) 118 | if i == 0: 119 | writer.add_image( 120 | "training-images", 121 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 122 | ) 123 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, conv_layer, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = conv_layer( 14 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 15 | ) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = conv_layer( 18 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 19 | ) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion * planes: 24 | self.shortcut = nn.Sequential( 25 | conv_layer( 26 | in_planes, 27 | self.expansion * planes, 28 | kernel_size=1, 29 | stride=stride, 30 | bias=False, 31 | ), 32 | nn.BatchNorm2d(self.expansion * planes), 33 | ) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, in_planes, planes, conv_layer, stride=1): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = conv_layer(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = conv_layer( 51 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 52 | ) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.conv3 = conv_layer( 55 | planes, self.expansion * planes, kernel_size=1, bias=False 56 | ) 57 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 58 | 59 | self.shortcut = nn.Sequential() 60 | if stride != 1 or in_planes != self.expansion * planes: 61 | self.shortcut = nn.Sequential( 62 | conv_layer( 63 | in_planes, 64 | self.expansion * planes, 65 | kernel_size=1, 66 | stride=stride, 67 | bias=False, 68 | ), 69 | nn.BatchNorm2d(self.expansion * planes), 70 | ) 71 | 72 | def forward(self, x): 73 | out = F.relu(self.bn1(self.conv1(x))) 74 | out = F.relu(self.bn2(self.conv2(out))) 75 | out = self.bn3(self.conv3(out)) 76 | out += self.shortcut(x) 77 | out = F.relu(out) 78 | return out 79 | 80 | 81 | class ResNet(nn.Module): 82 | def __init__(self, conv_layer, linear_layer, block, num_blocks, num_classes=10): 83 | super(ResNet, self).__init__() 84 | self.in_planes = 64 85 | self.conv_layer = conv_layer 86 | 87 | self.conv1 = conv_layer(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 90 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 91 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 92 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 93 | self.linear = linear_layer(512 * block.expansion, num_classes) 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1] * (num_blocks - 1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, self.conv_layer, stride)) 100 | self.in_planes = planes * block.expansion 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | out = F.relu(self.bn1(self.conv1(x))) 105 | out = self.layer1(out) 106 | out = self.layer2(out) 107 | out = self.layer3(out) 108 | out = self.layer4(out) 109 | out = F.avg_pool2d(out, 4) 110 | out = out.view(out.size(0), -1) 111 | out = self.linear(out) 112 | return out 113 | 114 | 115 | # NOTE: Only supporting default (kaiming_init) initializaition. 116 | def resnet18(conv_layer, linear_layer, init_type, **kwargs): 117 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 118 | return ResNet(conv_layer, linear_layer, BasicBlock, [2, 2, 2, 2], **kwargs) 119 | 120 | 121 | def resnet34(conv_layer, linear_layer, init_type, **kwargs): 122 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 123 | return ResNet(conv_layer, linear_layer, BasicBlock, [3, 4, 6, 3], **kwargs) 124 | 125 | 126 | def resnet50(conv_layer, linear_layer, init_type, **kwargs): 127 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 128 | return ResNet(conv_layer, linear_layer, Bottleneck, [3, 4, 6, 3], **kwargs) 129 | 130 | 131 | def resnet101(conv_layer, linear_layer, init_type, **kwargs): 132 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 133 | return ResNet(conv_layer, linear_layer, Bottleneck, [3, 4, 23, 3], **kwargs) 134 | 135 | 136 | def resnet152(conv_layer, linear_layer, init_type, **kwargs): 137 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 138 | return ResNet(conv_layer, linear_layer, Bottleneck, [3, 8, 36, 3], **kwargs) 139 | 140 | 141 | def test(): 142 | net = resnet18(nn.Conv2d, nn.Linear, "kaiming_normal") 143 | y = net(torch.randn(1, 3, 32, 32)) 144 | print(y.size()) 145 | -------------------------------------------------------------------------------- /eval_smoothing.py: -------------------------------------------------------------------------------- 1 | # evaluate a smoothed classifier on a dataset 2 | import argparse 3 | import os 4 | import sys 5 | import numpy as np 6 | from time import time 7 | import datetime 8 | import importlib 9 | import logging 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | import models 15 | import data 16 | 17 | from utils.smoothing import Smooth, quick_smoothing, eval_quick_smoothing 18 | from utils.model import get_layers 19 | 20 | parser = argparse.ArgumentParser(description="Certify many examples") 21 | parser.add_argument( 22 | "--dataset", 23 | type=str, 24 | choices=("CIFAR10", "CIFAR100", "SVHN", "MNIST", "imagenet"), 25 | help="Dataset for training and eval", 26 | ) 27 | parser.add_argument( 28 | "--normalize", 29 | action="store_true", 30 | default=False, 31 | help="whether to normalize the data", 32 | ) 33 | parser.add_argument( 34 | "--data-dir", type=str, default="./datasets", help="path to datasets" 35 | ) 36 | parser.add_argument( 37 | "--batch-size", 38 | type=int, 39 | default=128, 40 | metavar="N", 41 | help="input batch size for training (default: 128)", 42 | ) 43 | parser.add_argument( 44 | "--test-batch-size", 45 | type=int, 46 | default=128, 47 | metavar="N", 48 | help="input batch size for testing (default: 128)", 49 | ) 50 | parser.add_argument( 51 | "--data-fraction", 52 | type=float, 53 | default=1.0, 54 | help="Fraction of images used from training set", 55 | ) 56 | 57 | parser.add_argument( 58 | "--base_classifier", type=str, help="path to saved pytorch model of base classifier" 59 | ) 60 | parser.add_argument("--arch", type=str, default="vgg16_bn", help="Model achitecture") 61 | parser.add_argument( 62 | "--num-classes", type=int, default=10, help="Number of output classes in the model", 63 | ) 64 | parser.add_argument( 65 | "--layer-type", type=str, choices=("dense", "subnet"), help="dense | subnet layers" 66 | ) 67 | parser.add_argument( 68 | "--noise_std", type=float, default=0.25, help="noise hyperparameter" 69 | ) 70 | parser.add_argument("--outfile", type=str, help="output file") 71 | parser.add_argument("--skip", type=int, default=1, help="how many examples to skip") 72 | parser.add_argument("--max", type=int, default=-1, help="stop after this many examples") 73 | 74 | # parser.add_argument( 75 | # "--split", choices=["train", "test"], default="test", help="train or test set" 76 | # ) 77 | parser.add_argument( 78 | "--gpu", type=str, default="0", help="Comma separated list of GPU ids" 79 | ) 80 | parser.add_argument("--N0", type=int, default=100) 81 | parser.add_argument("--N", type=int, default=10000, help="number of samples to use") 82 | parser.add_argument("--alpha", type=float, default=0.001, help="failure probability") 83 | parser.add_argument( 84 | "--print-freq", 85 | type=int, 86 | default=100, 87 | help="Number of batches to wait before printing training logs", 88 | ) 89 | args = parser.parse_args() 90 | 91 | if __name__ == "__main__": 92 | 93 | # add logger 94 | logging.basicConfig(level=logging.INFO, format="%(message)s") 95 | logger = logging.getLogger() 96 | logger.addHandler(logging.FileHandler(args.outfile + ".log", "a")) 97 | logger.info(args) 98 | 99 | gpu_list = [int(i) for i in args.gpu.strip().split(",")] 100 | device = torch.device(f"cuda:{gpu_list[0]}") 101 | 102 | # Create model 103 | cl, ll = get_layers(args.layer_type) 104 | if len(gpu_list) > 1: 105 | print("Using multiple GPUs") 106 | base_classifier = nn.DataParallel( 107 | models.__dict__[args.arch]( 108 | cl, ll, "kaiming_normal", num_classes=args.num_classes 109 | ), 110 | gpu_list, 111 | ).to(device) 112 | else: 113 | base_classifier = models.__dict__[args.arch]( 114 | cl, ll, "kaiming_normal", num_classes=args.num_classes 115 | ).to(device) 116 | 117 | checkpoint = torch.load(args.base_classifier, map_location=device) 118 | base_classifier.load_state_dict(checkpoint["state_dict"]) 119 | 120 | smoothed_classifier = Smooth(base_classifier, args.num_classes, args.noise_std) 121 | 122 | # prepare output file 123 | f = open(args.outfile, "w") 124 | print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True) 125 | 126 | # Dataset 127 | D = data.__dict__[args.dataset](args, normalize=args.normalize) 128 | train_loader, test_loader = D.data_loaders() 129 | dataset = test_loader.dataset # Certify test inputs only (default) 130 | 131 | val = getattr(importlib.import_module("utils.eval"), "smooth") 132 | p, r = val(base_classifier, device, test_loader, nn.CrossEntropyLoss(), args, None) 133 | logger.info(f"Validation natural accuracy for source-net: {p}, radisu: {r}") 134 | 135 | # for i, v in base_classifier.named_modules(): 136 | # if isinstance(v, (nn.BatchNorm2d, nn.BatchNorm1d)): 137 | # v.track_running_stats = False 138 | 139 | # eval_quick_smoothing(base_classifier, train_loader, device, sigma=0.25, nbatch=10) 140 | 141 | base_classifier.eval() 142 | # sys.exit() 143 | 144 | for i in range(len(dataset)): 145 | 146 | # only certify every args.skip examples, and stop after args.max examples 147 | if i % args.skip != 0: 148 | continue 149 | if i == args.max: 150 | break 151 | 152 | (x, label) = dataset[i] 153 | 154 | before_time = time() 155 | # certify the prediction of g around x 156 | x = x.to(device) 157 | prediction, radius = smoothed_classifier.certify( 158 | x, args.N0, args.N, args.alpha, args.batch_size, device 159 | ) 160 | after_time = time() 161 | correct = int(prediction == label) 162 | 163 | time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time))) 164 | print( 165 | "{}\t{}\t{}\t{:.3}\t{}\t{}".format( 166 | i, label, prediction, radius, correct, time_elapsed 167 | ), 168 | file=f, 169 | flush=True, 170 | ) 171 | 172 | f.close() 173 | -------------------------------------------------------------------------------- /models/wrn_cifar.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, conv_layer, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = conv_layer( 13 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 14 | ) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = conv_layer( 18 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 19 | ) 20 | self.droprate = dropRate 21 | self.equalInOut = in_planes == out_planes 22 | self.convShortcut = ( 23 | (not self.equalInOut) 24 | and conv_layer( 25 | in_planes, 26 | out_planes, 27 | kernel_size=1, 28 | stride=stride, 29 | padding=0, 30 | bias=False, 31 | ) 32 | or None 33 | ) 34 | 35 | def forward(self, x): 36 | if not self.equalInOut: 37 | x = self.relu1(self.bn1(x)) 38 | else: 39 | out = self.relu1(self.bn1(x)) 40 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 41 | if self.droprate > 0: 42 | out = F.dropout(out, p=self.droprate, training=self.training) 43 | out = self.conv2(out) 44 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 45 | 46 | 47 | class NetworkBlock(nn.Module): 48 | def __init__( 49 | self, nb_layers, in_planes, out_planes, block, conv_layer, stride, dropRate=0.0 50 | ): 51 | super(NetworkBlock, self).__init__() 52 | self.layer = self._make_layer( 53 | conv_layer, block, in_planes, out_planes, nb_layers, stride, dropRate 54 | ) 55 | 56 | def _make_layer( 57 | self, conv_layer, block, in_planes, out_planes, nb_layers, stride, dropRate 58 | ): 59 | layers = [] 60 | for i in range(int(nb_layers)): 61 | layers.append( 62 | block( 63 | conv_layer, 64 | i == 0 and in_planes or out_planes, 65 | out_planes, 66 | i == 0 and stride or 1, 67 | dropRate, 68 | ) 69 | ) 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | return self.layer(x) 74 | 75 | 76 | class WideResNet(nn.Module): 77 | def __init__( 78 | self, 79 | conv_layer, 80 | linear_layer, 81 | depth=34, 82 | num_classes=10, 83 | widen_factor=10, 84 | dropRate=0.0, 85 | ): 86 | super(WideResNet, self).__init__() 87 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 88 | assert (depth - 4) % 6 == 0 89 | n = (depth - 4) / 6 90 | block = BasicBlock 91 | # 1st conv before any network block 92 | self.conv1 = conv_layer( 93 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False 94 | ) 95 | # 1st block 96 | self.block1 = NetworkBlock( 97 | n, nChannels[0], nChannels[1], block, conv_layer, 1, dropRate 98 | ) 99 | # 1st sub-block 100 | self.sub_block1 = NetworkBlock( 101 | n, nChannels[0], nChannels[1], block, conv_layer, 1, dropRate 102 | ) 103 | # 2nd block 104 | self.block2 = NetworkBlock( 105 | n, nChannels[1], nChannels[2], block, conv_layer, 2, dropRate 106 | ) 107 | # 3rd block 108 | self.block3 = NetworkBlock( 109 | n, nChannels[2], nChannels[3], block, conv_layer, 2, dropRate 110 | ) 111 | # global average pooling and classifier 112 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.fc = linear_layer(nChannels[3], num_classes) 115 | self.nChannels = nChannels[3] 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, linear_layer): 125 | m.bias.data.zero_() 126 | 127 | def forward(self, x): 128 | out = self.conv1(x) 129 | out = self.block1(out) 130 | out = self.block2(out) 131 | out = self.block3(out) 132 | out = self.relu(self.bn1(out)) 133 | out = F.avg_pool2d(out, 8) 134 | out = out.view(-1, self.nChannels) 135 | return self.fc(out) 136 | 137 | 138 | # NOTE: Only supporting default (kaiming_init) initializaition. 139 | def wrn_28_10(conv_layer, linear_layer, init_type, **kwargs): 140 | assert init_type == "kaiming_normal", "only supporting default init for WRN" 141 | return WideResNet(conv_layer, linear_layer, depth=28, widen_factor=10, **kwargs) 142 | 143 | 144 | def wrn_28_4(conv_layer, linear_layer, init_type, **kwargs): 145 | assert init_type == "kaiming_normal", "only supporting default init for WRN" 146 | return WideResNet(conv_layer, linear_layer, depth=28, widen_factor=4, **kwargs) 147 | 148 | 149 | def wrn_28_1(conv_layer, linear_layer, init_type, **kwargs): 150 | assert init_type == "kaiming_normal", "only supporting default init for WRN" 151 | return WideResNet(conv_layer, linear_layer, depth=28, widen_factor=1, **kwargs) 152 | 153 | 154 | def wrn_34_10(conv_layer, linear_layer, init_type, **kwargs): 155 | assert init_type == "kaiming_normal", "only supporting default init for WRN" 156 | return WideResNet(conv_layer, linear_layer, depth=34, widen_factor=10, **kwargs) 157 | 158 | 159 | def wrn_40_2(conv_layer, linear_layer, init_type, **kwargs): 160 | assert init_type == "kaiming_normal", "only supporting default init for WRN" 161 | return WideResNet(conv_layer, linear_layer, depth=40, widen_factor=2, **kwargs) 162 | 163 | -------------------------------------------------------------------------------- /trainer/crown-ibp.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | from utils.logging import AverageMeter, ProgressMeter 8 | from utils.eval import accuracy 9 | from utils.adv import trades_loss 10 | 11 | import numpy as np 12 | 13 | from symbolic_interval.symbolic_network import sym_interval_analyze, naive_interval_analyze 14 | from crown.eps_scheduler import EpsilonScheduler 15 | from crown.bound_layers import BoundSequential, BoundLinear, BoundConv2d, BoundDataParallel, Flatten 16 | 17 | 18 | def train( 19 | model, device, train_loader, sm_loader, criterion, optimizer, epoch, args, writer 20 | ): 21 | num_class = 10 22 | 23 | sa = np.zeros((num_class, num_class - 1), dtype = np.int32) 24 | for i in range(sa.shape[0]): 25 | for j in range(sa.shape[1]): 26 | if j < i: 27 | sa[i][j] = j 28 | else: 29 | sa[i][j] = j + 1 30 | sa = torch.LongTensor(sa) 31 | batch_size = args.batch_size*2 32 | 33 | schedule_start = 0 34 | num_steps_per_epoch = len(train_loader) 35 | eps_scheduler = EpsilonScheduler("linear", 36 | args.schedule_start, 37 | ((args.schedule_start + args.schedule_length) - 1) *\ 38 | num_steps_per_epoch, args.starting_epsilon, 39 | args.epsilon, 40 | num_steps_per_epoch) 41 | 42 | end_eps = eps_scheduler.get_eps(epoch+1, 0) 43 | start_eps = eps_scheduler.get_eps(epoch, 0) 44 | 45 | 46 | print( 47 | " ->->->->->->->->->-> One epoch with CROWN-IBP ({:.6f}-{:.6f})" 48 | " <-<-<-<-<-<-<-<-<-<-".format(start_eps, end_eps) 49 | ) 50 | 51 | batch_time = AverageMeter("Time", ":6.3f") 52 | data_time = AverageMeter("Data", ":6.3f") 53 | losses = AverageMeter("Loss", ":.4f") 54 | ibp_losses = AverageMeter("IBP_Loss", ":.4f") 55 | top1 = AverageMeter("Acc_1", ":6.2f") 56 | ibp_acc1 = AverageMeter("IBP1", ":6.2f") 57 | progress = ProgressMeter( 58 | len(train_loader), 59 | [batch_time, data_time, losses, ibp_losses, top1, ibp_acc1], 60 | prefix="Epoch: [{}]".format(epoch), 61 | ) 62 | 63 | model = BoundSequential.convert(model,\ 64 | {'same-slope': False, 'zero-lb': False,\ 65 | 'one-lb': False}).to(device) 66 | 67 | model.train() 68 | end = time.time() 69 | 70 | dataloader = train_loader if sm_loader is None else zip(train_loader, sm_loader) 71 | 72 | for i, data in enumerate(dataloader): 73 | if sm_loader: 74 | images, target = ( 75 | torch.cat([d[0] for d in data], 0).to(device), 76 | torch.cat([d[1] for d in data], 0).to(device), 77 | ) 78 | else: 79 | images, target = data[0].to(device), data[1].to(device) 80 | 81 | # basic properties of training data 82 | if i == 0: 83 | print( 84 | images.shape, 85 | target.shape, 86 | f"Batch_size from args: {args.batch_size}", 87 | "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]), 88 | ) 89 | print(f"Training images range: {[torch.min(images), torch.max(images)]}") 90 | 91 | output = model(images, method_opt="forward") 92 | ce = nn.CrossEntropyLoss()(output, target) 93 | 94 | eps = eps_scheduler.get_eps(epoch, i) 95 | # generate specifications 96 | c = torch.eye(num_class).type_as(images)[target].unsqueeze(1) -\ 97 | torch.eye(num_class).type_as(images).unsqueeze(0) 98 | # remove specifications to self 99 | I = (~(target.unsqueeze(1) ==\ 100 | torch.arange(num_class).to(device).type_as(target).unsqueeze(0))) 101 | c = (c[I].view(images.size(0),num_class-1,num_class)).to(device) 102 | # scatter matrix to avoid compute margin to self 103 | sa_labels = sa[target].to(device) 104 | # storing computed lower bounds after scatter 105 | lb_s = torch.zeros(images.size(0), num_class).to(device) 106 | ub_s = torch.zeros(images.size(0), num_class).to(device) 107 | 108 | data_ub = torch.min(images + eps, images.max()).to(device) 109 | data_lb = torch.max(images - eps, images.min()).to(device) 110 | 111 | ub, ilb, relu_activity, unstable, dead, alive =\ 112 | model(norm=np.inf, x_U=data_ub, x_L=data_lb,\ 113 | eps=eps, C=c, method_opt="interval_range") 114 | 115 | crown_final_beta = 0. 116 | beta = (args.epsilon - eps * (1.0 - crown_final_beta)) / args.epsilon 117 | 118 | if beta < 1e-5: 119 | # print("pure naive") 120 | lb = ilb 121 | else: 122 | # print("crown-ibp") 123 | # get the CROWN bound using interval bounds 124 | _, _, clb, bias = model(norm=np.inf, x_U=data_ub,\ 125 | x_L=data_lb, eps=eps, C=c,\ 126 | method_opt="backward_range") 127 | # how much better is crown-ibp better than ibp? 128 | # diff = (clb - ilb).sum().item() 129 | lb = clb * beta + ilb * (1 - beta) 130 | 131 | lb = lb_s.scatter(1, sa_labels, lb) 132 | robust_ce = criterion(-lb, target) 133 | 134 | #print(ce, robust_ce) 135 | racc = accuracy(-lb, target, topk=(1,)) 136 | 137 | loss = robust_ce 138 | 139 | # measure accuracy and record loss 140 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 141 | top1.update(acc1[0].item(), images.size(0)) 142 | losses.update(ce.item(), images.size(0)) 143 | ibp_losses.update(robust_ce.item(), images.size(0)) 144 | ibp_acc1.update(racc[0].item(), images.size(0)) 145 | 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | 150 | # measure elapsed time 151 | batch_time.update(time.time() - end) 152 | end = time.time() 153 | 154 | if i % args.print_freq == 0: 155 | progress.display(i) 156 | progress.write_to_tensorboard( 157 | writer, "train", epoch * len(train_loader) + i 158 | ) 159 | 160 | # write a sample of training images to tensorboard (helpful for debugging) 161 | if i == 0: 162 | writer.add_image( 163 | "training-images", 164 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 165 | ) 166 | -------------------------------------------------------------------------------- /models/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class VGG(nn.Module): 7 | def __init__(self, features, last_conv_channels_len, linear_layer, num_classes=10): 8 | super(VGG, self).__init__() 9 | self.features = features 10 | self.avgpool = nn.AdaptiveAvgPool2d((2, 2)) 11 | self.classifier = nn.Sequential( 12 | linear_layer(last_conv_channels_len * 2 * 2, 256), 13 | nn.ReLU(True), 14 | linear_layer(256, 256), 15 | nn.ReLU(True), 16 | linear_layer(256, num_classes), 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.features(x) 21 | x = self.avgpool(x) 22 | x = torch.flatten(x, 1) 23 | x = self.classifier(x) 24 | return x 25 | 26 | 27 | def initialize_weights(model, init_type): 28 | print(f"Initializing model with {init_type}") 29 | assert init_type in ["kaiming_normal", "kaiming_uniform", "signed_const"] 30 | for m in model.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 33 | if init_type == "signed_const": 34 | n = math.sqrt( 35 | 2.0 / (m.kernel_size[0] * m.kernel_size[1] * m.in_channels) 36 | ) 37 | m.weight.data = m.weight.data.sign() * n 38 | elif init_type == "kaiming_uniform": 39 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="relu") 40 | if m.bias is not None: 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.Linear): 43 | m.weight.data.normal_(0, 0.01) 44 | m.bias.data.zero_() 45 | if init_type == "signed_const": 46 | n = math.sqrt(2.0 / m.in_features) 47 | m.weight.data = m.weight.data.sign() * n 48 | elif isinstance(m, nn.BatchNorm2d): 49 | m.weight.data.fill_(1) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, conv_layer, batch_norm=True): 54 | layers = [] 55 | in_channels = 3 56 | for v in cfg: 57 | if v == "M": 58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 59 | else: 60 | conv2d = conv_layer(in_channels, v, kernel_size=3, padding=1, bias=False) 61 | if batch_norm: 62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 63 | else: 64 | layers += [conv2d, nn.ReLU(inplace=True)] 65 | in_channels = v 66 | return nn.Sequential(*layers) 67 | 68 | 69 | cfgs = { 70 | "2": [64, "M", 64, "M"], 71 | "4": [64, 64, "M", 128, 128, "M"], 72 | "6": [64, 64, "M", 128, 128, "M", 256, 256, "M"], 73 | "8": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M"], 74 | "11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], 75 | "13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], 76 | "16": [ 77 | 64, 78 | 64, 79 | "M", 80 | 128, 81 | 128, 82 | "M", 83 | 256, 84 | 256, 85 | 256, 86 | "M", 87 | 512, 88 | 512, 89 | 512, 90 | "M", 91 | 512, 92 | 512, 93 | 512, 94 | ], 95 | } 96 | 97 | 98 | def vgg2(conv_layer, linear_layer, init_type, **kwargs): 99 | n = [i for i in cfgs["2"] if isinstance(i, int)][-1] 100 | model = VGG( 101 | make_layers(cfgs["2"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 102 | ) 103 | initialize_weights(model, init_type) 104 | return model 105 | 106 | 107 | def vgg2_bn(conv_layer, linear_layer, init_type, **kwargs): 108 | n = [i for i in cfgs["2"] if isinstance(i, int)][-1] 109 | model = VGG( 110 | make_layers(cfgs["2"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 111 | ) 112 | initialize_weights(model, init_type) 113 | return model 114 | 115 | 116 | def vgg4(conv_layer, linear_layer, init_type, **kwargs): 117 | n = [i for i in cfgs["4"] if isinstance(i, int)][-1] 118 | model = VGG( 119 | make_layers(cfgs["4"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 120 | ) 121 | initialize_weights(model, init_type) 122 | return model 123 | 124 | 125 | def vgg4_bn(conv_layer, linear_layer, init_type, **kwargs): 126 | n = [i for i in cfgs["4"] if isinstance(i, int)][-1] 127 | model = VGG( 128 | make_layers(cfgs["4"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 129 | ) 130 | initialize_weights(model, init_type) 131 | return model 132 | 133 | 134 | def vgg6(conv_layer, linear_layer, init_type, **kwargs): 135 | n = [i for i in cfgs["6"] if isinstance(i, int)][-1] 136 | model = VGG( 137 | make_layers(cfgs["6"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 138 | ) 139 | initialize_weights(model, init_type) 140 | return model 141 | 142 | 143 | def vgg6_bn(conv_layer, linear_layer, init_type, **kwargs): 144 | n = [i for i in cfgs["6"] if isinstance(i, int)][-1] 145 | model = VGG( 146 | make_layers(cfgs["6"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 147 | ) 148 | initialize_weights(model, init_type) 149 | return model 150 | 151 | 152 | def vgg8(conv_layer, linear_layer, init_type, **kwargs): 153 | n = [i for i in cfgs["8"] if isinstance(i, int)][-1] 154 | model = VGG( 155 | make_layers(cfgs["8"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 156 | ) 157 | initialize_weights(model, init_type) 158 | return model 159 | 160 | 161 | def vgg8_bn(conv_layer, linear_layer, init_type, **kwargs): 162 | n = [i for i in cfgs["8"] if isinstance(i, int)][-1] 163 | model = VGG( 164 | make_layers(cfgs["8"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 165 | ) 166 | initialize_weights(model, init_type) 167 | return model 168 | 169 | 170 | def vgg11(conv_layer, linear_layer, init_type, **kwargs): 171 | n = [i for i in cfgs["11"] if isinstance(i, int)][-1] 172 | model = VGG( 173 | make_layers(cfgs["11"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 174 | ) 175 | initialize_weights(model, init_type) 176 | return model 177 | 178 | 179 | def vgg11_bn(conv_layer, linear_layer, init_type, **kwargs): 180 | n = [i for i in cfgs["11"] if isinstance(i, int)][-1] 181 | model = VGG( 182 | make_layers(cfgs["11"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 183 | ) 184 | initialize_weights(model, init_type) 185 | return model 186 | 187 | 188 | def vgg13(conv_layer, linear_layer, init_type, **kwargs): 189 | n = [i for i in cfgs["13"] if isinstance(i, int)][-1] 190 | model = VGG( 191 | make_layers(cfgs["13"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 192 | ) 193 | initialize_weights(model, init_type) 194 | return model 195 | 196 | 197 | def vgg13_bn(conv_layer, linear_layer, init_type, **kwargs): 198 | n = [i for i in cfgs["13"] if isinstance(i, int)][-1] 199 | model = VGG( 200 | make_layers(cfgs["13"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 201 | ) 202 | initialize_weights(model, init_type) 203 | return model 204 | 205 | 206 | def vgg16(conv_layer, linear_layer, init_type, **kwargs): 207 | n = [i for i in cfgs["16"] if isinstance(i, int)][-1] 208 | model = VGG( 209 | make_layers(cfgs["16"], conv_layer, batch_norm=False), n, linear_layer, **kwargs 210 | ) 211 | initialize_weights(model, init_type) 212 | return model 213 | 214 | 215 | def vgg16_bn(conv_layer, linear_layer, init_type, **kwargs): 216 | n = [i for i in cfgs["16"] if isinstance(i, int)][-1] 217 | model = VGG( 218 | make_layers(cfgs["16"], conv_layer, batch_norm=True), n, linear_layer, **kwargs 219 | ) 220 | initialize_weights(model, init_type) 221 | return model 222 | -------------------------------------------------------------------------------- /utils/smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from scipy.stats import norm, binom_test 4 | import numpy as np 5 | from math import ceil 6 | from statsmodels.stats.proportion import proportion_confint 7 | 8 | 9 | def eval_quick_smoothing(model, loader, device, sigma=0.25, nbatch=10): 10 | pred = [] 11 | rad = [] 12 | 13 | # model.eval() 14 | 15 | for index, (x, y) in enumerate(loader): 16 | q = quick_smoothing( 17 | model, 18 | x.to(device), 19 | y.to(device), 20 | device, 21 | sigma=sigma, 22 | eps=0.25, 23 | num_smooth=10, 24 | batch_size=100, 25 | softmax_temperature=1.0, 26 | detailed_output=True, 27 | ) 28 | pred += list(q[0]) 29 | rad += list(q[1]) 30 | if index == nbatch: 31 | break 32 | # print(pred) 33 | print(f"Mean smooth accuracy (len = {len(pred)})= ", np.mean(pred)) 34 | print(f"Mean rad (len={len(rad)})", np.mean(rad)) 35 | 36 | 37 | def quick_smoothing( 38 | model, 39 | x, 40 | y, 41 | device, 42 | sigma=1.0, 43 | eps=1.0, 44 | num_smooth=100, 45 | batch_size=1000, 46 | softmax_temperature=100.0, 47 | detailed_output=False, 48 | ): 49 | """Quick and dirty randomized smoothing 'certification', without proper 50 | confidence bounds. We use it only to monitor training.""" 51 | 52 | x_noise = x.view(1, *x.shape) + sigma * torch.randn(num_smooth, *x.shape).to(device) 53 | x_noise = x_noise.view(-1, *x.shape[1:]) 54 | # by setting a high softmax temperature, we are effectively using the 55 | # randomized smoothing approach as originally defined 56 | # it will be interesting to see if lower temperatures help 57 | 58 | preds = torch.cat( 59 | [ 60 | F.softmax(softmax_temperature * model(batch), dim=-1) 61 | for batch in torch.split(x_noise, batch_size) 62 | ] 63 | ) 64 | 65 | preds = preds.view(num_smooth, x.shape[0], -1).mean(dim=0) 66 | p_max, y_pred = preds.max(dim=-1) 67 | 68 | correct = (y_pred == y).data.cpu().numpy().astype("int64") 69 | radii = (sigma + 1e-16) * norm.ppf(p_max.data.cpu().numpy()) 70 | 71 | err = (1 - correct).sum() 72 | robust_err = (1 - correct * (radii >= eps)).sum() 73 | 74 | if not detailed_output: 75 | return err, robust_err 76 | else: 77 | return correct, radii 78 | 79 | 80 | class Smooth(object): 81 | """A smoothed classifier g """ 82 | 83 | # to abstain, Smooth returns this int 84 | ABSTAIN = -1 85 | 86 | def __init__( 87 | self, base_classifier: torch.nn.Module, num_classes: int, sigma: float 88 | ): 89 | """ 90 | :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes] 91 | :param num_classes: 92 | :param sigma: the noise level hyperparameter 93 | """ 94 | self.base_classifier = base_classifier 95 | self.num_classes = num_classes 96 | self.sigma = sigma 97 | 98 | def certify( 99 | self, 100 | x: torch.tensor, 101 | n0: int, 102 | n: int, 103 | alpha: float, 104 | batch_size: int, 105 | device: str, 106 | ) -> (int, float): 107 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius. 108 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will 109 | robust within a L2 ball of radius R around x. 110 | 111 | :param x: the input [channel x height x width] 112 | :param n0: the number of Monte Carlo samples to use for selection 113 | :param n: the number of Monte Carlo samples to use for estimation 114 | :param alpha: the failure probability 115 | :param batch_size: batch size to use when evaluating the base classifier 116 | :return: (predicted class, certified radius) 117 | in the case of abstention, the class will be ABSTAIN and the radius 0. 118 | """ 119 | self.base_classifier.eval() 120 | # draw samples of f(x+ epsilon) 121 | counts_selection = self._sample_noise(x, n0, batch_size, device) 122 | # use these samples to take a guess at the top class 123 | cAHat = counts_selection.argmax().item() 124 | # draw more samples of f(x + epsilon) 125 | counts_estimation = self._sample_noise(x, n, batch_size, device) 126 | # use these samples to estimate a lower bound on pA 127 | nA = counts_estimation[cAHat].item() 128 | pABar = self._lower_confidence_bound(nA, n, alpha) 129 | if pABar < 0.5: 130 | return Smooth.ABSTAIN, 0.0 131 | else: 132 | radius = self.sigma * norm.ppf(pABar) 133 | return cAHat, radius 134 | 135 | def predict( 136 | self, x: torch.tensor, n: int, alpha: float, batch_size: int, device: str 137 | ) -> int: 138 | """ Monte Carlo algorithm for evaluating the prediction of g at x. With probability at least 1 - alpha, the 139 | class returned by this method will equal g(x). 140 | 141 | This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944 142 | for identifying the top category of a multinomial distribution. 143 | 144 | :param x: the input [channel x height x width] 145 | :param n: the number of Monte Carlo samples to use 146 | :param alpha: the failure probability 147 | :param batch_size: batch size to use when evaluating the base classifier 148 | :return: the predicted class, or ABSTAIN 149 | """ 150 | self.base_classifier.eval() 151 | counts = self._sample_noise(x, n, batch_size, device) 152 | top2 = counts.argsort()[::-1][:2] 153 | count1 = counts[top2[0]] 154 | count2 = counts[top2[1]] 155 | if binom_test(count1, count1 + count2, p=0.5) > alpha: 156 | return Smooth.ABSTAIN 157 | else: 158 | return top2[0] 159 | 160 | def _sample_noise( 161 | self, x: torch.tensor, num: int, batch_size, device 162 | ) -> np.ndarray: 163 | """ Sample the base classifier's prediction under noisy corruptions of the input x. 164 | 165 | :param x: the input [channel x width x height] 166 | :param num: number of samples to collect 167 | :param batch_size: 168 | :return: an ndarray[int] of length num_classes containing the per-class counts 169 | """ 170 | with torch.no_grad(): 171 | counts = np.zeros(self.num_classes, dtype=int) 172 | for _ in range(ceil(num / batch_size)): 173 | this_batch_size = min(batch_size, num) 174 | num -= this_batch_size 175 | 176 | batch = x.repeat((this_batch_size, 1, 1, 1)) 177 | noise = torch.randn_like(batch, device=device) * self.sigma 178 | predictions = self.base_classifier(batch + noise).argmax(1) 179 | counts += self._count_arr( 180 | predictions.data.cpu().numpy(), self.num_classes 181 | ) 182 | return counts 183 | 184 | def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray: 185 | counts = np.zeros(length, dtype=int) 186 | for idx in arr: 187 | counts[idx] += 1 188 | return counts 189 | 190 | def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float: 191 | """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion. 192 | 193 | This function uses the Clopper-Pearson method. 194 | 195 | :param NA: the number of "successes" 196 | :param N: the number of total draws 197 | :param alpha: the confidence level 198 | :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples 199 | """ 200 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0] 201 | -------------------------------------------------------------------------------- /utils/adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | from utils.misc import xe_with_one_hot 7 | 8 | 9 | def squared_l2_norm(x): 10 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 11 | return (flattened ** 2).sum(1) 12 | 13 | 14 | def l2_norm(x): 15 | return squared_l2_norm(x).sqrt() 16 | 17 | 18 | # ref: https://github.com/yaodongyu/TRADES 19 | def trades_loss( 20 | model, 21 | x_natural, 22 | y, 23 | device, 24 | optimizer, 25 | step_size, 26 | epsilon, 27 | perturb_steps, 28 | beta, 29 | clip_min, 30 | clip_max, 31 | distance="l_inf", 32 | natural_criterion=nn.CrossEntropyLoss(), 33 | ): 34 | # define KL-loss 35 | criterion_kl = nn.KLDivLoss(size_average=False) 36 | model.eval() 37 | batch_size = len(x_natural) 38 | # generate adversarial example 39 | x_adv = ( 40 | x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach() 41 | ) 42 | if distance == "l_inf": 43 | for _ in range(perturb_steps): 44 | x_adv.requires_grad_() 45 | with torch.enable_grad(): 46 | loss_kl = criterion_kl( 47 | F.log_softmax(model(x_adv), dim=1), 48 | F.softmax(model(x_natural), dim=1), 49 | ) 50 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 51 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 52 | x_adv = torch.min( 53 | torch.max(x_adv, x_natural - epsilon), x_natural + epsilon 54 | ) 55 | x_adv = torch.clamp(x_adv, clip_min, clip_max) 56 | elif distance == "l_2": 57 | delta = 0.001 * torch.randn(x_natural.shape).to(device).detach() 58 | delta = Variable(delta.data, requires_grad=True) 59 | 60 | # Setup optimizers 61 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 62 | 63 | for _ in range(perturb_steps): 64 | adv = x_natural + delta 65 | 66 | # optimize 67 | optimizer_delta.zero_grad() 68 | with torch.enable_grad(): 69 | loss = (-1) * criterion_kl( 70 | F.log_softmax(model(adv), dim=1), F.softmax(model(x_natural), dim=1) 71 | ) 72 | loss.backward() 73 | # renorming gradient 74 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 75 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 76 | # avoid nan or inf if gradient is 0 77 | if (grad_norms == 0).any(): 78 | delta.grad[grad_norms == 0] = torch.randn_like( 79 | delta.grad[grad_norms == 0] 80 | ) 81 | optimizer_delta.step() 82 | 83 | # projection 84 | delta.data.add_(x_natural) 85 | delta.data.clamp_(clip_min, clip_max).sub_(x_natural) 86 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 87 | x_adv = Variable(x_natural + delta, requires_grad=False) 88 | else: 89 | x_adv = torch.clamp(x_adv, clip_min, clip_max) 90 | model.train() 91 | 92 | x_adv = Variable(torch.clamp(x_adv, clip_min, clip_max), requires_grad=False) 93 | # zero gradient 94 | optimizer.zero_grad() 95 | # calculate robust loss 96 | logits = model(x_natural) 97 | loss_natural = natural_criterion(logits, y) 98 | loss_robust = (1.0 / batch_size) * criterion_kl( 99 | F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1) 100 | ) 101 | loss = loss_natural + beta * loss_robust 102 | return loss 103 | 104 | 105 | def trades_loss_hot_vector( 106 | model, 107 | x_natural, 108 | y, 109 | device, 110 | optimizer, 111 | step_size, 112 | epsilon, 113 | perturb_steps, 114 | beta, 115 | clip_min, 116 | clip_max, 117 | distance="l_inf", 118 | ): 119 | # define KL-loss 120 | criterion_kl = nn.KLDivLoss(size_average=False) 121 | model.eval() 122 | batch_size = len(x_natural) 123 | # generate adversarial example 124 | x_adv = ( 125 | x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach() 126 | ) 127 | if distance == "l_inf": 128 | for _ in range(perturb_steps): 129 | x_adv.requires_grad_() 130 | with torch.enable_grad(): 131 | loss_kl = criterion_kl( 132 | F.log_softmax(model(x_adv), dim=1), 133 | F.softmax(model(x_natural), dim=1), 134 | ) 135 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 136 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 137 | x_adv = torch.min( 138 | torch.max(x_adv, x_natural - epsilon), x_natural + epsilon 139 | ) 140 | x_adv = torch.clamp(x_adv, clip_min, clip_max) 141 | elif distance == "l_2": 142 | delta = 0.001 * torch.randn(x_natural.shape).to(device).detach() 143 | delta = Variable(delta.data, requires_grad=True) 144 | 145 | # Setup optimizers 146 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 147 | 148 | for _ in range(perturb_steps): 149 | adv = x_natural + delta 150 | 151 | # optimize 152 | optimizer_delta.zero_grad() 153 | with torch.enable_grad(): 154 | loss = (-1) * criterion_kl( 155 | F.log_softmax(model(adv), dim=1), F.softmax(model(x_natural), dim=1) 156 | ) 157 | loss.backward() 158 | # renorming gradient 159 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 160 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 161 | # avoid nan or inf if gradient is 0 162 | if (grad_norms == 0).any(): 163 | delta.grad[grad_norms == 0] = torch.randn_like( 164 | delta.grad[grad_norms == 0] 165 | ) 166 | optimizer_delta.step() 167 | 168 | # projection 169 | delta.data.add_(x_natural) 170 | delta.data.clamp_(clip_min, clip_max).sub_(x_natural) 171 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 172 | x_adv = Variable(x_natural + delta, requires_grad=False) 173 | else: 174 | x_adv = torch.clamp(x_adv, clip_min, clip_max) 175 | model.train() 176 | 177 | x_adv = Variable(torch.clamp(x_adv, clip_min, clip_max), requires_grad=False) 178 | # zero gradient 179 | optimizer.zero_grad() 180 | # calculate robust loss 181 | logits = model(x_natural) 182 | loss_natural = F.cross_entropy(logits, y) 183 | loss_robust = (1.0 / batch_size) * criterion_kl( 184 | F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1) 185 | ) 186 | loss = loss_natural + beta * loss_robust 187 | return loss 188 | 189 | 190 | # TODO: support L-2 attacks too. 191 | def pgd_whitebox( 192 | model, 193 | x, 194 | y, 195 | device, 196 | epsilon, 197 | num_steps, 198 | step_size, 199 | clip_min, 200 | clip_max, 201 | is_random=True, 202 | ): 203 | 204 | x_pgd = Variable(x.data, requires_grad=True) 205 | if is_random: 206 | random_noise = ( 207 | torch.FloatTensor(x_pgd.shape).uniform_(-epsilon, epsilon).to(device) 208 | ) 209 | x_pgd = Variable(x_pgd.data + random_noise, requires_grad=True) 210 | 211 | for _ in range(num_steps): 212 | opt = optim.SGD([x_pgd], lr=1e-3) 213 | opt.zero_grad() 214 | 215 | with torch.enable_grad(): 216 | loss = nn.CrossEntropyLoss()(model(x_pgd), y) 217 | loss.backward() 218 | eta = step_size * x_pgd.grad.data.sign() 219 | x_pgd = Variable(x_pgd.data + eta, requires_grad=True) 220 | eta = torch.clamp(x_pgd.data - x.data, -epsilon, epsilon) 221 | x_pgd = Variable(x.data + eta, requires_grad=True) 222 | x_pgd = Variable(torch.clamp(x_pgd, clip_min, clip_max), requires_grad=True) 223 | 224 | return x_pgd 225 | 226 | 227 | def fgsm(gradz, step_size): 228 | return step_size * torch.sign(gradz) 229 | 230 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, conv_layer, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return conv_layer(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, conv_layer, stride=1): 11 | """1x1 convolution""" 12 | return conv_layer(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False, dilation=1) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | __constants__ = ['downsample'] 18 | 19 | def __init__(self, inplanes, planes, conv_layer, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, conv_layer, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes, conv_layer) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes, conv_layer, stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv1x1(inplanes, width, conv_layer) 68 | self.bn1 = norm_layer(width) 69 | self.conv2 = conv3x3(width, width, conv_layer, stride, groups, dilation) 70 | self.bn2 = norm_layer(width) 71 | self.conv3 = conv1x1(width, planes * self.expansion, conv_layer) 72 | self.bn3 = norm_layer(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, conv_layer, linear_layer, block, layers, num_classes=1000, zero_init_residual=False, 102 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 103 | norm_layer=None): 104 | super(ResNet, self).__init__() 105 | if norm_layer is None: 106 | norm_layer = nn.BatchNorm2d 107 | self._norm_layer = norm_layer 108 | self.conv_layer = conv_layer 109 | 110 | self.inplanes = 64 111 | self.dilation = 1 112 | if replace_stride_with_dilation is None: 113 | # each element in the tuple indicates if we should replace 114 | # the 2x2 stride with a dilated convolution instead 115 | replace_stride_with_dilation = [False, False, False] 116 | if len(replace_stride_with_dilation) != 3: 117 | raise ValueError("replace_stride_with_dilation should be None " 118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | 122 | self.conv1 = conv_layer(3, self.inplanes, kernel_size=7, stride=2, padding=3, 123 | bias=False) 124 | self.bn1 = norm_layer(self.inplanes) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 127 | self.layer1 = self._make_layer(block, 64, layers[0]) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 129 | dilate=replace_stride_with_dilation[0]) 130 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 131 | dilate=replace_stride_with_dilation[1]) 132 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 133 | dilate=replace_stride_with_dilation[2]) 134 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.fc = linear_layer(512 * block.expansion, num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | # Zero-initialize the last BN in each residual branch, 145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 147 | if zero_init_residual: 148 | for m in self.modules(): 149 | if isinstance(m, Bottleneck): 150 | nn.init.constant_(m.bn3.weight, 0) 151 | elif isinstance(m, BasicBlock): 152 | nn.init.constant_(m.bn2.weight, 0) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 155 | norm_layer = self._norm_layer 156 | downsample = None 157 | previous_dilation = self.dilation 158 | if dilate: 159 | self.dilation *= stride 160 | stride = 1 161 | if stride != 1 or self.inplanes != planes * block.expansion: 162 | downsample = nn.Sequential( 163 | conv1x1(self.inplanes, planes * block.expansion, self.conv_layer, stride), 164 | norm_layer(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, self.conv_layer, stride, downsample, self.groups, 169 | self.base_width, previous_dilation, norm_layer)) 170 | self.inplanes = planes * block.expansion 171 | for _ in range(1, blocks): 172 | layers.append(block(self.inplanes, planes, self.conv_layer, groups=self.groups, 173 | base_width=self.base_width, dilation=self.dilation, 174 | norm_layer=norm_layer)) 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def _forward_impl(self, x): 179 | # See note [TorchScript super()] 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = torch.flatten(x, 1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | def forward(self, x): 197 | return self._forward_impl(x) 198 | 199 | 200 | # NOTE: Only supporting default (kaiming_init) initializaition. 201 | def ResNet18(conv_layer, linear_layer, init_type, **kwargs): 202 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 203 | return ResNet(conv_layer, linear_layer, BasicBlock, [2, 2, 2, 2], **kwargs) 204 | 205 | 206 | def ResNet34(conv_layer, linear_layer, init_type, **kwargs): 207 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 208 | return ResNet(conv_layer, linear_layer, BasicBlock, [3, 4, 6, 3], **kwargs) 209 | 210 | 211 | def ResNet50(conv_layer, linear_layer, init_type, **kwargs): 212 | assert init_type == "kaiming_normal", "only supporting default init for Resnets" 213 | return ResNet(conv_layer, linear_layer, Bottleneck, [3, 4, 6, 3], **kwargs) -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | import os 6 | import math 7 | import numpy as np 8 | 9 | from models.layers import SubnetConv, SubnetLinear 10 | 11 | # TODO: avoid freezing bn_params 12 | # Some utils are borrowed from https://github.com/allenai/hidden-networks 13 | def freeze_vars(model, var_name, freeze_bn=False): 14 | """ 15 | freeze vars. If freeze_bn then only freeze batch_norm params. 16 | """ 17 | 18 | assert var_name in ["weight", "bias", "popup_scores"] 19 | for i, v in model.named_modules(): 20 | if hasattr(v, var_name): 21 | if not isinstance(v, (nn.BatchNorm2d, nn.BatchNorm2d)) or freeze_bn: 22 | if getattr(v, var_name) is not None: 23 | getattr(v, var_name).requires_grad = False 24 | 25 | 26 | def unfreeze_vars(model, var_name): 27 | assert var_name in ["weight", "bias", "popup_scores"] 28 | for i, v in model.named_modules(): 29 | if hasattr(v, var_name): 30 | if getattr(v, var_name) is not None: 31 | getattr(v, var_name).requires_grad = True 32 | 33 | 34 | def set_prune_rate_model(model, prune_rate): 35 | for _, v in model.named_modules(): 36 | if hasattr(v, "set_prune_rate"): 37 | v.set_prune_rate(prune_rate) 38 | 39 | 40 | def get_layers(layer_type): 41 | """ 42 | Returns: (conv_layer, linear_layer) 43 | """ 44 | if layer_type == "dense": 45 | return nn.Conv2d, nn.Linear 46 | elif layer_type == "subnet": 47 | return SubnetConv, SubnetLinear 48 | else: 49 | raise ValueError("Incorrect layer type") 50 | 51 | 52 | def show_gradients(model): 53 | for i, v in model.named_parameters(): 54 | print(f"variable = {i}, Gradient requires_grad = {v.requires_grad}") 55 | 56 | 57 | def snip_init(model, criterion, optimizer, train_loader, device, args): 58 | print("Using SNIP initialization") 59 | assert args.exp_mode == "pretrain" 60 | optimizer.zero_grad() 61 | # init the score with kaiming normal init 62 | for m in model.modules(): 63 | if hasattr(m, "popup_scores"): 64 | nn.init.kaiming_normal_(m.popup_scores, mode="fan_in") 65 | 66 | set_prune_rate_model(model, 1.0) 67 | unfreeze_vars(model, "popup_scores") 68 | 69 | # take a forward pass and get gradients 70 | for _, data in enumerate(train_loader): 71 | images, target = data[0].to(device), data[1].to(device) 72 | 73 | output = model(images) 74 | loss = criterion(output, target) 75 | 76 | loss.backward() 77 | break 78 | 79 | # update scores with their respective connection sensitivty 80 | for m in model.modules(): 81 | if hasattr(m, "popup_scores"): 82 | print(m.popup_scores.data) 83 | m.popup_scores.data = m.popup_scores.grad.data.abs() 84 | print(m.popup_scores.data) 85 | 86 | # update k back to args.k. 87 | set_prune_rate_model(model, args.k) 88 | freeze_vars(model, "popup_scores") 89 | 90 | 91 | def initialize_scores(model, init_type): 92 | print(f"Initialization relevance score with {init_type} initialization") 93 | for m in model.modules(): 94 | if hasattr(m, "popup_scores"): 95 | if init_type == "kaiming_uniform": 96 | nn.init.kaiming_uniform_(m.popup_scores) 97 | elif init_type == "kaiming_normal": 98 | nn.init.kaiming_normal_(m.popup_scores) 99 | elif init_type == "xavier_uniform": 100 | nn.init.xavier_uniform_( 101 | m.popup_scores, gain=nn.init.calculate_gain("relu") 102 | ) 103 | elif init_type == "xavier_normal": 104 | nn.init.xavier_normal_( 105 | m.popup_scores, gain=nn.init.calculate_gain("relu") 106 | ) 107 | 108 | 109 | def initialize_scaled_score(model): 110 | print( 111 | "Initialization relevance score proportional to weight magnitudes (OVERWRITING SOURCE NET SCORES)" 112 | ) 113 | for m in model.modules(): 114 | if hasattr(m, "popup_scores"): 115 | n = nn.init._calculate_correct_fan(m.popup_scores, "fan_in") 116 | # Close to kaiming unifrom init 117 | m.popup_scores.data = ( 118 | math.sqrt(6 / n) * m.weight.data / torch.max(torch.abs(m.weight.data)) 119 | ) 120 | 121 | 122 | def scale_rand_init(model, k): 123 | print( 124 | f"Initializating random weight with scaling by 1/sqrt({k}) | Only applied to CONV & FC layers" 125 | ) 126 | for m in model.modules(): 127 | if isinstance(m, (nn.Conv2d, nn.Linear)): 128 | # print(f"previous std = {torch.std(m.weight.data)}") 129 | m.weight.data = 1 / math.sqrt(k) * m.weight.data 130 | # print(f"new std = {torch.std(m.weight.data)}") 131 | 132 | 133 | def prepare_model(model, args): 134 | """ 135 | 1. Set model pruning rate 136 | 2. Set gradients base of training mode. 137 | """ 138 | 139 | set_prune_rate_model(model, args.k) 140 | 141 | if args.exp_mode == "pretrain": 142 | print(f"#################### Pre-training network ####################") 143 | print(f"===>> gradient for importance_scores: None | training weights only") 144 | freeze_vars(model, "popup_scores", args.freeze_bn) 145 | unfreeze_vars(model, "weight") 146 | unfreeze_vars(model, "bias") 147 | 148 | elif args.exp_mode == "prune": 149 | print(f"#################### Pruning network ####################") 150 | print(f"===>> gradient for weights: None | training importance scores only") 151 | 152 | unfreeze_vars(model, "popup_scores") 153 | freeze_vars(model, "weight", args.freeze_bn) 154 | freeze_vars(model, "bias", args.freeze_bn) 155 | 156 | elif args.exp_mode == "finetune": 157 | print(f"#################### Fine-tuning network ####################") 158 | print( 159 | f"===>> gradient for importance_scores: None | fine-tuning important weigths only" 160 | ) 161 | freeze_vars(model, "popup_scores", args.freeze_bn) 162 | unfreeze_vars(model, "weight") 163 | unfreeze_vars(model, "bias") 164 | 165 | else: 166 | assert False, f"{args.exp_mode} mode is not supported" 167 | 168 | initialize_scores(model, args.scores_init_type) 169 | 170 | 171 | def subnet_to_dense(subnet_dict, p): 172 | """ 173 | Convert a subnet state dict (with subnet layers) to dense i.e., which can be directly 174 | loaded in network with dense layers. 175 | """ 176 | dense = {} 177 | 178 | # load dense variables 179 | for (k, v) in subnet_dict.items(): 180 | if "popup_scores" not in k: 181 | dense[k] = v 182 | 183 | # update dense variables 184 | for (k, v) in subnet_dict.items(): 185 | if "popup_scores" in k: 186 | s = torch.abs(subnet_dict[k]) 187 | 188 | out = s.clone() 189 | _, idx = s.flatten().sort() 190 | j = int((1 - p) * s.numel()) 191 | 192 | flat_out = out.flatten() 193 | flat_out[idx[:j]] = 0 194 | flat_out[idx[j:]] = 1 195 | dense[k.replace("popup_scores", "weight")] = ( 196 | subnet_dict[k.replace("popup_scores", "weight")] * out 197 | ) 198 | return dense 199 | 200 | 201 | def dense_to_subnet(model, state_dict): 202 | """ 203 | Load a dict with dense-layer in a model trained with subnet layers. 204 | """ 205 | model.load_state_dict(state_dict, strict=False) 206 | 207 | 208 | def current_model_pruned_fraction(model, result_dir, verbose=True): 209 | """ 210 | Find pruning raio per layer. Return average of them. 211 | Result_dict should correspond to the checkpoint of model. 212 | """ 213 | 214 | # load the dense models 215 | path = os.path.join(result_dir, "checkpoint_dense.pth.tar") 216 | 217 | pl = [] 218 | 219 | if os.path.exists(path): 220 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 221 | for i, v in model.named_modules(): 222 | if isinstance(v, (nn.Conv2d, nn.Linear)): 223 | if i + ".weight" in state_dict.keys(): 224 | d = state_dict[i + ".weight"].data.cpu().numpy() 225 | p = 100 * np.sum(d == 0) / np.size(d) 226 | pl.append(p) 227 | if verbose: 228 | print(i, v, p) 229 | return np.mean(pl) 230 | 231 | 232 | def sanity_check_paramter_updates(model, last_ckpt): 233 | """ 234 | Check whether weigths/popup_scores gets updated or not compared to last ckpt. 235 | ONLY does it for 1 layer (to avoid computational overhead) 236 | """ 237 | for i, v in model.named_modules(): 238 | if hasattr(v, "weight") and hasattr(v, "popup_scores"): 239 | if getattr(v, "weight") is not None: 240 | w1 = getattr(v, "weight").data.cpu() 241 | w2 = last_ckpt[i + ".weight"].data.cpu() 242 | if getattr(v, "popup_scores") is not None: 243 | s1 = getattr(v, "popup_scores").data.cpu() 244 | s2 = last_ckpt[i + ".popup_scores"].data.cpu() 245 | return not torch.allclose(w1, w2), not torch.allclose(s1, s2) 246 | 247 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | # Inherited from https://github.com/yaodongyu/TRADES/blob/master/train_trades_cifar10.py 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="PyTorch Training") 7 | 8 | # primary 9 | parser.add_argument( 10 | "--configs", type=str, default="", help="configs file", 11 | ) 12 | parser.add_argument( 13 | "--result-dir", 14 | default="./trained_models", 15 | type=str, 16 | help="directory to save results", 17 | ) 18 | parser.add_argument( 19 | "--exp-name", 20 | type=str, 21 | help="Name of the experiment (creates dir with this name in --result-dir)", 22 | ) 23 | 24 | parser.add_argument( 25 | "--exp-mode", 26 | type=str, 27 | choices=("pretrain", "prune", "finetune"), 28 | help="Train networks following one of these methods.", 29 | ) 30 | 31 | # Model 32 | parser.add_argument("--arch", type=str, help="Model achitecture") 33 | parser.add_argument( 34 | "--num-classes", 35 | type=int, 36 | default=10, 37 | help="Number of output classes in the model", 38 | ) 39 | parser.add_argument( 40 | "--layer-type", type=str, choices=("dense", "subnet"), help="dense | subnet" 41 | ) 42 | parser.add_argument( 43 | "--init_type", 44 | choices=("kaiming_normal", "kaiming_uniform", "signed_const"), 45 | help="Which init to use for weight parameters: kaiming_normal | kaiming_uniform | signed_const", 46 | ) 47 | 48 | # Pruning 49 | parser.add_argument( 50 | "--snip-init", 51 | action="store_true", 52 | default=False, 53 | help="Whether implemnet snip init", 54 | ) 55 | 56 | parser.add_argument( 57 | "--k", 58 | type=float, 59 | default=1.0, 60 | help="Fraction of weight variables kept in subnet", 61 | ) 62 | 63 | parser.add_argument( 64 | "--scaled-score-init", 65 | action="store_true", 66 | default=False, 67 | help="Init importance scores proportaional to weights (default kaiming init)", 68 | ) 69 | 70 | parser.add_argument( 71 | "--scale_rand_init", 72 | action="store_true", 73 | default=False, 74 | help="Init weight with scaling using pruning ratio", 75 | ) 76 | 77 | parser.add_argument( 78 | "--freeze-bn", 79 | action="store_true", 80 | default=False, 81 | help="freeze batch-norm parameters in pruning", 82 | ) 83 | 84 | parser.add_argument( 85 | "--source-net", 86 | type=str, 87 | default="", 88 | help="Checkpoint which will be pruned/fine-tuned", 89 | ) 90 | 91 | # Semi-supervision dataset setting 92 | parser.add_argument( 93 | "--is-semisup", 94 | action="store_true", 95 | default=False, 96 | help="Use semisupervised training", 97 | ) 98 | 99 | parser.add_argument( 100 | "--semisup-data", 101 | type=str, 102 | choices=("tinyimages", "splitgan"), 103 | help="Name for semi-supervision dataset", 104 | ) 105 | 106 | parser.add_argument( 107 | "--semisup-fraction", 108 | type=float, 109 | default=0.25, 110 | help="Fraction of images used in training from semisup dataset", 111 | ) 112 | 113 | # Randomized smoothing 114 | parser.add_argument( 115 | "--noise-std", 116 | type=float, 117 | default=0.25, 118 | help="Std of normal distribution used to generate noise", 119 | ) 120 | 121 | #parser.add_argument( 122 | # "--scale_rand_init", 123 | # action="store_true", 124 | # default=False, 125 | # help="Init weight with scaling using pruning ratio", 126 | #) 127 | 128 | parser.add_argument( 129 | "--scores_init_type", 130 | choices=("kaiming_normal", "kaiming_uniform", "xavier_uniform", "xavier_normal"), 131 | help="Which init to use for relevance scores", 132 | ) 133 | 134 | # Data 135 | parser.add_argument( 136 | "--dataset", 137 | type=str, 138 | choices=("CIFAR10", "CIFAR100", "SVHN", "MNIST", "imagenet"), 139 | help="Dataset for training and eval", 140 | ) 141 | parser.add_argument( 142 | "--batch-size", 143 | type=int, 144 | default=128, 145 | metavar="N", 146 | help="input batch size for training (default: 128)", 147 | ) 148 | parser.add_argument( 149 | "--test-batch-size", 150 | type=int, 151 | default=128, 152 | metavar="N", 153 | help="input batch size for testing (default: 128)", 154 | ) 155 | parser.add_argument( 156 | "--normalize", 157 | action="store_true", 158 | default=False, 159 | help="whether to normalize the data", 160 | ) 161 | parser.add_argument( 162 | "--data-dir", type=str, default="./datasets", help="path to datasets" 163 | ) 164 | 165 | parser.add_argument( 166 | "--data-fraction", 167 | type=float, 168 | default=1.0, 169 | help="Fraction of images used from training set", 170 | ) 171 | parser.add_argument( 172 | "--image-dim", type=int, default=32, help="Image size: dim x dim x 3" 173 | ) 174 | parser.add_argument( 175 | "--mean", type=tuple, default=(0, 0, 0), help="Mean for data normalization" 176 | ) 177 | parser.add_argument( 178 | "--std", type=tuple, default=(1, 1, 1), help="Std for data normalization" 179 | ) 180 | 181 | # Training 182 | parser.add_argument( 183 | "--trainer", 184 | type=str, 185 | default="base", 186 | choices=("base", "adv", "mixtrain", "crown-ibp", "smooth", "freeadv"), 187 | help="Natural (base) or adversarial or verifiable training", 188 | ) 189 | parser.add_argument( 190 | "--epochs", type=int, default=100, metavar="N", help="number of epochs to train" 191 | ) 192 | parser.add_argument( 193 | "--optimizer", type=str, default="sgd", choices=("sgd", "adam", "rmsprop") 194 | ) 195 | parser.add_argument("--wd", default=1e-4, type=float, help="Weight decay") 196 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate") 197 | parser.add_argument( 198 | "--lr-schedule", 199 | type=str, 200 | default="cosine", 201 | choices=("step", "cosine"), 202 | help="Learning rate schedule", 203 | ) 204 | parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum") 205 | parser.add_argument( 206 | "--warmup-epochs", type=int, default=0, help="Number of warmup epochs" 207 | ) 208 | parser.add_argument( 209 | "--warmup-lr", type=float, default=0.1, help="warmup learning rate" 210 | ) 211 | parser.add_argument( 212 | "--save-dense", 213 | action="store_true", 214 | default=False, 215 | help="Save dense model alongwith subnets.", 216 | ) 217 | 218 | # Free-adv training (only for imagenet) 219 | parser.add_argument( 220 | "--n-repeats", 221 | type=int, 222 | default=4, 223 | help="--number of repeats in free-adv training", 224 | ) 225 | 226 | # Adversarial attacks 227 | parser.add_argument("--epsilon", default=8.0 / 255, type=float, help="perturbation") 228 | parser.add_argument( 229 | "--num-steps", default=10, type=int, help="perturb number of steps" 230 | ) 231 | parser.add_argument( 232 | "--step-size", default=2.0 / 255, type=float, help="perturb step size" 233 | ) 234 | parser.add_argument("--clip-min", default=0, type=float, help="perturb step size") 235 | parser.add_argument("--clip-max", default=1.0, type=float, help="perturb step size") 236 | parser.add_argument( 237 | "--distance", 238 | type=str, 239 | default="l_inf", 240 | choices=("l_inf", "l_2"), 241 | help="attack distance metric", 242 | ) 243 | parser.add_argument( 244 | "--const-init", 245 | action="store_true", 246 | default=False, 247 | help="use random initialization of epsilon for attacks", 248 | ) 249 | parser.add_argument( 250 | "--beta", 251 | default=6.0, 252 | type=float, 253 | help="regularization, i.e., 1/lambda in TRADES", 254 | ) 255 | 256 | # Evaluate 257 | parser.add_argument( 258 | "--evaluate", action="store_true", default=False, help="Evaluate model" 259 | ) 260 | 261 | parser.add_argument( 262 | "--val_method", 263 | type=str, 264 | default="base", 265 | choices=("base", "adv", "mixtrain", "ibp", "smooth", "freeadv"), 266 | help="base: evaluation on unmodified inputs | adv: evaluate on adversarial inputs", 267 | ) 268 | 269 | # Restart 270 | parser.add_argument( 271 | "--start-epoch", 272 | type=int, 273 | default=0, 274 | help="manual start epoch (useful in restarts)", 275 | ) 276 | parser.add_argument( 277 | "--resume", 278 | type=str, 279 | default="", 280 | help="path to latest checkpoint (default:None)", 281 | ) 282 | 283 | # Additional 284 | parser.add_argument( 285 | "--gpu", type=str, default="0", help="Comma separated list of GPU ids" 286 | ) 287 | parser.add_argument( 288 | "--no-cuda", action="store_true", default=False, help="disables CUDA training" 289 | ) 290 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 291 | parser.add_argument( 292 | "--print-freq", 293 | type=int, 294 | default=100, 295 | help="Number of batches to wait before printing training logs", 296 | ) 297 | 298 | parser.add_argument( 299 | "--schedule_length", 300 | type=int, 301 | default=0, 302 | help="Number of epochs to schedule the training epsilon.", 303 | ) 304 | 305 | parser.add_argument( 306 | "--mixtraink", 307 | type=int, 308 | default=1, 309 | help="Number of samples out of a batch to train with sym in mixtrain.", 310 | ) 311 | 312 | return parser.parse_args() 313 | -------------------------------------------------------------------------------- /models/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, input): 8 | return input.view(input.size(0), -1) 9 | 10 | 11 | # lin_i: i layer linear feedforard network. 12 | def lin_1(input_dim=3072, num_classes=10): 13 | model = nn.Sequential(nn.Flatten(), nn.Linear(input_dim, num_classes)) 14 | return model 15 | 16 | 17 | def lin_2(input_dim=3072, hidden_dim=100, num_classes=10): 18 | model = nn.Sequential( 19 | nn.Flatten(), 20 | nn.Linear(input_dim, hidden_dim), 21 | nn.Linear(hidden_dim, num_classes), 22 | ) 23 | return model 24 | 25 | 26 | def lin_3(input_dim=3072, hidden_dim=100, num_classes=10): 27 | model = nn.Sequential( 28 | nn.Flatten(), 29 | nn.Linear(input_dim, hidden_dim), 30 | nn.Linear(hidden_dim, hidden_dim), 31 | nn.Linear(hidden_dim, num_classes), 32 | ) 33 | return model 34 | 35 | 36 | def lin_4(input_dim=3072, hidden_dim=100, num_classes=10): 37 | model = nn.Sequential( 38 | nn.Flatten(), 39 | nn.Linear(input_dim, hidden_dim), 40 | nn.Linear(hidden_dim, hidden_dim), 41 | nn.Linear(hidden_dim, num_classes), 42 | ) 43 | return model 44 | 45 | 46 | def mnist_model(conv_layer, linear_layer, init_type, **kwargs): 47 | assert init_type == "kaiming_normal", "only supporting kaiming_normal init" 48 | model = nn.Sequential( 49 | conv_layer(1, 16, 4, stride=2, padding=1), 50 | nn.ReLU(), 51 | conv_layer(16, 32, 4, stride=2, padding=1), 52 | nn.ReLU(), 53 | Flatten(), 54 | linear_layer(32 * 7 * 7, 100), 55 | nn.ReLU(), 56 | linear_layer(100, 10), 57 | ) 58 | return model 59 | 60 | 61 | def mnist_model_large(conv_layer, linear_layer, init_type, **kwargs): 62 | assert init_type == "kaiming_normal", "only supporting kaiming_normal init" 63 | model = nn.Sequential( 64 | conv_layer(1, 32, 3, stride=1, padding=1), 65 | nn.ReLU(), 66 | conv_layer(32, 32, 4, stride=2, padding=1), 67 | nn.ReLU(), 68 | conv_layer(32, 64, 3, stride=1, padding=1), 69 | nn.ReLU(), 70 | conv_layer(64, 64, 4, stride=2, padding=1), 71 | nn.ReLU(), 72 | Flatten(), 73 | linear_layer(64 * 7 * 7, 512), 74 | nn.ReLU(), 75 | linear_layer(512, 512), 76 | nn.ReLU(), 77 | linear_layer(512, 10), 78 | ) 79 | return model 80 | 81 | 82 | def cifar_model(conv_layer, linear_layer, init_type, **kwargs): 83 | assert init_type == "kaiming_normal", "only supporting kaiming_normal init" 84 | model = nn.Sequential( 85 | conv_layer(3, 16, 4, stride=2, padding=1), 86 | nn.ReLU(), 87 | conv_layer(16, 32, 4, stride=2, padding=1), 88 | nn.ReLU(), 89 | Flatten(), 90 | linear_layer(32 * 8 * 8, 100), 91 | nn.ReLU(), 92 | linear_layer(100, 10), 93 | ) 94 | for m in model.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 98 | m.bias.data.zero_() 99 | return model 100 | 101 | 102 | def cifar_model_large(conv_layer, linear_layer, init_type, **kwargs): 103 | assert init_type == "kaiming_normal", "only supporting kaiming_normal init" 104 | model = nn.Sequential( 105 | conv_layer(3, 32, 3, stride=1, padding=1), 106 | nn.ReLU(), 107 | conv_layer(32, 32, 4, stride=2, padding=1), 108 | nn.ReLU(), 109 | conv_layer(32, 64, 3, stride=1, padding=1), 110 | nn.ReLU(), 111 | conv_layer(64, 64, 4, stride=2, padding=1), 112 | nn.ReLU(), 113 | Flatten(), 114 | linear_layer(64 * 8 * 8, 512), 115 | nn.ReLU(), 116 | linear_layer(512, 512), 117 | nn.ReLU(), 118 | linear_layer(512, 10), 119 | ) 120 | for m in model.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 124 | m.bias.data.zero_() 125 | return model 126 | 127 | 128 | class DenseSequential(nn.Sequential): 129 | def forward(self, x): 130 | xs = [x] 131 | for module in self._modules.values(): 132 | if "Dense" in type(module).__name__: 133 | xs.append(module(*xs)) 134 | else: 135 | xs.append(module(xs[-1])) 136 | return xs[-1] 137 | 138 | 139 | class Dense(nn.Module): 140 | def __init__(self, *Ws): 141 | super(Dense, self).__init__() 142 | self.Ws = nn.ModuleList(list(Ws)) 143 | if len(Ws) > 0 and hasattr(Ws[0], "out_features"): 144 | self.out_features = Ws[0].out_features 145 | 146 | def forward(self, *xs): 147 | xs = xs[-len(self.Ws) :] 148 | out = sum(W(x) for x, W in zip(xs, self.Ws) if W is not None) 149 | return out 150 | 151 | 152 | def cifar_model_resnet(conv_layer, linear_layer, init_type, N=5, factor=1, **kwargs): 153 | def block(in_filters, out_filters, k, downsample): 154 | if not downsample: 155 | k_first = 3 156 | skip_stride = 1 157 | k_skip = 1 158 | else: 159 | k_first = 4 160 | skip_stride = 2 161 | k_skip = 2 162 | return [ 163 | Dense( 164 | conv_layer( 165 | in_filters, out_filters, k_first, stride=skip_stride, padding=1 166 | ) 167 | ), 168 | nn.ReLU(), 169 | Dense( 170 | conv_layer( 171 | in_filters, out_filters, k_skip, stride=skip_stride, padding=0 172 | ), 173 | None, 174 | conv_layer(out_filters, out_filters, k, stride=1, padding=1), 175 | ), 176 | nn.ReLU(), 177 | ] 178 | 179 | conv1 = [conv_layer(3, 16, 3, stride=1, padding=1), nn.ReLU()] 180 | conv2 = block(16, 16 * factor, 3, False) 181 | for _ in range(N): 182 | conv2.extend(block(16 * factor, 16 * factor, 3, False)) 183 | conv3 = block(16 * factor, 32 * factor, 3, True) 184 | for _ in range(N - 1): 185 | conv3.extend(block(32 * factor, 32 * factor, 3, False)) 186 | conv4 = block(32 * factor, 64 * factor, 3, True) 187 | for _ in range(N - 1): 188 | conv4.extend(block(64 * factor, 64 * factor, 3, False)) 189 | layers = ( 190 | conv1 191 | + conv2 192 | + conv3 193 | + conv4 194 | + [ 195 | Flatten(), 196 | linear_layer(64 * factor * 8 * 8, 1000), 197 | nn.ReLU(), 198 | linear_layer(1000, 10), 199 | ] 200 | ) 201 | model = DenseSequential(*layers) 202 | 203 | for m in model.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 206 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 207 | if m.bias is not None: 208 | m.bias.data.zero_() 209 | return model 210 | 211 | 212 | def vgg4_without_maxpool(conv_layer, linear_layer, init_type, **kwargs): 213 | assert init_type == "kaiming_normal", "only supporting kaiming_normal init" 214 | model = nn.Sequential( 215 | conv_layer(3, 64, 3, stride=1, padding=1), 216 | nn.ReLU(), 217 | conv_layer(64, 64, 3, stride=2, padding=1), 218 | nn.ReLU(), 219 | conv_layer(64, 128, 3, stride=1, padding=1), 220 | nn.ReLU(), 221 | conv_layer(128, 128, 3, stride=2, padding=1), 222 | nn.ReLU(), 223 | Flatten(), 224 | linear_layer(128 * 8 * 8, 256), 225 | nn.ReLU(), 226 | linear_layer(256, 256), 227 | nn.ReLU(), 228 | linear_layer(256, 10), 229 | ) 230 | for m in model.modules(): 231 | if isinstance(m, nn.Conv2d): 232 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 233 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 234 | m.bias.data.zero_() 235 | return model 236 | 237 | 238 | def cifar_model_resnet(N=5, factor=10): 239 | def block(in_filters, out_filters, k, downsample): 240 | if not downsample: 241 | k_first = 3 242 | skip_stride = 1 243 | k_skip = 1 244 | else: 245 | k_first = 4 246 | skip_stride = 2 247 | k_skip = 2 248 | return [ 249 | Dense( 250 | nn.Conv2d( 251 | in_filters, out_filters, k_first, stride=skip_stride, padding=1 252 | ) 253 | ), 254 | nn.ReLU(), 255 | Dense( 256 | nn.Conv2d( 257 | in_filters, out_filters, k_skip, stride=skip_stride, padding=0 258 | ), 259 | None, 260 | nn.Conv2d(out_filters, out_filters, k, stride=1, padding=1), 261 | ), 262 | nn.ReLU(), 263 | ] 264 | 265 | conv1 = [nn.Conv2d(3, 16, 3, stride=1, padding=1), nn.ReLU()] 266 | conv2 = block(16, 16 * factor, 3, False) 267 | for _ in range(N): 268 | conv2.extend(block(16 * factor, 16 * factor, 3, False)) 269 | conv3 = block(16 * factor, 32 * factor, 3, True) 270 | for _ in range(N - 1): 271 | conv3.extend(block(32 * factor, 32 * factor, 3, False)) 272 | conv4 = block(32 * factor, 64 * factor, 3, True) 273 | for _ in range(N - 1): 274 | conv4.extend(block(64 * factor, 64 * factor, 3, False)) 275 | layers = ( 276 | conv1 277 | + conv2 278 | + conv3 279 | + conv4 280 | + [ 281 | Flatten(), 282 | nn.Linear(64 * factor * 8 * 8, 1000), 283 | nn.ReLU(), 284 | nn.Linear(1000, 10), 285 | ] 286 | ) 287 | model = DenseSequential(*layers) 288 | 289 | for m in model.modules(): 290 | if isinstance(m, nn.Conv2d): 291 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 292 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 293 | if m.bias is not None: 294 | m.bias.data.zero_() 295 | return model 296 | -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | # Some part borrowed from official tutorial https://github.com/pytorch/examples/blob/master/imagenet/main.py 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import argparse 9 | import importlib 10 | import time 11 | import logging 12 | from pathlib import Path 13 | import copy 14 | import math 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.utils.data.dataset import Dataset 19 | from torchvision import datasets, transforms 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | import models 23 | import data 24 | from args import parse_args 25 | from utils.schedules import get_lr_policy, get_optimizer 26 | from utils.logging import ( 27 | save_checkpoint, 28 | create_subdirs, 29 | parse_configs_file, 30 | clone_results_to_latest_subdir, 31 | ) 32 | from utils.semisup import get_semisup_dataloader 33 | from utils.model import ( 34 | get_layers, 35 | prepare_model, 36 | initialize_scaled_score, 37 | show_gradients, 38 | current_model_pruned_fraction, 39 | sanity_check_paramter_updates, 40 | ) 41 | 42 | 43 | # TODO: update wrn, resnet models. Save both subnet and dense version. 44 | # TODO: take care of BN, bias in pruning, support structured pruning 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | parse_configs_file(args) 50 | 51 | # sanity checks 52 | if args.exp_mode in ["prune", "finetune"] and not args.resume: 53 | assert args.source_net, "Provide checkpoint to prune/finetune" 54 | 55 | # create resutls dir (for logs, checkpoints, etc.) 56 | result_main_dir = os.path.join(Path(args.result_dir), args.exp_name, args.exp_mode) 57 | 58 | if os.path.exists(result_main_dir): 59 | n = len(next(os.walk(result_main_dir))[-2]) # prev experiments with same name 60 | result_sub_dir = os.path.join( 61 | result_main_dir, 62 | "{}--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format( 63 | n + 1, 64 | args.k, 65 | args.trainer, 66 | args.lr, 67 | args.epochs, 68 | args.warmup_lr, 69 | args.warmup_epochs, 70 | ), 71 | ) 72 | else: 73 | os.makedirs(result_main_dir, exist_ok=True) 74 | result_sub_dir = os.path.join( 75 | result_main_dir, 76 | "1--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format( 77 | args.k, 78 | args.trainer, 79 | args.lr, 80 | args.epochs, 81 | args.warmup_lr, 82 | args.warmup_epochs, 83 | ), 84 | ) 85 | create_subdirs(result_sub_dir) 86 | 87 | # add logger 88 | logging.basicConfig(level=logging.INFO, format="%(message)s") 89 | logger = logging.getLogger() 90 | logger.addHandler( 91 | logging.FileHandler(os.path.join(result_sub_dir, "setup.log"), "a") 92 | ) 93 | logger.info(args) 94 | 95 | # seed cuda 96 | torch.manual_seed(args.seed) 97 | torch.cuda.manual_seed(args.seed) 98 | torch.cuda.manual_seed_all(args.seed) 99 | np.random.seed(args.seed) 100 | 101 | # Select GPUs 102 | use_cuda = not args.no_cuda and torch.cuda.is_available() 103 | gpu_list = [int(i) for i in args.gpu.strip().split(",")] 104 | device = torch.device(f"cuda:{gpu_list[0]}" if use_cuda else "cpu") 105 | 106 | # Create model 107 | cl, ll = get_layers(args.layer_type) 108 | if len(gpu_list) > 1: 109 | print("Using multiple GPUs") 110 | model = nn.DataParallel( 111 | models.__dict__[args.arch]( 112 | cl, ll, args.init_type, num_classes=args.num_classes 113 | ), 114 | gpu_list, 115 | ).to(device) 116 | else: 117 | model = models.__dict__[args.arch]( 118 | cl, ll, args.init_type, num_classes=args.num_classes 119 | ).to(device) 120 | logger.info(model) 121 | 122 | # Customize models for training/pruning/fine-tuning 123 | prepare_model(model, args) 124 | 125 | # Setup tensorboard writer 126 | writer = SummaryWriter(os.path.join(result_sub_dir, "tensorboard")) 127 | 128 | # Dataloader 129 | D = data.__dict__[args.dataset](args, normalize=args.normalize) 130 | train_loader, test_loader = D.data_loaders() 131 | 132 | logger.info(args.dataset, D, len(train_loader.dataset), len(test_loader.dataset)) 133 | 134 | # Semi-sup dataloader 135 | if args.is_semisup: 136 | logger.info("Using semi-supervised training") 137 | sm_loader = get_semisup_dataloader(args, D.tr_train) 138 | else: 139 | sm_loader = None 140 | 141 | # autograd 142 | criterion = nn.CrossEntropyLoss() 143 | optimizer = get_optimizer(model, args) 144 | lr_policy = get_lr_policy(args.lr_schedule)(optimizer, args) 145 | logger.info([criterion, optimizer, lr_policy]) 146 | 147 | # train & val method 148 | trainer = importlib.import_module(f"trainer.{args.trainer}").train 149 | val = getattr(importlib.import_module("utils.eval"), args.val_method) 150 | 151 | # Load source_net (if checkpoint provided). Only load the state_dict (required for pruning and fine-tuning) 152 | if args.source_net: 153 | if os.path.isfile(args.source_net): 154 | logger.info("=> loading source model from '{}'".format(args.source_net)) 155 | checkpoint = torch.load(args.source_net, map_location=device) 156 | model.load_state_dict( 157 | checkpoint["state_dict"], strict=False 158 | ) # allows loading dense models 159 | logger.info("=> loaded checkpoint '{}'".format(args.source_net)) 160 | else: 161 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 162 | 163 | # Init scores once source net is loaded. 164 | # NOTE: scaled_init_scores will overwrite the scores in the pre-trained net. 165 | if args.scaled_score_init: 166 | initialize_scaled_score(model) 167 | 168 | assert not (args.source_net and args.resume), ( 169 | "Incorrect setup: " 170 | "resume => required to resume a previous experiment (loads all parameters)|| " 171 | "source_net => required to start pruning/fine-tuning from a source model (only load state_dict)" 172 | ) 173 | # resume (if checkpoint provided). Continue training with preiovus settings. 174 | if args.resume: 175 | if os.path.isfile(args.resume): 176 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 177 | checkpoint = torch.load(args.resume) 178 | args.start_epoch = checkpoint["epoch"] 179 | best_prec1 = checkpoint["best_prec1"] 180 | model.load_state_dict(checkpoint["state_dict"]) 181 | optimizer.load_state_dict(checkpoint["optimizer"]) 182 | logger.info( 183 | "=> loaded checkpoint '{}' (epoch {})".format( 184 | args.resume, checkpoint["epoch"] 185 | ) 186 | ) 187 | else: 188 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 189 | 190 | if args.evaluate: 191 | p1, _ = val(model, device, test_loader, criterion, args, writer) 192 | logger.info(f"Validation accuracy {args.val_method} for source-net: {p1}") 193 | if args.evaluate: 194 | return 195 | 196 | best_prec1 = 0 197 | args.epochs = int(math.ceil(args.epochs / args.n_repeats)) 198 | logger.info("New logs") 199 | logger.info(args) 200 | 201 | show_gradients(model) 202 | 203 | # Do not select source-net as last checkpoint as it might even be a dense model. 204 | # Most other function won't works well with a dense layer checkpoint. 205 | last_ckpt = copy.deepcopy(model.state_dict()) 206 | 207 | # Start training 208 | for epoch in range(args.start_epoch, args.epochs + args.warmup_epochs): 209 | lr_policy(epoch) # adjust learning rate 210 | 211 | # train 212 | trainer( 213 | model, 214 | device, 215 | train_loader, 216 | sm_loader, 217 | criterion, 218 | optimizer, 219 | epoch, 220 | args, 221 | writer, 222 | ) 223 | 224 | # evaluate on test set 225 | print("Evaluating on a weaker attack to save time!") 226 | if args.val_method == "smooth": 227 | prec1, radii = val( 228 | model, device, test_loader, criterion, args, writer, epoch 229 | ) 230 | logger.info(f"Epoch {epoch}, mean provable Radii {radii}") 231 | else: 232 | prec1, _ = val(model, device, test_loader, criterion, args, writer, epoch) 233 | 234 | # remember best prec@1 and save checkpoint 235 | is_best = prec1 > best_prec1 236 | best_prec1 = max(prec1, best_prec1) 237 | save_checkpoint( 238 | { 239 | "epoch": epoch + 1, 240 | "arch": args.arch, 241 | "state_dict": model.state_dict(), 242 | "best_prec1": best_prec1, 243 | "optimizer": optimizer.state_dict(), 244 | }, 245 | is_best, 246 | args, 247 | result_dir=os.path.join(result_sub_dir, "checkpoint"), 248 | save_dense=args.save_dense, 249 | ) 250 | 251 | logger.info( 252 | f"Epoch {epoch}, val-method {args.val_method}, validation accuracy {prec1}, best_prec {best_prec1}" 253 | ) 254 | if args.exp_mode in ["prune", "finetune"]: 255 | logger.info( 256 | "Pruned model: {:.2f}%".format( 257 | current_model_pruned_fraction( 258 | model, os.path.join(result_sub_dir, "checkpoint"), verbose=False 259 | ) 260 | ) 261 | ) 262 | # clone results to latest subdir (sync after every epoch) 263 | # Latest_subdir: stores results from latest run of an experiment. 264 | clone_results_to_latest_subdir( 265 | result_sub_dir, os.path.join(result_main_dir, "latest_exp") 266 | ) 267 | 268 | # Check what parameters got updated in the current epoch. 269 | sw, ss = sanity_check_paramter_updates(model, last_ckpt) 270 | logger.info( 271 | f"Sanity check (exp-mode: {args.exp_mode}): Weight update - {sw}, Scores update - {ss}" 272 | ) 273 | 274 | current_model_pruned_fraction( 275 | model, os.path.join(result_sub_dir, "checkpoint"), verbose=True 276 | ) 277 | 278 | 279 | if __name__ == "__main__": 280 | main() 281 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Some part borrowed from official tutorial https://github.com/pytorch/examples/blob/master/imagenet/main.py 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import argparse 9 | import importlib 10 | import time 11 | import logging 12 | from pathlib import Path 13 | import copy 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.utils.data.dataset import Dataset 18 | from torchvision import datasets, transforms 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | import models 22 | import data 23 | from args import parse_args 24 | from utils.schedules import get_lr_policy, get_optimizer 25 | from utils.logging import ( 26 | save_checkpoint, 27 | create_subdirs, 28 | parse_configs_file, 29 | clone_results_to_latest_subdir, 30 | ) 31 | from utils.semisup import get_semisup_dataloader 32 | from utils.model import ( 33 | get_layers, 34 | prepare_model, 35 | initialize_scaled_score, 36 | scale_rand_init, 37 | show_gradients, 38 | current_model_pruned_fraction, 39 | sanity_check_paramter_updates, 40 | snip_init, 41 | ) 42 | 43 | 44 | # TODO: update wrn, resnet models. Save both subnet and dense version. 45 | # TODO: take care of BN, bias in pruning, support structured pruning 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | parse_configs_file(args) 51 | 52 | # sanity checks 53 | if args.exp_mode in ["prune", "finetune"] and not args.resume: 54 | assert args.source_net, "Provide checkpoint to prune/finetune" 55 | 56 | # create resutls dir (for logs, checkpoints, etc.) 57 | result_main_dir = os.path.join(Path(args.result_dir), args.exp_name, args.exp_mode) 58 | 59 | if os.path.exists(result_main_dir): 60 | n = len(next(os.walk(result_main_dir))[-2]) # prev experiments with same name 61 | result_sub_dir = os.path.join( 62 | result_main_dir, 63 | "{}--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format( 64 | n + 1, 65 | args.k, 66 | args.trainer, 67 | args.lr, 68 | args.epochs, 69 | args.warmup_lr, 70 | args.warmup_epochs, 71 | ), 72 | ) 73 | else: 74 | os.makedirs(result_main_dir, exist_ok=True) 75 | result_sub_dir = os.path.join( 76 | result_main_dir, 77 | "1--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format( 78 | args.k, 79 | args.trainer, 80 | args.lr, 81 | args.epochs, 82 | args.warmup_lr, 83 | args.warmup_epochs, 84 | ), 85 | ) 86 | create_subdirs(result_sub_dir) 87 | 88 | # add logger 89 | logging.basicConfig(level=logging.INFO, format="%(message)s") 90 | logger = logging.getLogger() 91 | logger.addHandler( 92 | logging.FileHandler(os.path.join(result_sub_dir, "setup.log"), "a") 93 | ) 94 | logger.info(args) 95 | 96 | # seed cuda 97 | torch.manual_seed(args.seed) 98 | torch.cuda.manual_seed(args.seed) 99 | torch.cuda.manual_seed_all(args.seed) 100 | np.random.seed(args.seed) 101 | 102 | # Select GPUs 103 | use_cuda = not args.no_cuda and torch.cuda.is_available() 104 | gpu_list = [int(i) for i in args.gpu.strip().split(",")] 105 | device = torch.device(f"cuda:{gpu_list[0]}" if use_cuda else "cpu") 106 | 107 | # Create model 108 | cl, ll = get_layers(args.layer_type) 109 | if len(gpu_list) > 1: 110 | print("Using multiple GPUs") 111 | model = nn.DataParallel( 112 | models.__dict__[args.arch]( 113 | cl, ll, args.init_type, num_classes=args.num_classes 114 | ), 115 | gpu_list, 116 | ).to(device) 117 | else: 118 | model = models.__dict__[args.arch]( 119 | cl, ll, args.init_type, num_classes=args.num_classes 120 | ).to(device) 121 | logger.info(model) 122 | 123 | # Customize models for training/pruning/fine-tuning 124 | prepare_model(model, args) 125 | 126 | # Setup tensorboard writer 127 | writer = SummaryWriter(os.path.join(result_sub_dir, "tensorboard")) 128 | 129 | # Dataloader 130 | D = data.__dict__[args.dataset](args, normalize=args.normalize) 131 | train_loader, test_loader = D.data_loaders() 132 | 133 | logger.info(args.dataset, D, len(train_loader.dataset), len(test_loader.dataset)) 134 | 135 | # Semi-sup dataloader 136 | if args.is_semisup: 137 | logger.info("Using semi-supervised training") 138 | sm_loader = get_semisup_dataloader(args, D.tr_train) 139 | else: 140 | sm_loader = None 141 | 142 | # autograd 143 | criterion = nn.CrossEntropyLoss() 144 | optimizer = get_optimizer(model, args) 145 | lr_policy = get_lr_policy(args.lr_schedule)(optimizer, args) 146 | logger.info([criterion, optimizer, lr_policy]) 147 | 148 | # train & val method 149 | trainer = importlib.import_module(f"trainer.{args.trainer}").train 150 | val = getattr(importlib.import_module("utils.eval"), args.val_method) 151 | 152 | # Load source_net (if checkpoint provided). Only load the state_dict (required for pruning and fine-tuning) 153 | if args.source_net: 154 | if os.path.isfile(args.source_net): 155 | logger.info("=> loading source model from '{}'".format(args.source_net)) 156 | checkpoint = torch.load(args.source_net, map_location=device) 157 | model.load_state_dict(checkpoint["state_dict"]) 158 | logger.info("=> loaded checkpoint '{}'".format(args.source_net)) 159 | else: 160 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 161 | 162 | # Init scores once source net is loaded. 163 | # NOTE: scaled_init_scores will overwrite the scores in the pre-trained net. 164 | if args.scaled_score_init: 165 | initialize_scaled_score(model) 166 | 167 | # Scaled random initialization. Useful when training a high sparse net from scratch. 168 | # If not used, a sparse net (without batch-norm) from scratch will not coverge. 169 | # With batch-norm its not really necessary. 170 | if args.scale_rand_init: 171 | scale_rand_init(model, args.k) 172 | 173 | # Scaled random initialization. Useful when training a high sparse net from scratch. 174 | # If not used, a sparse net (without batch-norm) from scratch will not coverge. 175 | # With batch-norm its not really necessary. 176 | if args.scale_rand_init: 177 | scale_rand_init(model, args.k) 178 | 179 | if args.snip_init: 180 | snip_init(model, criterion, optimizer, train_loader, device, args) 181 | 182 | assert not (args.source_net and args.resume), ( 183 | "Incorrect setup: " 184 | "resume => required to resume a previous experiment (loads all parameters)|| " 185 | "source_net => required to start pruning/fine-tuning from a source model (only load state_dict)" 186 | ) 187 | # resume (if checkpoint provided). Continue training with preiovus settings. 188 | if args.resume: 189 | if os.path.isfile(args.resume): 190 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 191 | checkpoint = torch.load(args.resume, map_location=device) 192 | args.start_epoch = checkpoint["epoch"] 193 | best_prec1 = checkpoint["best_prec1"] 194 | model.load_state_dict(checkpoint["state_dict"]) 195 | optimizer.load_state_dict(checkpoint["optimizer"]) 196 | logger.info( 197 | "=> loaded checkpoint '{}' (epoch {})".format( 198 | args.resume, checkpoint["epoch"] 199 | ) 200 | ) 201 | else: 202 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 203 | 204 | # Evaluate 205 | if args.evaluate or args.exp_mode in ["prune", "finetune"]: 206 | p1, _ = val(model, device, test_loader, criterion, args, writer) 207 | logger.info(f"Validation accuracy {args.val_method} for source-net: {p1}") 208 | if args.evaluate: 209 | return 210 | 211 | best_prec1 = 0 212 | 213 | show_gradients(model) 214 | 215 | if args.source_net: 216 | last_ckpt = checkpoint["state_dict"] 217 | else: 218 | last_ckpt = copy.deepcopy(model.state_dict()) 219 | 220 | # Start training 221 | for epoch in range(args.start_epoch, args.epochs + args.warmup_epochs): 222 | lr_policy(epoch) # adjust learning rate 223 | 224 | # train 225 | trainer( 226 | model, 227 | device, 228 | train_loader, 229 | sm_loader, 230 | criterion, 231 | optimizer, 232 | epoch, 233 | args, 234 | writer, 235 | ) 236 | 237 | # evaluate on test set 238 | if args.val_method == "smooth": 239 | prec1, radii = val( 240 | model, device, test_loader, criterion, args, writer, epoch 241 | ) 242 | logger.info(f"Epoch {epoch}, mean provable Radii {radii}") 243 | if args.val_method == "mixtrain" and epoch <= args.schedule_length: 244 | prec1 = 0.0 245 | else: 246 | prec1, _ = val(model, device, test_loader, criterion, args, writer, epoch) 247 | 248 | # remember best prec@1 and save checkpoint 249 | is_best = prec1 > best_prec1 250 | best_prec1 = max(prec1, best_prec1) 251 | save_checkpoint( 252 | { 253 | "epoch": epoch + 1, 254 | "arch": args.arch, 255 | "state_dict": model.state_dict(), 256 | "best_prec1": best_prec1, 257 | "optimizer": optimizer.state_dict(), 258 | }, 259 | is_best, 260 | args, 261 | result_dir=os.path.join(result_sub_dir, "checkpoint"), 262 | save_dense=args.save_dense, 263 | ) 264 | 265 | logger.info( 266 | f"Epoch {epoch}, val-method {args.val_method}, validation accuracy {prec1}, best_prec {best_prec1}" 267 | ) 268 | if args.exp_mode in ["prune", "finetune"]: 269 | logger.info( 270 | "Pruned model: {:.2f}%".format( 271 | current_model_pruned_fraction( 272 | model, os.path.join(result_sub_dir, "checkpoint"), verbose=False 273 | ) 274 | ) 275 | ) 276 | # clone results to latest subdir (sync after every epoch) 277 | # Latest_subdir: stores results from latest run of an experiment. 278 | clone_results_to_latest_subdir( 279 | result_sub_dir, os.path.join(result_main_dir, "latest_exp") 280 | ) 281 | 282 | # Check what parameters got updated in the current epoch. 283 | sw, ss = sanity_check_paramter_updates(model, last_ckpt) 284 | logger.info( 285 | f"Sanity check (exp-mode: {args.exp_mode}): Weight update - {sw}, Scores update - {ss}" 286 | ) 287 | 288 | current_model_pruned_fraction( 289 | model, os.path.join(result_sub_dir, "checkpoint"), verbose=True 290 | ) 291 | 292 | 293 | if __name__ == "__main__": 294 | main() 295 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HYDRA: Pruning Adversarially Robust Neural Networks (NeurIPS 2020) 2 | 3 | Repository with code to reproduce the results and checkpoints for compressed networks in [our paper on novel pruning techniques with robust training](https://arxiv.org/abs/2002.10509). This repository supports all four robust training objectives: iterative adversarial training, randomized smoothing, MixTrain, and CROWN-IBP. 4 | 5 | Following is a snippet of key results where we showed that accounting the robust training objective in pruning strategy can lead to large gains in the robustness of pruned networks. 6 | 7 | ![results_table](/images/results_table.png) 8 | 9 | 10 | 11 | In particular, the improvement arises from letting the robust training objective controlling which connections to prune. In almost all cases, it prefers to pruned certain high-magnitude weights while preserving other small magnitude weights, which is orthogonal to the strategy in well-established least-weight magnitude (LWM) based pruning. 12 | 13 | ![weight_histogram](/images/weight_histogram.png) 14 | 15 | ## Updates 16 | **April 30, 2020**: [Checkpoints for WRN-28-10](https://www.dropbox.com/sh/56yyfy16elwbnr8/AADmr7bXgFkrNdoHjKWwIFKqa?dl=0), a common network for benchmarking adv. robustness | 90% pruned with proposed technique | Benign test accuracy = 88.97% , PGD-50 test accuracy = 62.24%. 17 | 18 | **May 23, 2020**: Our WRN-28-10 network with 90% connection pruning comes in the second place in the [auto-attack robustness benchmark](https://github.com/fra31/auto-attack). 19 | 20 | ## Getting started 21 | 22 | Let's start by installing all dependencies. 23 | 24 | `pip install -r requirement.txt` 25 | 26 | 27 | 28 | We will use `train.py` for all our experiments on the CIFAR-10 and SVHN dataset. For ImageNet, we will use `train_imagenet.py`. It provides the flexibility to work with pre-training, pruning, and Finetuning steps along with different training objectives. 29 | 30 | - `exp_mode`: select from pretrain, prune, finetune 31 | - `trainer`: benign (base), iterative adversarial training (adv), randomized smoothing (smooth), mix train, crown-imp 32 | - `--dataset`: cifar10, svhn, imagenet 33 | 34 | 35 | 36 | Following [this](https://github.com/allenai/hidden-networks) work, we modify the convolution layer to have an internal mask. We can use a masked convolution layer with `--layer-type=subnet`. The argument `k` refers to the fraction of non-pruned connections. 37 | 38 | 39 | 40 | ## Pre-training 41 | 42 | In pre-training, we train the networks with `k=1` i.e, without pruning. Following example pre-train a WRN-28-4 network with adversarial training. 43 | 44 | `python train.py --arch wrn_28_4 --exp-mode pretrain --configs configs/configs.yml --trainer adv --val_method adv --k 1.0` 45 | 46 | 47 | 48 | ## Pruning 49 | 50 | In pruning steps, we will essentially freeze weights of the network and only update the importance scores. The following command will prune the pre-trained WRN-28-4 network to 99% pruning ratio. 51 | 52 | `python train.py --arch wrn_28_4 --exp-mode prune --configs configs.yml --trainer adv --val_method adv --k 0.01 --scaled-score-init --source-net pretrained_net_checkpoint_path --epochs 20 --save-dense` 53 | 54 | It used 20 epochs to optimize for better-pruned networks following the proposed scaled initialization of importance scores. It also saves a checkpoint of pruned with dense layers i.e, throws aways masks form each layer after multiplying it with weights. These dense checkpoints are helpful as they are directly loaded in a model based on standard layers from torch.nn. 55 | 56 | 57 | 58 | ## Fine-tuning 59 | 60 | In the fine-tuning step, we will update the non-pruned weights but freeze the importance scores. For correct results, we must select the same pruning ratio as the pruning step. 61 | 62 | `python train.py --arch wrn_28_4 --exp-mode finetune --configs configs.yml --trainer adv --val_method adv --k 0.01 --source-net pruned_net_checkpoint_path --save-dense --lr 0.01` 63 | 64 | 65 | 66 | ## Least weight magnitude (LWM) based pruning 67 | 68 | We use a single shot pruning approach where we prune the desired number of connections after pre-training in a single step. After that, the network is fine-tuned with a similar configuration as above. 69 | 70 | `python train.py --arch wrn_28_4 --exp-mode finetune --configs configs.yml --trainer adv --val_method adv --k 0.01 --source-net pretrained_net_checkpoint_path --save-dense --lr 0.01 --scaled-score-init` 71 | 72 | The only difference from fine-tuning from previous steps is the now we initialized the importance scores with proposed scaling. This scheme effectively prunes the connection with the lowest magnitude at the start. Since the importance scores are not updated with fine-tuning, this will effectively implement the LWM based pruning. 73 | 74 | 75 | 76 | ## Bringing it all together 77 | 78 | We can use the following scripts to obtain compact networks from both LWM and proposed pruning techniques. 79 | 80 | - `get_compact_net_adv_train.sh`: Compact networks with iterative adversarial training. 81 | - `get_compact_net_rand_smoothing.sh` Compact networks with randomized smoothing. 82 | - `get_compact_net_mixtrain.sh` Compact networks with MixTrain. 83 | - `get_compact_net_crown-ibp.sh` Compact networks with CROWN-IBP. 84 | 85 | 86 | 87 | 88 | 89 | ## Finding robust sub-networks 90 | 91 | It is curious to ask whether pruning certain connections itself can induce robustness in a network. In particular, given a non-robust network, does there exist a highly robust subnetwork? We find that indeed there exist such robust subnetworks with a non-trivial amount of robustness. Here is an example to reproduce these results: 92 | 93 | `python train.py --arch wrn_28_4 --configs configs.yml --trainer adv --val-method adv --k 0.5 --source-net pretrained_non-robust-net_checkpoint_path` 94 | 95 | Thus, given the checkpoint path of a non-robust network, it aims to find a sub-network with half the connections but having high empirical robust accuracy. We can similarly optimize for verifiably robust accuracy by selecting `--trainer` from `smooth | mixtrain | crown-ibp`, with using proper configs for each. 96 | 97 | 98 | 99 | ## Model Zoo (checkpoints for pre-trained and compressed networks) 100 | 101 | We are releasing pruned models for all three pruning ratios (90, 95, 99%) for all three datasets used in the paper. In case you want to compare some additional property of pruned models with a baseline, we are also releasing non-pruned i.e., pre-trained networks. Note that, we use input normalization only for the ImageNet dataset. For each model, we are releasing two checkpoints: one with masked layers and other with dense layers. Note that the numbers from these checkpoints might differ a little bit from the ones reported in the paper. 102 | 103 | ### Adversarial training 104 | 105 | | Dataset | Architecture | Pre-trained (0%) | 90% pruned | 95% pruned | 99% pruned | 106 | | :-----: | :----------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 107 | | CIFAR10 | VGG16 | [ckpt](https://www.dropbox.com/sh/1037dxc9m4m6wqs/AAD62kuJRuVaoRFOto_jxKJ2a?dl=0) | [ckpt](https://www.dropbox.com/sh/ugf2xokml5uf9s0/AAALs9dvG5fwejfBFU-RbL0ma?dl=0) | [ckpt](https://www.dropbox.com/sh/xehsrmls76k85y0/AAC-QARNd_b4hJYC5V9QwEJXa?dl=0) | [ckpt](https://www.dropbox.com/sh/8zgknaiv8o19o9k/AAAG2ZncZmhdj-Hz9uM46u-ka?dl=0) | 108 | | CIFAR10 | WRN-28-4 | [ckpt](https://www.dropbox.com/sh/zvqgjd5xx06lh3t/AACT5vYS3S6b33-0uRDjK2Awa?dl=0) | [ckpt](https://www.dropbox.com/sh/b9cyx9ewg5dt981/AADMA9vVVCXe68RwrSZtC9tia?dl=0) | [ckpt](https://www.dropbox.com/sh/cbt8xqq9na4tj1b/AADyPq6J34cUWHB8GvGf_ivDa?dl=0) | [ckpt](https://www.dropbox.com/sh/pjn8thd1fw2pujr/AABcCAH7BEdVrJs0v_gMQ0lTa?dl=0) | 109 | | SVHN | VGG16 | [ckpt](https://www.dropbox.com/sh/jmo7hj25po0r7tl/AAAw756-U1bifArFr_y1GeSCa?dl=0) | [ckpt](https://www.dropbox.com/sh/7pg0aaguyzndx61/AABqL_8-XFhilpywT9jMHCHqa?dl=0) | [ckpt](https://www.dropbox.com/sh/m3t33ku6aqecv4u/AACykFCWN1-QwbMftvk-a-8na?dl=0) | [ckpt](https://www.dropbox.com/sh/d8il3fpzxvx4uhq/AACZF5GVuV5yzc781Ge5kkD9a?dl=0) | 110 | | SVHN | WRN-28-4 | [ckpt](https://www.dropbox.com/sh/0o906gxijsk4ruh/AAAAj-mwEnv7uNgildkeMqC-a?dl=0) | [ckpt](https://www.dropbox.com/sh/9hyh3iwnrjwvgon/AAC2a6vZSrN3DvzVaPeBhQ6Ya?dl=0) | [ckpt](https://www.dropbox.com/sh/5hs67w8yh9crhyx/AAB8Q4u_EE9rDlYkTF-bT95Ta?dl=0) | [ckpt](https://www.dropbox.com/sh/l0c1houep3w61b6/AAB9CXmKnOpmLe_VKkwB4Ovaa?dl=0) | 111 | 112 | 113 | 114 | ### Randomized smoothing 115 | 116 | | Dataset | Architecture | Pre-trained (0%) | 90% pruned | 95% pruned | 99% pruned | 117 | | :-----: | :----------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 118 | | CIFAR10 | VGG16 | [ckpt](https://www.dropbox.com/sh/y5n7000qt7004fu/AAC7eRNUkGQvFfoepwn6tTpaa?dl=0) | [ckpt](https://www.dropbox.com/sh/0pwxek9vom9cywl/AACDZ_-lmhsNK9BG1BlzWpLea?dl=0) | [ckpt](https://www.dropbox.com/sh/pe8mfstkxl621hb/AAAohk6M7o-NwRXUsvk-hLfCa?dl=0) | [ckpt](https://www.dropbox.com/sh/iahysjrj1dekzpw/AAAjvfsE9Xu1P_q23lAF7uNoa?dl=0) | 119 | | CIFAR10 | WRN-28-4 | [ckpt](https://www.dropbox.com/sh/4xwjxiyal1o7qr3/AABnCDX5dNin_NeYxmlS9XpLa?dl=0) | [ckpt](https://www.dropbox.com/sh/6jj33youpc41o4o/AAAfjYboGCg9yZc-XYyL3ABza?dl=0) | [ckpt](https://www.dropbox.com/sh/3qqw15yyza5zi6a/AABDVyGvJcCEyWT6kPDOQ-spa?dl=0) | [ckpt](https://www.dropbox.com/sh/m1dvdgedovb19yp/AACxxW-6xArpiVV4cfY7cwAYa?dl=0) | 120 | | SVHN | VGG16 | [ckpt](https://www.dropbox.com/sh/9k82top60lvngqb/AABAX9wJUBqGmF8akhoWrRA6a?dl=0) | [ckpt](https://www.dropbox.com/sh/7siuxmb6l9d1qt1/AADnA4m4-1k27eZCBkGyU6ena?dl=0) | [ckpt](https://www.dropbox.com/sh/j0eh9jyqpqurvl3/AAAS4awDRQhiyEnNEPNqwlg2a?dl=0) | [ckpt](https://www.dropbox.com/sh/3rnl9uea4cb44vs/AACaTNrTsp5JybLoCAGzid-4a?dl=0) | 121 | | SVHN | WRN-28-4 | [ckpt](https://www.dropbox.com/sh/m5he7uskva23sfr/AADUlbsXAxuROXFo7Bt2U8R6a?dl=0) | [ckpt](https://www.dropbox.com/sh/hzymmaem17pcr68/AADeFeEZJ4X2fo6WCiqfA1tFa?dl=0) | [ckpt](https://www.dropbox.com/sh/b8kqbkcsmxlhdt9/AABFYwwUHxj3-cnCgL3f0pota?dl=0) | [ckpt](https://www.dropbox.com/sh/g2z07aucy9tw4z8/AABJ1inIcVX2UFdD3e75vjMNa?dl=0) | 122 | 123 | 124 | 125 | ### Adversarial training on ImageNet (ResNet50) 126 | 127 | | Pre-trained (0%) | 95% pruned | 99% pruned | 128 | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 129 | | [ckpt](https://www.dropbox.com/sh/z9m0mp6jkdp0ovi/AACtN93nnlp-u48WOgeuzb8Ra?dl=0) | [ckpt](https://www.dropbox.com/sh/w003d06uga1ylu4/AADBY9zbz9dgGYi2Ir2ZyINAa?dl=0) | [ckpt](https://www.dropbox.com/sh/i9i1i50een62zae/AAAq-HNkEsYS8dEmQY3sU4ERa?dl=0) | 130 | 131 | 132 | 133 | ## Contributors 134 | 135 | * [Vikash Sehwag](https://vsehwag.github.io/) 136 | * [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) 137 | 138 | 139 | 140 | Some of the code in this repository is based on the following amazing works. 141 | 142 | * https://github.com/allenai/hidden-networks 143 | * https://github.com/yaircarmon/semisup-adv 144 | * https://github.com/locuslab/smoothing 145 | * https://github.com/huanzhang12/CROWN-IBP 146 | * https://github.com/tcwangshiqi-columbia/symbolic_interval 147 | 148 | 149 | 150 | ## Reference 151 | 152 | If you find this work helpful, consider citing it. 153 | ``` 154 | @article{sehwag2020hydra, 155 | title={Hydra: Pruning adversarially robust neural networks}, 156 | author={Sehwag, Vikash and Wang, Shiqi and Mittal, Prateek and Jana, Suman}, 157 | journal={Advances in Neural Information Processing Systems}, 158 | volume={33}, 159 | year={2020} 160 | } 161 | -------------------------------------------------------------------------------- /locuslab_smoothing/analyze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | # matplotlib.use("TkAgg") 5 | import matplotlib.pyplot as plt 6 | from typing import * 7 | import pandas as pd 8 | import seaborn as sns 9 | import math 10 | 11 | sns.set() 12 | 13 | 14 | class Accuracy(object): 15 | def at_radii(self, radii: np.ndarray): 16 | raise NotImplementedError() 17 | 18 | 19 | class ApproximateAccuracy(Accuracy): 20 | def __init__(self, data_file_path: str): 21 | self.data_file_path = data_file_path 22 | 23 | def at_radii(self, radii: np.ndarray) -> np.ndarray: 24 | df = pd.read_csv(self.data_file_path, delimiter="\t") 25 | return np.array([self.at_radius(df, radius) for radius in radii]) 26 | 27 | def at_radius(self, df: pd.DataFrame, radius: float): 28 | return (df["correct"] & (df["radius"] >= radius)).mean() 29 | 30 | 31 | class HighProbAccuracy(Accuracy): 32 | def __init__(self, data_file_path: str, alpha: float, rho: float): 33 | self.data_file_path = data_file_path 34 | self.alpha = alpha 35 | self.rho = rho 36 | 37 | def at_radii(self, radii: np.ndarray) -> np.ndarray: 38 | df = pd.read_csv(self.data_file_path, delimiter="\t") 39 | return np.array([self.at_radius(df, radius) for radius in radii]) 40 | 41 | def at_radius(self, df: pd.DataFrame, radius: float): 42 | mean = (df["correct"] & (df["radius"] >= radius)).mean() 43 | num_examples = len(df) 44 | return ( 45 | mean 46 | - self.alpha 47 | - math.sqrt( 48 | self.alpha * (1 - self.alpha) * math.log(1 / self.rho) / num_examples 49 | ) 50 | - math.log(1 / self.rho) / (3 * num_examples) 51 | ) 52 | 53 | 54 | class Line(object): 55 | def __init__( 56 | self, quantity: Accuracy, legend: str, plot_fmt: str = "", scale_x: float = 1 57 | ): 58 | self.quantity = quantity 59 | self.legend = legend 60 | self.plot_fmt = plot_fmt 61 | self.scale_x = scale_x 62 | 63 | 64 | def plot_certified_accuracy( 65 | outfile: str, 66 | title: str, 67 | max_radius: float, 68 | lines: List[Line], 69 | radius_step: float = 0.01, 70 | ) -> None: 71 | radii = np.arange(0, max_radius + radius_step, radius_step) 72 | plt.figure() 73 | for line in lines: 74 | plt.plot(radii * line.scale_x, line.quantity.at_radii(radii), line.plot_fmt) 75 | 76 | plt.ylim((0, 1)) 77 | plt.xlim((0, max_radius)) 78 | plt.tick_params(labelsize=14) 79 | plt.xlabel("radius", fontsize=16) 80 | plt.ylabel("certified accuracy", fontsize=16) 81 | plt.legend([method.legend for method in lines], loc="upper right", fontsize=16) 82 | plt.savefig(outfile + ".pdf") 83 | plt.tight_layout() 84 | plt.title(title, fontsize=20) 85 | plt.tight_layout() 86 | plt.savefig(outfile + ".png", dpi=300) 87 | plt.close() 88 | 89 | 90 | def smallplot_certified_accuracy( 91 | outfile: str, 92 | title: str, 93 | max_radius: float, 94 | methods: List[Line], 95 | radius_step: float = 0.01, 96 | xticks=0.5, 97 | ) -> None: 98 | radii = np.arange(0, max_radius + radius_step, radius_step) 99 | plt.figure() 100 | for method in methods: 101 | plt.plot(radii, method.quantity.at_radii(radii), method.plot_fmt) 102 | 103 | plt.ylim((0, 1)) 104 | plt.xlim((0, max_radius)) 105 | plt.xlabel("radius", fontsize=22) 106 | plt.ylabel("certified accuracy", fontsize=22) 107 | plt.tick_params(labelsize=20) 108 | plt.gca().xaxis.set_major_locator(plt.MultipleLocator(xticks)) 109 | plt.legend([method.legend for method in methods], loc="upper right", fontsize=20) 110 | plt.tight_layout() 111 | plt.savefig(outfile + ".pdf") 112 | plt.close() 113 | 114 | 115 | def latex_table_certified_accuracy( 116 | outfile: str, 117 | radius_start: float, 118 | radius_stop: float, 119 | radius_step: float, 120 | methods: List[Line], 121 | ): 122 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step) 123 | accuracies = np.zeros((len(methods), len(radii))) 124 | for i, method in enumerate(methods): 125 | accuracies[i, :] = method.quantity.at_radii(radii) 126 | 127 | f = open(outfile, "w") 128 | 129 | for radius in radii: 130 | f.write("& $r = {:.3}$".format(radius)) 131 | f.write("\\\\\n") 132 | 133 | f.write("\midrule\n") 134 | 135 | for i, method in enumerate(methods): 136 | f.write(method.legend) 137 | for j, radius in enumerate(radii): 138 | if i == accuracies[:, j].argmax(): 139 | txt = r" & \textbf{" + "{:.2f}".format(accuracies[i, j]) + "}" 140 | else: 141 | txt = " & {:.2f}".format(accuracies[i, j]) 142 | f.write(txt) 143 | f.write("\\\\\n") 144 | f.close() 145 | 146 | 147 | def markdown_table_certified_accuracy( 148 | outfile: str, 149 | radius_start: float, 150 | radius_stop: float, 151 | radius_step: float, 152 | methods: List[Line], 153 | ): 154 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step) 155 | accuracies = np.zeros((len(methods), len(radii))) 156 | for i, method in enumerate(methods): 157 | accuracies[i, :] = method.quantity.at_radii(radii) 158 | 159 | f = open(outfile, "w") 160 | f.write("| | ") 161 | for radius in radii: 162 | f.write("r = {:.3} |".format(radius)) 163 | f.write("\n") 164 | 165 | f.write("| --- | ") 166 | for i in range(len(radii)): 167 | f.write(" --- |") 168 | f.write("\n") 169 | 170 | for i, method in enumerate(methods): 171 | f.write(" {} | ".format(method.legend)) 172 | for j, radius in enumerate(radii): 173 | if i == accuracies[:, j].argmax(): 174 | txt = "{:.2f}* |".format(accuracies[i, j]) 175 | else: 176 | txt = "{:.2f} |".format(accuracies[i, j]) 177 | f.write(txt) 178 | f.write("\n") 179 | f.close() 180 | 181 | 182 | if __name__ == "__main__": 183 | latex_table_certified_accuracy( 184 | "analysis/latex/vary_noise_cifar10", 185 | 0.25, 186 | 1.5, 187 | 0.25, 188 | [ 189 | Line( 190 | ApproximateAccuracy( 191 | "data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12" 192 | ), 193 | "$\sigma = 0.12$", 194 | ), 195 | Line( 196 | ApproximateAccuracy( 197 | "data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25" 198 | ), 199 | "$\sigma = 0.25$", 200 | ), 201 | Line( 202 | ApproximateAccuracy( 203 | "data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50" 204 | ), 205 | "$\sigma = 0.50$", 206 | ), 207 | Line( 208 | ApproximateAccuracy( 209 | "data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00" 210 | ), 211 | "$\sigma = 1.00$", 212 | ), 213 | ], 214 | ) 215 | markdown_table_certified_accuracy( 216 | "analysis/markdown/vary_noise_cifar10", 217 | 0.25, 218 | 1.5, 219 | 0.25, 220 | [ 221 | Line( 222 | ApproximateAccuracy( 223 | "data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12" 224 | ), 225 | "σ = 0.12", 226 | ), 227 | Line( 228 | ApproximateAccuracy( 229 | "data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25" 230 | ), 231 | "σ = 0.25", 232 | ), 233 | Line( 234 | ApproximateAccuracy( 235 | "data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50" 236 | ), 237 | "σ = 0.50", 238 | ), 239 | Line( 240 | ApproximateAccuracy( 241 | "data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00" 242 | ), 243 | "σ = 1.00", 244 | ), 245 | ], 246 | ) 247 | latex_table_certified_accuracy( 248 | "analysis/latex/vary_noise_imagenet", 249 | 0.5, 250 | 3.0, 251 | 0.5, 252 | [ 253 | Line( 254 | ApproximateAccuracy( 255 | "data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25" 256 | ), 257 | "$\sigma = 0.25$", 258 | ), 259 | Line( 260 | ApproximateAccuracy( 261 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50" 262 | ), 263 | "$\sigma = 0.50$", 264 | ), 265 | Line( 266 | ApproximateAccuracy( 267 | "data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00" 268 | ), 269 | "$\sigma = 1.00$", 270 | ), 271 | ], 272 | ) 273 | markdown_table_certified_accuracy( 274 | "analysis/markdown/vary_noise_imagenet", 275 | 0.5, 276 | 3.0, 277 | 0.5, 278 | [ 279 | Line( 280 | ApproximateAccuracy( 281 | "data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25" 282 | ), 283 | "σ = 0.25", 284 | ), 285 | Line( 286 | ApproximateAccuracy( 287 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50" 288 | ), 289 | "σ = 0.50", 290 | ), 291 | Line( 292 | ApproximateAccuracy( 293 | "data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00" 294 | ), 295 | "σ = 1.00", 296 | ), 297 | ], 298 | ) 299 | plot_certified_accuracy( 300 | "analysis/plots/vary_noise_cifar10", 301 | "CIFAR-10, vary $\sigma$", 302 | 1.5, 303 | [ 304 | Line( 305 | ApproximateAccuracy( 306 | "data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12" 307 | ), 308 | "$\sigma = 0.12$", 309 | ), 310 | Line( 311 | ApproximateAccuracy( 312 | "data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25" 313 | ), 314 | "$\sigma = 0.25$", 315 | ), 316 | Line( 317 | ApproximateAccuracy( 318 | "data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50" 319 | ), 320 | "$\sigma = 0.50$", 321 | ), 322 | Line( 323 | ApproximateAccuracy( 324 | "data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00" 325 | ), 326 | "$\sigma = 1.00$", 327 | ), 328 | ], 329 | ) 330 | plot_certified_accuracy( 331 | "analysis/plots/vary_train_noise_cifar_050", 332 | "CIFAR-10, vary train noise, $\sigma=0.5$", 333 | 1.5, 334 | [ 335 | Line( 336 | ApproximateAccuracy( 337 | "data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50" 338 | ), 339 | "train $\sigma = 0.25$", 340 | ), 341 | Line( 342 | ApproximateAccuracy( 343 | "data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50" 344 | ), 345 | "train $\sigma = 0.50$", 346 | ), 347 | Line( 348 | ApproximateAccuracy( 349 | "data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50" 350 | ), 351 | "train $\sigma = 1.00$", 352 | ), 353 | ], 354 | ) 355 | plot_certified_accuracy( 356 | "analysis/plots/vary_train_noise_imagenet_050", 357 | "ImageNet, vary train noise, $\sigma=0.5$", 358 | 1.5, 359 | [ 360 | Line( 361 | ApproximateAccuracy( 362 | "data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50" 363 | ), 364 | "train $\sigma = 0.25$", 365 | ), 366 | Line( 367 | ApproximateAccuracy( 368 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50" 369 | ), 370 | "train $\sigma = 0.50$", 371 | ), 372 | Line( 373 | ApproximateAccuracy( 374 | "data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50" 375 | ), 376 | "train $\sigma = 1.00$", 377 | ), 378 | ], 379 | ) 380 | plot_certified_accuracy( 381 | "analysis/plots/vary_noise_imagenet", 382 | "ImageNet, vary $\sigma$", 383 | 4, 384 | [ 385 | Line( 386 | ApproximateAccuracy( 387 | "data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25" 388 | ), 389 | "$\sigma = 0.25$", 390 | ), 391 | Line( 392 | ApproximateAccuracy( 393 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50" 394 | ), 395 | "$\sigma = 0.50$", 396 | ), 397 | Line( 398 | ApproximateAccuracy( 399 | "data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00" 400 | ), 401 | "$\sigma = 1.00$", 402 | ), 403 | ], 404 | ) 405 | plot_certified_accuracy( 406 | "analysis/plots/high_prob", 407 | "Approximate vs. High-Probability", 408 | 2.0, 409 | [ 410 | Line( 411 | ApproximateAccuracy( 412 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50" 413 | ), 414 | "Approximate", 415 | ), 416 | Line( 417 | HighProbAccuracy( 418 | "data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50", 419 | 0.001, 420 | 0.001, 421 | ), 422 | "High-Prob", 423 | ), 424 | ], 425 | ) 426 | 427 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | from utils.logging import AverageMeter, ProgressMeter 8 | from utils.adv import pgd_whitebox, fgsm 9 | from symbolic_interval.symbolic_network import ( 10 | sym_interval_analyze, 11 | naive_interval_analyze, 12 | mix_interval_analyze, 13 | ) 14 | from crown.bound_layers import ( 15 | BoundSequential, 16 | BoundLinear, 17 | BoundConv2d, 18 | BoundDataParallel, 19 | Flatten, 20 | ) 21 | 22 | from scipy.stats import norm 23 | import numpy as np 24 | import time 25 | 26 | 27 | def get_output_for_batch(model, img, temp=1): 28 | """ 29 | model(x) is expected to return logits (instead of softmax probas) 30 | """ 31 | with torch.no_grad(): 32 | out = nn.Softmax(dim=-1)(model(img) / temp) 33 | p, index = torch.max(out, dim=-1) 34 | return p.data.cpu().numpy(), index.data.cpu().numpy() 35 | 36 | 37 | def accuracy(output, target, topk=(1,)): 38 | """Computes the accuracy over the k top predictions for the specified values of k""" 39 | with torch.no_grad(): 40 | maxk = max(topk) 41 | batch_size = target.size(0) 42 | 43 | _, pred = output.topk(maxk, 1, True, True) 44 | pred = pred.t() 45 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 46 | 47 | res = [] 48 | for k in topk: 49 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 50 | res.append(correct_k.mul_(100.0 / batch_size)) 51 | return res 52 | 53 | 54 | def base(model, device, val_loader, criterion, args, writer, epoch=0): 55 | """ 56 | Evaluating on unmodified validation set inputs. 57 | """ 58 | batch_time = AverageMeter("Time", ":6.3f") 59 | losses = AverageMeter("Loss", ":.4f") 60 | top1 = AverageMeter("Acc_1", ":6.2f") 61 | top5 = AverageMeter("Acc_5", ":6.2f") 62 | progress = ProgressMeter( 63 | len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 64 | ) 65 | 66 | # switch to evaluate mode 67 | model.eval() 68 | 69 | with torch.no_grad(): 70 | end = time.time() 71 | for i, data in enumerate(val_loader): 72 | images, target = data[0].to(device), data[1].to(device) 73 | 74 | # compute output 75 | output = model(images) 76 | loss = criterion(output, target) 77 | 78 | # measure accuracy and record loss 79 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 80 | losses.update(loss.item(), images.size(0)) 81 | top1.update(acc1[0], images.size(0)) 82 | top5.update(acc5[0], images.size(0)) 83 | 84 | # measure elapsed time 85 | batch_time.update(time.time() - end) 86 | end = time.time() 87 | 88 | if (i + 1) % args.print_freq == 0: 89 | progress.display(i) 90 | 91 | if writer: 92 | progress.write_to_tensorboard( 93 | writer, "test", epoch * len(val_loader) + i 94 | ) 95 | 96 | # write a sample of test images to tensorboard (helpful for debugging) 97 | if i == 0 and writer: 98 | writer.add_image( 99 | "test-images", 100 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 101 | ) 102 | progress.display(i) # print final results 103 | 104 | return top1.avg, top5.avg 105 | 106 | 107 | def adv(model, device, val_loader, criterion, args, writer, epoch=0): 108 | """ 109 | Evaluate on adversarial validation set inputs. 110 | """ 111 | 112 | batch_time = AverageMeter("Time", ":6.3f") 113 | losses = AverageMeter("Loss", ":.4f") 114 | adv_losses = AverageMeter("Adv_Loss", ":.4f") 115 | top1 = AverageMeter("Acc_1", ":6.2f") 116 | top5 = AverageMeter("Acc_5", ":6.2f") 117 | adv_top1 = AverageMeter("Adv-Acc_1", ":6.2f") 118 | adv_top5 = AverageMeter("Adv-Acc_5", ":6.2f") 119 | progress = ProgressMeter( 120 | len(val_loader), 121 | [batch_time, losses, adv_losses, top1, top5, adv_top1, adv_top5], 122 | prefix="Test: ", 123 | ) 124 | 125 | # switch to evaluation mode 126 | model.eval() 127 | 128 | with torch.no_grad(): 129 | end = time.time() 130 | for i, data in enumerate(val_loader): 131 | images, target = data[0].to(device), data[1].to(device) 132 | 133 | # clean images 134 | output = model(images) 135 | loss = criterion(output, target) 136 | 137 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 138 | losses.update(loss.item(), images.size(0)) 139 | top1.update(acc1[0], images.size(0)) 140 | top5.update(acc5[0], images.size(0)) 141 | 142 | # adversarial images 143 | images = pgd_whitebox( 144 | model, 145 | images, 146 | target, 147 | device, 148 | args.epsilon, 149 | args.num_steps, 150 | args.step_size, 151 | args.clip_min, 152 | args.clip_max, 153 | is_random=not args.const_init, 154 | ) 155 | 156 | # compute output 157 | output = model(images) 158 | loss = criterion(output, target) 159 | 160 | # measure accuracy and record loss 161 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 162 | adv_losses.update(loss.item(), images.size(0)) 163 | adv_top1.update(acc1[0], images.size(0)) 164 | adv_top5.update(acc5[0], images.size(0)) 165 | 166 | # measure elapsed time 167 | batch_time.update(time.time() - end) 168 | end = time.time() 169 | 170 | if (i + 1) % args.print_freq == 0: 171 | progress.display(i) 172 | 173 | if writer: 174 | progress.write_to_tensorboard( 175 | writer, "test", epoch * len(val_loader) + i 176 | ) 177 | 178 | # write a sample of test images to tensorboard (helpful for debugging) 179 | if i == 0 and writer: 180 | writer.add_image( 181 | "Adv-test-images", 182 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 183 | ) 184 | progress.display(i) # print final results 185 | 186 | return adv_top1.avg, adv_top5.avg 187 | 188 | 189 | def mixtrain(model, device, val_loader, criterion, args, writer, epoch=0): 190 | batch_time = AverageMeter("Time", ":6.3f") 191 | losses = AverageMeter("Loss", ":.4f") 192 | sym_losses = AverageMeter("Sym_Loss", ":.4f") 193 | top1 = AverageMeter("Acc_1", ":6.2f") 194 | top5 = AverageMeter("Acc_5", ":6.2f") 195 | sym_top1 = AverageMeter("Sym-Acc_1", ":6.2f") 196 | progress = ProgressMeter( 197 | len(val_loader), 198 | [batch_time, losses, sym_losses, top1, top5, sym_top1], 199 | prefix="Test: ", 200 | ) 201 | 202 | # switch to evaluation mode 203 | model.eval() 204 | 205 | with torch.no_grad(): 206 | end = time.time() 207 | for i, data in enumerate(val_loader): 208 | images, target = data[0].to(device), data[1].to(device) 209 | 210 | # clean images 211 | output = model(images) 212 | loss = criterion(output, target) 213 | 214 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 215 | losses.update(loss.item(), images.size(0)) 216 | top1.update(acc1[0], images.size(0)) 217 | top5.update(acc5[0], images.size(0)) 218 | 219 | rce_avg = 0 220 | rerr_avg = 0 221 | for r in range(images.shape[0]): 222 | 223 | rce, rerr = sym_interval_analyze( 224 | model, 225 | args.epsilon, 226 | images[r : r + 1], 227 | target[r : r + 1], 228 | use_cuda=torch.cuda.is_available(), 229 | parallel=False, 230 | ) 231 | rce_avg = rce_avg + rce.item() 232 | rerr_avg = rerr_avg + rerr 233 | 234 | rce_avg = rce_avg / float(images.shape[0]) 235 | rerr_avg = rerr_avg / float(images.shape[0]) 236 | 237 | # measure accuracy and record loss 238 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 239 | sym_losses.update(rce_avg, images.size(0)) 240 | sym_top1.update((1 - rerr_avg) * 100.0, images.size(0)) 241 | 242 | # measure elapsed time 243 | batch_time.update(time.time() - end) 244 | end = time.time() 245 | 246 | if (i + 1) % args.print_freq == 0: 247 | progress.display(i) 248 | 249 | if writer: 250 | progress.write_to_tensorboard( 251 | writer, "test", epoch * len(val_loader) + i 252 | ) 253 | 254 | # write a sample of test images to tensorboard (helpful for debugging) 255 | if i == 0 and writer: 256 | writer.add_image( 257 | "Adv-test-images", 258 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 259 | ) 260 | progress.display(i) # print final results 261 | 262 | return sym_top1.avg, sym_top1.avg 263 | 264 | 265 | def ibp(model, device, val_loader, criterion, args, writer, epoch=0): 266 | batch_time = AverageMeter("Time", ":6.3f") 267 | losses = AverageMeter("Loss", ":.4f") 268 | ibp_losses = AverageMeter("IBP_Loss", ":.4f") 269 | top1 = AverageMeter("Acc_1", ":6.2f") 270 | top5 = AverageMeter("Acc_5", ":6.2f") 271 | ibp_top1 = AverageMeter("IBP-Acc_1", ":6.2f") 272 | progress = ProgressMeter( 273 | len(val_loader), 274 | [batch_time, losses, ibp_losses, top1, top5, ibp_top1], 275 | prefix="Test: ", 276 | ) 277 | 278 | # switch to evaluation mode 279 | model.eval() 280 | 281 | with torch.no_grad(): 282 | end = time.time() 283 | for i, data in enumerate(val_loader): 284 | images, target = data[0].to(device), data[1].to(device) 285 | 286 | # clean images 287 | 288 | output = model(images) 289 | loss = criterion(output, target) 290 | 291 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 292 | losses.update(loss.item(), images.size(0)) 293 | top1.update(acc1[0], images.size(0)) 294 | top5.update(acc5[0], images.size(0)) 295 | 296 | rce, rerr = naive_interval_analyze( 297 | model, 298 | args.epsilon, 299 | images, 300 | target, 301 | use_cuda=torch.cuda.is_available(), 302 | parallel=False, 303 | ) 304 | 305 | # measure accuracy and record loss 306 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 307 | ibp_losses.update(rce.item(), images.size(0)) 308 | ibp_top1.update((1 - rerr) * 100.0, images.size(0)) 309 | 310 | # measure elapsed time 311 | batch_time.update(time.time() - end) 312 | end = time.time() 313 | 314 | if (i + 1) % args.print_freq == 0: 315 | progress.display(i) 316 | 317 | if writer: 318 | progress.write_to_tensorboard( 319 | writer, "test", epoch * len(val_loader) + i 320 | ) 321 | 322 | # write a sample of test images to tensorboard (helpful for debugging) 323 | if i == 0 and writer: 324 | writer.add_image( 325 | "Adv-test-images", 326 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 327 | ) 328 | progress.display(i) # print final results 329 | 330 | return ibp_top1.avg, ibp_top1.avg 331 | 332 | 333 | def smooth(model, device, val_loader, criterion, args, writer, epoch=0): 334 | """ 335 | Evaluating on unmodified validation set inputs. 336 | """ 337 | batch_time = AverageMeter("Time", ":6.3f") 338 | top1 = AverageMeter("Acc_1", ":6.2f") 339 | top5 = AverageMeter("Acc_5", ":6.2f") 340 | rad = AverageMeter("rad", ":6.2f") 341 | progress = ProgressMeter( 342 | len(val_loader), [batch_time, top1, top5, rad], prefix="Smooth (eval): " 343 | ) 344 | 345 | # switch to evaluate mode 346 | model.eval() 347 | 348 | with torch.no_grad(): 349 | end = time.time() 350 | for i, data in enumerate(val_loader): 351 | images, target = data[0].to(device), data[1].to(device) 352 | 353 | # Defult: evaluate on 10 random samples of additive gaussian noise. 354 | output = [] 355 | for _ in range(10): 356 | # add noise 357 | if args.dataset == "imagenet": 358 | std = ( 359 | torch.tensor([0.229, 0.224, 0.225]) 360 | .unsqueeze(0) 361 | .unsqueeze(-1) 362 | .unsqueeze(-1) 363 | ).to(device) 364 | noise = (torch.randn_like(images) / std).to(device) * args.noise_std 365 | else: 366 | noise = torch.randn_like(images).to(device) * args.noise_std 367 | 368 | output.append(F.softmax(model(images + noise), -1)) 369 | 370 | output = torch.sum(torch.stack(output), axis=0) 371 | 372 | p_max, _ = output.max(dim=-1) 373 | radii = (args.noise_std + 1e-16) * norm.ppf(p_max.data.cpu().numpy()) 374 | 375 | # measure accuracy and record loss 376 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 377 | top1.update(acc1[0], images.size(0)) 378 | top5.update(acc5[0], images.size(0)) 379 | rad.update(np.mean(radii)) 380 | 381 | # measure elapsed time 382 | batch_time.update(time.time() - end) 383 | end = time.time() 384 | 385 | if (i + 1) % args.print_freq == 0: 386 | progress.display(i) 387 | 388 | if writer: 389 | progress.write_to_tensorboard( 390 | writer, "test", epoch * len(val_loader) + i 391 | ) 392 | 393 | # write a sample of test images to tensorboard (helpful for debugging) 394 | if i == 0 and writer: 395 | writer.add_image( 396 | "Adv-test-images", 397 | torchvision.utils.make_grid(images[0 : len(images) // 4]), 398 | ) 399 | 400 | progress.display(i) # print final results 401 | 402 | return top1.avg, rad.avg 403 | 404 | 405 | def freeadv(model, device, val_loader, criterion, args, writer, epoch=0): 406 | 407 | assert ( 408 | not args.normalize 409 | ), "Explicit normalization is done in the training loop, Dataset should have [0, 1] dynamic range." 410 | 411 | # Mean/Std for normalization 412 | mean = torch.Tensor(np.array(args.mean)[:, np.newaxis, np.newaxis]) 413 | mean = mean.expand(3, args.image_dim, args.image_dim).to(device) 414 | std = torch.Tensor(np.array(args.std)[:, np.newaxis, np.newaxis]) 415 | std = std.expand(3, args.image_dim, args.image_dim).to(device) 416 | 417 | batch_time = AverageMeter("Time", ":6.3f") 418 | losses = AverageMeter("Loss", ":.4f") 419 | top1 = AverageMeter("Acc_1", ":6.2f") 420 | top5 = AverageMeter("Acc_5", ":6.2f") 421 | progress = ProgressMeter( 422 | len(val_loader), [batch_time, losses, top1, top5], prefix="Test: ", 423 | ) 424 | 425 | eps = args.epsilon 426 | K = args.num_steps 427 | step = args.step_size 428 | model.eval() 429 | end = time.time() 430 | print(" PGD eps: {}, num-steps: {}, step-size: {} ".format(eps, K, step)) 431 | for i, (input, target) in enumerate(val_loader): 432 | 433 | input = input.to(device, non_blocking=True) 434 | target = target.to(device, non_blocking=True) 435 | 436 | orig_input = input.clone() 437 | randn = torch.FloatTensor(input.size()).uniform_(-eps, eps).to(device) 438 | input += randn 439 | input.clamp_(0, 1.0) 440 | for _ in range(K): 441 | invar = Variable(input, requires_grad=True) 442 | in1 = invar - mean 443 | in1.div_(std) 444 | output = model(in1) 445 | ascend_loss = criterion(output, target) 446 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0] 447 | pert = fgsm(ascend_grad, step) 448 | # Apply purturbation 449 | input += pert.data 450 | input = torch.max(orig_input - eps, input) 451 | input = torch.min(orig_input + eps, input) 452 | input.clamp_(0, 1.0) 453 | 454 | input.sub_(mean).div_(std) 455 | with torch.no_grad(): 456 | # compute output 457 | output = model(input) 458 | loss = criterion(output, target) 459 | 460 | # measure accuracy and record loss 461 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 462 | losses.update(loss.item(), input.size(0)) 463 | top1.update(prec1[0], input.size(0)) 464 | top5.update(prec5[0], input.size(0)) 465 | 466 | # measure elapsed time 467 | batch_time.update(time.time() - end) 468 | end = time.time() 469 | 470 | if (i + 1) % args.print_freq == 0: 471 | progress.display(i) 472 | 473 | if writer: 474 | progress.write_to_tensorboard(writer, "test", epoch * len(val_loader) + i) 475 | 476 | # write a sample of test images to tensorboard (helpful for debugging) 477 | if i == 0 and writer: 478 | writer.add_image( 479 | "Adv-test-images", 480 | torchvision.utils.make_grid(input[0 : len(input) // 4]), 481 | ) 482 | 483 | progress.display(i) # print final results 484 | 485 | return top1.avg, top5.avg 486 | --------------------------------------------------------------------------------