├── Asym-RoCL ├── data │ ├── __init__.py │ ├── .DS_Store │ ├── vision.py │ ├── cifar.py │ └── utils.py ├── models │ ├── __init__.py │ ├── .DS_Store │ ├── projector.py │ └── resnet.py ├── warmup_scheduler │ ├── __init__.py │ ├── .DS_Store │ ├── run.py │ └── scheduler.py ├── init_paths.py ├── model_loader.py ├── shell │ ├── rocl-cifar10.sh │ ├── rocl-cifar100.sh │ ├── rocl-HN-cifar100.sh │ ├── rocl-IP-cifar10.sh │ ├── rocl-IP-cifar100.sh │ ├── rocl-HN-cifar10.sh │ ├── rocl-IPHN-cifar100.sh │ └── rocl-IPHN-cifar10.sh ├── trades.py ├── robustness_test.py ├── data_loader.py ├── attack_lib.py ├── utils.py ├── rocl_train.py ├── loss.py ├── linear_eval.py └── argument.py ├── figures ├── .DS_Store └── intro.png ├── Asym-AdvCL ├── models │ ├── .DS_Store │ ├── linear.py │ └── resnet_cifar.py ├── shell │ ├── .DS_Store │ ├── cifar10 │ │ ├── .DS_Store │ │ ├── advcl-HN-cifar10.sh │ │ ├── advcl-IP-cifar10.sh │ │ ├── advcl-cifar10.sh │ │ └── advcl-IPHN-cifar10.sh │ └── cifar100 │ │ ├── .DS_Store │ │ ├── advcl-HN-cifar100.sh │ │ ├── advcl-cifar100.sh │ │ ├── advcl-IP-cifar100.sh │ │ └── advcl-IPHN-cifar100.sh ├── data │ ├── cifar100_pseudo_labels.pkl │ └── imagenet_clPretrain_pseudo_labels.pkl ├── .idea │ ├── inspectionProfiles │ │ ├── profiles_settings.xml │ │ └── Project_Default.xml │ ├── modules.xml │ └── Asym-AdvCL.iml ├── fr_util.py ├── dataset.py ├── trades.py ├── utils.py ├── losses.py ├── finetuning_advCL_SLF.py └── pretraining_advCL.py └── README.md /Asym-RoCL/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Asym-RoCL/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .projector import * 3 | 4 | -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/figures/.DS_Store -------------------------------------------------------------------------------- /figures/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/figures/intro.png -------------------------------------------------------------------------------- /Asym-RoCL/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/data/.DS_Store -------------------------------------------------------------------------------- /Asym-AdvCL/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/models/.DS_Store -------------------------------------------------------------------------------- /Asym-AdvCL/shell/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/.DS_Store -------------------------------------------------------------------------------- /Asym-RoCL/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/models/.DS_Store -------------------------------------------------------------------------------- /Asym-RoCL/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar10/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/cifar10/.DS_Store -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar100/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/cifar100/.DS_Store -------------------------------------------------------------------------------- /Asym-RoCL/init_paths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | sys.path.insert(0, 'lib') 5 | 6 | -------------------------------------------------------------------------------- /Asym-RoCL/warmup_scheduler/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/warmup_scheduler/.DS_Store -------------------------------------------------------------------------------- /Asym-AdvCL/data/cifar100_pseudo_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/data/cifar100_pseudo_labels.pkl -------------------------------------------------------------------------------- /Asym-AdvCL/data/imagenet_clPretrain_pseudo_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/data/imagenet_clPretrain_pseudo_labels.pkl -------------------------------------------------------------------------------- /Asym-AdvCL/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /Asym-AdvCL/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Asym-AdvCL/.idea/Asym-AdvCL.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Asym-RoCL/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from warmup_scheduler import GradualWarmupScheduler 4 | 5 | 6 | if __name__ == '__main__': 7 | v = torch.zeros(10) 8 | optim = torch.optim.SGD([v], lr=0.01) 9 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=10) 10 | 11 | for epoch in range(1, 20): 12 | scheduler.step(epoch) 13 | 14 | print(epoch, optim.param_groups[0]['lr']) 15 | -------------------------------------------------------------------------------- /Asym-RoCL/models/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Projector(nn.Module): 6 | def __init__(self, expansion=0): 7 | super(Projector, self).__init__() 8 | 9 | self.linear_1 = nn.Linear(512*expansion, 2048) 10 | self.linear_2 = nn.Linear(2048, 128) 11 | 12 | def forward(self, x): 13 | 14 | output = self.linear_1(x) 15 | output = F.relu(output) 16 | 17 | output = self.linear_2(output) 18 | 19 | return output 20 | 21 | -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar10/advcl-HN-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --learning_rate 1.2 --dataset cifar10 --cosine --name $1 --epoch 400 --attack_ori --HN --beta 1.0 --tau_plus 0.12 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar10/advcl-HN-cifar10.sh advcl-HN-cifar10 | tee logs/advcl-HN-cifar10.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar100/advcl-HN-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --attack_ori --HN --beta 1. --tau_plus 0.01 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar100/advcl-HN-cifar100.sh advcl-HN-cifar100 | tee logs/advcl-HN-cifar100.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar10/advcl-IP-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --dataset cifar10 --cosine --name $1 --epoch 400 --stpg_degree 0.2 --stop_grad --adv_weight 1.2 --stop_grad_adaptive 30 --learning_rate 1. 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar10/advcl-IP-cifar10.sh advcl-IP-cifar10 | tee logs/advcl-IP-cifar10.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar10/advcl-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # follow the settings in [AdvCL](https://arxiv.org/pdf/2111.01124.pdf) 4 | 5 | python pretraining_advCL.py --learning_rate 0.5 --dataset cifar10 --cosine --name $1 --epoch 400 6 | 7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF 8 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF 9 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 10 | 11 | # sh shell/cifar10/advcl-cifar10.sh advcl-cifar10 | tee logs/advcl-cifar10.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar100/advcl-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # follow the settings in [AdvCL](https://arxiv.org/pdf/2111.01124.pdf) 4 | 5 | python pretraining_advCL.py --learning_rate 0.5 --dataset cifar100 --cosine --name $1 --epoch 400 6 | 7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF 8 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF 9 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 10 | 11 | # sh shell/cifar100/advcl-cifar100.sh advcl-cifar100 | tee logs/advcl-cifar100.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar100/advcl-IP-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --stpg_degree 0.3 --stop_grad --adv_weight 1.5 --stop_grad_adaptive 20 --d_min 0.5 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar100/advcl-IP-cifar100.sh advcl-IP-cifar100 | tee logs/advcl-IP-cifar100.out -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar10/advcl-IPHN-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar10 --cosine --name $1 --epoch 400 --stpg_degree 0.2 --stop_grad --adv_weight 0.7 --stop_grad_adaptive 30 --attack_ori --HN 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar10/advcl-IPHN-cifar10.sh advcl-IPHN-cifar10 | tee logs/advcl-IPHN-cifar10.out -------------------------------------------------------------------------------- /Asym-RoCL/model_loader.py: -------------------------------------------------------------------------------- 1 | from models.resnet import ResNet18,ResNet50 2 | def get_model(args): 3 | 4 | if args.dataset == 'cifar-10': 5 | num_classes=10 6 | elif args.dataset == 'cifar-100': 7 | num_classes=100 8 | else: 9 | raise NotImplementedError 10 | 11 | if 'contrastive' in args.train_type or 'linear_eval' in args.train_type: 12 | contrastive_learning=True 13 | else: 14 | contrastive_learning=False 15 | 16 | if args.model == 'ResNet18': 17 | model = ResNet18(num_classes,contrastive_learning) 18 | print('ResNet18 is loading ...') 19 | elif args.model == 'ResNet50': 20 | model = ResNet50(num_classes,contrastive_learning) 21 | print('ResNet 50 is loading ...') 22 | return model 23 | 24 | -------------------------------------------------------------------------------- /Asym-AdvCL/shell/cifar100/advcl-IPHN-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --stpg_degree 0.3 --stop_grad --adv_weight 1.5 --stop_grad_adaptive 20 --d_min 0.5 --attack_ori --HN --beta 1.0 --tau 0.01 4 | 5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF 6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF 7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades 8 | 9 | # sh shell/cifar100/advcl-IPHN-cifar100.sh advcl-IPHN-cifar100 | tee logs/advcl-IPHN-cifar100.out -------------------------------------------------------------------------------- /Asym-AdvCL/fr_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def distance(i, j, imageSize, r): 6 | dis = np.sqrt((i - imageSize / 2) ** 2 + (j - imageSize / 2) ** 2) 7 | if dis < r: 8 | return 1.0 9 | else: 10 | return 0 11 | 12 | 13 | def mask_radial(img, r): 14 | rows, cols = img.shape 15 | mask = torch.zeros((rows, cols)) 16 | for i in range(rows): 17 | for j in range(cols): 18 | mask[i, j] = distance(i, j, imageSize=rows, r=r) 19 | return mask.cuda() 20 | 21 | 22 | def generate_high(Images, r): 23 | # Image: bsxcxhxw, input batched images 24 | # r: int, radius 25 | mask = mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r) 26 | bs, c, h, w = Images.shape 27 | x = Images.reshape([bs * c, h, w]) 28 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1))) 29 | mask = mask.unsqueeze(0).repeat([bs * c, 1, 1]) 30 | fd = fd * (1. - mask) 31 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1)) 32 | fd = torch.real(fd) 33 | fd = fd.reshape([bs, c, h, w]) 34 | return fd 35 | -------------------------------------------------------------------------------- /Asym-AdvCL/models/linear.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class LinearClassifier(nn.Module): 5 | """Linear classifier""" 6 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10): 7 | super(LinearClassifier, self).__init__() 8 | # _, feat_dim = model_dict[name] 9 | self.fc = nn.Linear(feat_dim, num_classes) 10 | 11 | def forward(self, features): 12 | return self.fc(features) 13 | 14 | 15 | class NonLinearClassifier(nn.Module): 16 | """Linear classifier""" 17 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10): 18 | super(NonLinearClassifier, self).__init__() 19 | # _, feat_dim = model_dict[name] 20 | self.fc1 = nn.Linear(feat_dim, feat_dim) 21 | # self.bn1 = nn.BatchNorm1d(512) 22 | self.fc2 = nn.Linear(feat_dim, feat_dim) 23 | # self.bn2 = nn.BatchNorm1d(512) 24 | self.fc3 = nn.Linear(feat_dim, num_classes) 25 | 26 | def forward(self, features): 27 | # features = F.relu(self.bn1(self.fc1(features))) 28 | # features = F.relu(self.bn2(self.fc2(features))) 29 | features = F.relu(self.fc1(features)) 30 | features = F.relu(self.fc2(features)) 31 | return self.fc3(features) -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12562 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12562 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12562 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12352 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12352 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-cifar10.sh rocl-cifar10 ResNet18 cifar-10 | tee logs/rocl-cifar10.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12561 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12561 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12561 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12351 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12351 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-cifar100.sh rocl-cifar100 ResNet18 cifar-100 | tee logs/rocl-cifar100.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-HN-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12565 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12565 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12565 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12355 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12355 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-HN-cifar100.sh rocl-HN-cifar100 ResNet18 cifar-100 | tee logs/rocl-HN-cifar100.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-IP-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12564 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --stop_grad --stpg_degree 0.2 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12564 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12564 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12354 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12354 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-IP-cifar10.sh rocl-IP-cifar10 ResNet18 cifar-10 | tee logs/rocl-IP-cifar10.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-IP-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12567 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --stop_grad --stpg_degree 0.2 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12567 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12567 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12357 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12357 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-IP-cifar100.sh rocl-IP-cifar100 ResNet18 cifar-100 | tee logs/rocl-IP-cifar100.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-HN-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12563 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN --beta 1.0 --tau 0.12\ 6 | --lr 0.2 7 | 8 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12563 \ 9 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 10 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 11 | --clean=True --dataset=$3 12 | 13 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12563 \ 14 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 15 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 16 | 17 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12353 \ 18 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 19 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 20 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 21 | 22 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12353 \ 23 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 24 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 25 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 26 | 27 | # sh shell/rocl-HN-cifar10.sh rocl-HN-cifar10 ResNet18 cifar-10 | tee logs/rocl-HN-cifar10.out 28 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-IPHN-cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12560 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN --stop_grad --stpg_degree 0.2 --lr 0.2 6 | 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12560 \ 8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 10 | --clean=True --dataset=$3 11 | 12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12560 \ 13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 15 | 16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12350 \ 17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 20 | 21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12350 \ 22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \ 24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 25 | 26 | # sh shell/rocl-IPHN-cifar100.sh rocl-IPHN-cifar100 ResNet18 cifar-100 | tee logs/rocl-IPHN-cifar100.out 27 | -------------------------------------------------------------------------------- /Asym-RoCL/shell/rocl-IPHN-cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12566 \ 4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \ 5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 \ 6 | --HN --beta 1.0 --tau 0.12 --stop_grad --stpg_degree 0.2 --lr 0.3 7 | 8 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12566 \ 9 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \ 10 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 11 | --clean=True --dataset=$3 12 | 13 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12566 \ 14 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \ 15 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3 16 | 17 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12356 \ 18 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \ 19 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 20 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 21 | 22 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12356 \ 23 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \ 24 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \ 25 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20 26 | 27 | # sh shell/rocl-IPHN-cifar10.sh rocl-IPHN-cifar10 ResNet18 cifar-10 | tee logs/rocl-IPHN-cifar10.out -------------------------------------------------------------------------------- /Asym-AdvCL/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 47 | -------------------------------------------------------------------------------- /Asym-AdvCL/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | 5 | class CIFAR10IndexPseudoLabelEnsemble(Dataset): 6 | def __init__(self, root='', transform=None, download=False, train=True, 7 | pseudoLabel_002=None, 8 | pseudoLabel_010=None, 9 | pseudoLabel_050=None, 10 | pseudoLabel_100=None, 11 | pseudoLabel_500=None, 12 | ): 13 | self.cifar10 = datasets.CIFAR10(root=root, 14 | download=download, 15 | train=train, 16 | transform=transform) 17 | 18 | self.pseudo_label_002 = pseudoLabel_002 19 | self.pseudo_label_010 = pseudoLabel_010 20 | self.pseudo_label_050 = pseudoLabel_050 21 | self.pseudo_label_100 = pseudoLabel_100 22 | self.pseudo_label_500 = pseudoLabel_500 23 | 24 | def __getitem__(self, index): 25 | data, target = self.cifar10[index] 26 | label_p_002 = self.pseudo_label_002[index] 27 | label_p_010 = self.pseudo_label_010[index] 28 | label_p_050 = self.pseudo_label_050[index] 29 | label_p_100 = self.pseudo_label_100[index] 30 | label_p_500 = self.pseudo_label_500[index] 31 | 32 | label_p = (label_p_002, 33 | label_p_010, 34 | label_p_050, 35 | label_p_100, 36 | label_p_500) 37 | return data, target, label_p, index 38 | 39 | def __len__(self): 40 | return len(self.cifar10) 41 | 42 | 43 | class CIFAR100IndexPseudoLabelEnsemble(Dataset): 44 | def __init__(self, root='', transform=None, download=False, train=True, 45 | pseudoLabel_002=None, 46 | pseudoLabel_010=None, 47 | pseudoLabel_050=None, 48 | pseudoLabel_100=None, 49 | pseudoLabel_500=None, 50 | ): 51 | self.cifar100 = datasets.CIFAR100(root=root, 52 | download=download, 53 | train=train, 54 | transform=transform) 55 | 56 | self.pseudo_label_002 = pseudoLabel_002 57 | self.pseudo_label_010 = pseudoLabel_010 58 | self.pseudo_label_050 = pseudoLabel_050 59 | self.pseudo_label_100 = pseudoLabel_100 60 | self.pseudo_label_500 = pseudoLabel_500 61 | 62 | def __getitem__(self, index): 63 | data, target = self.cifar100[index] 64 | label_p_002 = self.pseudo_label_002[index] 65 | label_p_010 = self.pseudo_label_010[index] 66 | label_p_050 = self.pseudo_label_050[index] 67 | label_p_100 = self.pseudo_label_100[index] 68 | label_p_500 = self.pseudo_label_500[index] 69 | 70 | label_p = (label_p_002, 71 | label_p_010, 72 | label_p_050, 73 | label_p_100, 74 | label_p_500) 75 | return data, target, label_p, index 76 | 77 | def __len__(self): 78 | return len(self.cifar100) 79 | -------------------------------------------------------------------------------- /Asym-RoCL/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | def trades_loss(model, 7 | classifier, 8 | x_natural, 9 | x_adv, 10 | y, 11 | optimizer, 12 | beta=6.0, 13 | distance='l_inf', 14 | step_size=2./255., 15 | epsilon=8./255., 16 | perturb_steps=10, 17 | trainmode='adv', 18 | fixmode='', 19 | trades=True 20 | ): 21 | if trainmode == "adv": 22 | batch_size = len(x_natural) 23 | # define KL-loss 24 | criterion_kl = nn.KLDivLoss(size_average=False) 25 | model.eval() 26 | 27 | if trades: 28 | # generate adversarial example 29 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 30 | if distance == 'l_inf': 31 | for _ in range(perturb_steps): 32 | x_adv.requires_grad_() 33 | with torch.enable_grad(): 34 | model.eval() 35 | loss_kl = criterion_kl(F.log_softmax(classifier(model(x_adv)), dim=1), 36 | F.softmax(classifier(model(x_natural)), dim=1)) 37 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 38 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 39 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 40 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 41 | elif distance == 'l_2': 42 | assert False 43 | else: 44 | assert False 45 | 46 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 47 | 48 | # zero gradient 49 | optimizer.zero_grad() 50 | # calculate robust loss 51 | model.train() 52 | 53 | if fixmode == 'f1': 54 | for name, param in model.named_parameters(): 55 | param.requires_grad = True 56 | elif fixmode == 'f2': 57 | # fix previous three layers 58 | for name, param in model.named_parameters(): 59 | if not ("layer4" in name or "fc" in name): 60 | param.requires_grad = False 61 | else: 62 | param.requires_grad = True 63 | elif fixmode == 'f3': 64 | # fix every layer except fc 65 | # fix previous four layers 66 | for name, param in model.named_parameters(): 67 | if not ("fc" in name): 68 | param.requires_grad = False 69 | else: 70 | param.requires_grad = True 71 | else: 72 | assert False 73 | 74 | logits = classifier(model(x_natural)) 75 | 76 | loss = F.cross_entropy(logits, y) 77 | 78 | if trainmode == "adv": 79 | logits_adv = classifier(model(x_adv)) 80 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 81 | F.softmax(logits, dim=1)) 82 | loss += beta * loss_robust 83 | return loss, logits -------------------------------------------------------------------------------- /Asym-RoCL/data/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) 81 | -------------------------------------------------------------------------------- /Asym-RoCL/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /Asym-AdvCL/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def trades_loss(model, 8 | classifier, 9 | x_natural, 10 | x_adv, 11 | y, 12 | optimizer, 13 | beta=6.0, 14 | distance='l_inf', 15 | step_size=2./255., 16 | epsilon=8./255., 17 | perturb_steps=10, 18 | trainmode='adv', 19 | fixmode='', 20 | trades=True 21 | ): 22 | if trainmode == "adv": 23 | batch_size = len(x_natural) 24 | # define KL-loss 25 | criterion_kl = nn.KLDivLoss(size_average=False) 26 | model.eval() 27 | 28 | if trades: 29 | # generate adversarial example 30 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 31 | if distance == 'l_inf': 32 | for _ in range(perturb_steps): 33 | x_adv.requires_grad_() 34 | with torch.enable_grad(): 35 | model.eval() 36 | loss_kl = criterion_kl(F.log_softmax(classifier(model(x_adv, return_feat=True)), dim=1), 37 | F.softmax(classifier(model(x_natural, return_feat=True)), dim=1)) 38 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 39 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 40 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 41 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 42 | elif distance == 'l_2': 43 | assert False 44 | else: 45 | assert False 46 | 47 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 48 | 49 | # zero gradient 50 | optimizer.zero_grad() 51 | # calculate robust loss 52 | model.train() 53 | 54 | if fixmode == 'f1': 55 | for name, param in model.named_parameters(): 56 | param.requires_grad = True 57 | elif fixmode == 'f2': 58 | # fix previous three layers 59 | for name, param in model.named_parameters(): 60 | if not ("layer4" in name or "fc" in name): 61 | param.requires_grad = False 62 | else: 63 | param.requires_grad = True 64 | elif fixmode == 'f3': 65 | # fix every layer except fc 66 | # fix previous four layers 67 | for name, param in model.named_parameters(): 68 | if not ("fc" in name): 69 | param.requires_grad = False 70 | else: 71 | param.requires_grad = True 72 | else: 73 | assert False 74 | 75 | logits = classifier(model(x_natural, return_feat=True)) 76 | 77 | loss = F.cross_entropy(logits, y) 78 | 79 | if trainmode == "adv": 80 | logits_adv = classifier(model(x_adv, return_feat=True)) 81 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 82 | F.softmax(logits, dim=1)) 83 | loss += beta * loss_robust 84 | return loss, logits -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Contrastive Learning via Asymmetric InfoNCE 2 | 3 | intro 4 | 5 | This is a PyTorch implementation of the paper [Adversarial Contrastive Learning via Asymmetric InfoNCE](https://arxiv.org/abs/2207.08374) (ECCV 2022). 6 | 7 | ## Preparation 8 | 9 | In each sub-directory (`./Asym-AdvCL`, `./Asym-RoCL`) for each baseline, please first create `./logs` directories, then download CIFAR datasets into `~/data`. All models will be stored in `./checkpoint`, and logs stored in `./logs`. 10 | 11 | ### Environment requirements 12 | 13 | - PyTorch >= 1.8 14 | - NVIDIA apex 15 | - numpy 16 | - tensorboard_logger 17 | - [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2 18 | - [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr) 19 | - [diffdist](https://github.com/ag14774/diffdist) == 0.1 20 | 21 | This repo is a modification on the [RoCL](https://github.com/Kim-Minseon/RoCL) and [AdvCL](https://github.com/LijieFan/AdvCL) repos. Environments can also be installed according to the requirements of [RoCL](https://github.com/Kim-Minseon/RoCL) and [AdvCL](https://github.com/LijieFan/AdvCL) for experiments of each baseline. 22 | 23 | ## Training 24 | 25 | We provide shells for reproducing our main results in Table 1. The hyperparameter settings of baselines are the same as their original papers reported. All experiments were conducted on 2 Tesla V100 GPUs. 26 | 27 | First, for [RoCL](https://arxiv.org/abs/2006.07589): 28 | 29 | * RoCL on CIFAR10 30 | 31 | ```bash 32 | sh shell/rocl-cifar10.sh rocl-cifar10 ResNet18 cifar-10 | tee logs/rocl-cifar10.log 33 | ``` 34 | 35 | * RoCL-IP (RoCL with inferior positive) on CIFAR10 36 | 37 | ```bash 38 | sh shell/rocl-IP-cifar10.sh rocl-IP-cifar10 ResNet18 cifar-10 | tee logs/rocl-IP-cifar10.log 39 | ``` 40 | 41 | * RoCL-HN (RoCL with hard negative) on CIFAR10 42 | 43 | ```bash 44 | sh shell/rocl-HN-cifar10.sh rocl-HN-cifar10 ResNet18 cifar-10 | tee logs/rocl-HN-cifar10.log 45 | ``` 46 | 47 | * RoCL-IPHN (RoCL with both) on CIFAR10 48 | 49 | ```bash 50 | sh shell/rocl-IPHN-cifar10.sh rocl-IPHN-cifar10 ResNet18 cifar-10 | tee logs/rocl-IPHN-cifar10.log 51 | ``` 52 | 53 | * RoCL on CIFAR100 54 | 55 | ```bash 56 | sh shell/rocl-cifar100.sh rocl-cifar100 ResNet18 cifar-100 | tee logs/rocl-cifar100.log 57 | ``` 58 | 59 | * RoCL-IP on CIFAR100 60 | 61 | ```bash 62 | sh shell/rocl-IP-cifar100.sh rocl-IP-cifar100 ResNet18 cifar-100 | tee logs/rocl-IP-cifar100.log 63 | ``` 64 | 65 | * RoCL-HN on CIFAR100 66 | 67 | ```bash 68 | sh shell/rocl-HN-cifar100.sh rocl-HN-cifar100 ResNet18 cifar-100 | tee logs/rocl-HN-cifar100.log 69 | ``` 70 | 71 | * RoCL-IPHN on CIFAR100 72 | 73 | ```bash 74 | sh shell/rocl-IPHN-cifar100.sh rocl-IPHN-cifar100 ResNet18 cifar-100 | tee logs/rocl-IPHN-cifar100.log 75 | ``` 76 | 77 | For [AdvCL](https://arxiv.org/pdf/2111.01124.pdf): 78 | 79 | * AdvCL on CIFAR10 80 | 81 | ```bash 82 | sh shell/cifar10/advcl-cifar10.sh advcl-cifar10 | tee logs/advcl-cifar10.log 83 | ``` 84 | 85 | * AdvCL-IP on CIFAR10 86 | 87 | ```bash 88 | sh shell/cifar10/advcl-IP-cifar10.sh advcl-IP-cifar10 | tee logs/advcl-IP-cifar10.log 89 | ``` 90 | 91 | * AdvCL-HN on CIFAR10 92 | 93 | ```bash 94 | sh shell/cifar10/advcl-HN-cifar10.sh advcl-HN-cifar10 | tee logs/advcl-HN-cifar10.log 95 | ``` 96 | 97 | * AdvCL-IPHN on CIFAR10 98 | 99 | ```bash 100 | sh shell/cifar10/advcl-IPHN-cifar10.sh advcl-IPHN-cifar10 | tee logs/advcl-IPHN-cifar10.log 101 | ``` 102 | 103 | * AdvCL on CIFAR100 104 | 105 | ```bash 106 | sh shell/cifar100/advcl-cifar100.sh advcl-cifar100 | tee logs/advcl-cifar100.log 107 | ``` 108 | 109 | * AdvCL-IP on CIFAR100 110 | 111 | ```bash 112 | sh shell/cifar100/advcl-IP-cifar100.sh advcl-IP-cifar100 | tee logs/advcl-IP-cifar100.log 113 | ``` 114 | 115 | * AdvCL-HN on CIFAR100 116 | 117 | ```bash 118 | sh shell/cifar100/advcl-HN-cifar100.sh advcl-HN-cifar100 | tee logs/advcl-HN-cifar100.log 119 | ``` 120 | 121 | * AdvCL-IPHN on CIFAR100 122 | 123 | ```bash 124 | sh shell/cifar100/advcl-IPHN-cifar100.sh advcl-IPHN-cifar100 | tee logs/advcl-IPHN-cifar100.log 125 | ``` 126 | 127 | ## Citation 128 | 129 | If you find our work useful or provides some new insights about adversarial contrastive learning:blush:, please consider citing: 130 | 131 | ``` 132 | @inproceedings{yu2022adversarial, 133 | title={Adversarial Contrastive Learning via Asymmetric InfoNCE}, 134 | author={Yu, Qiying and Lou, Jieming and Zhan, Xianyuan and Li, Qizhang and Zuo, Wangmeng and Liu, Yang and Liu, Jingjing}, 135 | booktitle={European Conference on Computer Vision}, 136 | pages={53--69}, 137 | year={2022}, 138 | organization={Springer} 139 | } 140 | ``` 141 | 142 | ## Acknowledgements 143 | 144 | We thank for the code implementation from [RoCL](https://github.com/Kim-Minseon/RoCL), [AdvCL](https://github.com/LijieFan/AdvCL), [HCL](https://github.com/joshr17/HCL) and [SupContrast](https://github.com/HobbitLong/SupContrast). 145 | -------------------------------------------------------------------------------- /Asym-RoCL/robustness_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import csv 6 | import os 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | import data_loader 12 | import model_loader 13 | from argument import test_parser 14 | 15 | from utils import progress_bar 16 | from collections import OrderedDict 17 | 18 | from attack_lib import FastGradientSignUntargeted 19 | 20 | args = test_parser() 21 | use_cuda = torch.cuda.is_available() 22 | if use_cuda: 23 | ngpus_per_node = torch.cuda.device_count() 24 | 25 | def print_status(string): 26 | if args.local_rank % ngpus_per_node ==0: 27 | print(string) 28 | 29 | if args.local_rank % ngpus_per_node == 0: 30 | print(torch.cuda.device_count()) 31 | print('Using CUDA..') 32 | best_acc = 0 # best test accuracy 33 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 34 | 35 | if args.seed != 0: 36 | torch.manual_seed(args.seed) 37 | 38 | # Data 39 | print_status('==> Preparing data..') 40 | 41 | trainloader, traindst, testloader, testdst = data_loader.get_dataset(args) 42 | 43 | if args.dataset == 'cifar-10': 44 | num_outputs = 10 45 | elif args.dataset == 'cifar-100': 46 | num_outputs = 100 47 | 48 | if args.model == 'ResNet50': 49 | expansion = 4 50 | else: 51 | expansion = 1 52 | 53 | # Model 54 | print_status('==> Building model..') 55 | train_type = args.train_type 56 | 57 | model = model_loader.get_model(args)#models.__dict__[args.model]() 58 | if args.dataset=='cifar-10': 59 | Linear = nn.Sequential(nn.Linear(512*expansion, 10)) 60 | elif args.dataset=='cifar-100': 61 | Linear = nn.Sequential(nn.Linear(512*expansion, 100)) 62 | 63 | checkpoint_ = torch.load(args.load_checkpoint) 64 | new_state_dict = OrderedDict() 65 | for k, v in checkpoint_['model'].items(): 66 | name = k[7:] 67 | new_state_dict[name] = v 68 | 69 | model.load_state_dict(new_state_dict) 70 | 71 | linearcheckpoint_ = torch.load(args.load_checkpoint+'_linear') 72 | new_state_dict = OrderedDict() 73 | for k, v in linearcheckpoint_['model'].items(): 74 | name = k[7:] # remove `module.` 75 | new_state_dict[name] = v 76 | Linear.load_state_dict(new_state_dict) 77 | 78 | criterion = nn.CrossEntropyLoss() 79 | 80 | use_cuda = torch.cuda.is_available() 81 | if use_cuda: 82 | ngpus_per_node = torch.cuda.device_count() 83 | model.cuda() 84 | Linear.cuda() 85 | print_status(torch.cuda.device_count()) 86 | print_status('Using CUDA..') 87 | cudnn.benchmark = True 88 | 89 | attack_info = 'epsilon_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(0.0) + '_max_val_' + str(1.0) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start) 90 | print_status("Attack information...") 91 | print_status(attack_info) 92 | attacker = FastGradientSignUntargeted(model, Linear, epsilon=args.epsilon, alpha=args.alpha, min_val=0.0, max_val=1.0, max_iters=args.k, _type=args.attack_type) 93 | 94 | def test(attacker): 95 | global best_acc 96 | 97 | model.eval() 98 | Linear.eval() 99 | 100 | test_clean_loss = 0 101 | test_adv_loss = 0 102 | clean_correct = 0 103 | adv_correct = 0 104 | clean_acc = 0 105 | total = 0 106 | 107 | for idx, (image, label) in enumerate(testloader): 108 | 109 | img = image.cuda() 110 | y = label.cuda() 111 | total += y.size(0) 112 | if 'ResNet18' in args.model: 113 | if args.epsilon==0.0314 or args.epsilon==0.047: 114 | out = Linear(model(img)) 115 | _, predx = torch.max(out.data, 1) 116 | clean_loss = criterion(out, y) 117 | 118 | clean_correct += predx.eq(y.data).cpu().sum().item() 119 | 120 | clean_acc = 100.*clean_correct/total 121 | 122 | test_clean_loss += clean_loss.data 123 | 124 | adv_inputs = attacker.perturb(original_images=img, labels=y, random_start=args.random_start) 125 | 126 | out = Linear(model(adv_inputs)) 127 | 128 | _, predx = torch.max(out.data, 1) 129 | adv_loss = criterion(out, y) 130 | 131 | adv_correct += predx.eq(y.data).cpu().sum().item() 132 | adv_acc = 100.*adv_correct/total 133 | 134 | 135 | test_adv_loss += adv_loss.data 136 | if args.local_rank % ngpus_per_node == 0: 137 | progress_bar(idx, len(testloader),'Testing Loss {:.3f}, acc {:.3f} , adv Loss {:.3f}, adv acc {:.3f}'.format(test_clean_loss/(idx+1), clean_acc, test_adv_loss/(idx+1), adv_acc)) 138 | 139 | print ("Test accuracy: {0}/{1}".format(clean_acc, adv_acc)) 140 | 141 | return (clean_acc, adv_acc) 142 | 143 | test_acc, adv_acc = test(attacker) 144 | 145 | if not os.path.isdir('results'): 146 | os.mkdir('results') 147 | 148 | args.name += ('_Robust'+ args.train_type + '_' +args.model + '_' + args.dataset) 149 | loginfo = 'results/log_' + args.name + '_' + str(args.seed) 150 | logname = (loginfo+ '.csv') 151 | 152 | with open(logname, 'w') as logfile: 153 | logwriter = csv.writer(logfile, delimiter=',') 154 | logwriter.writerow(['random_start', 'attack_type','epsilon','k','adv_acc']) 155 | 156 | if args.local_rank % ngpus_per_node == 0: 157 | with open(logname, 'a') as logfile: 158 | logwriter = csv.writer(logfile, delimiter=',') 159 | logwriter.writerow([args.random_start, args.attack_type, args.epsilon, args.k, adv_acc]) 160 | 161 | -------------------------------------------------------------------------------- /Asym-AdvCL/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | 11 | # _, term_width = os.popen('stty size', 'r').read().split() 12 | # term_width = int(term_width) 13 | term_width = 80 14 | 15 | TOTAL_BAR_LENGTH = 86. 16 | last_time = time.time() 17 | begin_time = last_time 18 | 19 | 20 | def progress_bar(current, total, msg=None): 21 | global last_time, begin_time 22 | if current == 0: 23 | begin_time = time.time() # Reset for new bar. 24 | 25 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 26 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 27 | 28 | sys.stdout.write(' [') 29 | for i in range(cur_len): 30 | sys.stdout.write('=') 31 | sys.stdout.write('>') 32 | for i in range(rest_len): 33 | sys.stdout.write('.') 34 | sys.stdout.write(']') 35 | 36 | cur_time = time.time() 37 | step_time = cur_time - last_time 38 | last_time = cur_time 39 | tot_time = cur_time - begin_time 40 | 41 | L = [] 42 | L.append(' Step: %s' % format_time(step_time)) 43 | L.append(' | Tot: %s' % format_time(tot_time)) 44 | if msg: 45 | L.append(' | ' + msg) 46 | 47 | msg = ''.join(L) 48 | sys.stdout.write(msg) 49 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 50 | sys.stdout.write(' ') 51 | 52 | # Go back to the center of the bar 53 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)): 54 | sys.stdout.write('\b') 55 | sys.stdout.write(' %d/%d ' % (current + 1, total)) 56 | 57 | if current < total - 1: 58 | sys.stdout.write('\r') 59 | else: 60 | sys.stdout.write('\n') 61 | sys.stdout.flush() 62 | 63 | 64 | def load_BN_checkpoint(state_dict): 65 | new_state_dict = {} 66 | new_state_dict_normal = {} 67 | for k, v in state_dict.items(): 68 | if 'downsample' in k: 69 | k = k.replace('downsample', 'shortcut') 70 | if 'shortcut.bn.bn_list.0' in k: 71 | k = k.replace('shortcut.bn.bn_list.0', 'shortcut.1') 72 | new_state_dict_normal[k] = v 73 | elif 'shortcut.bn.bn_list.1' in k: 74 | k = k.replace('shortcut.bn.bn_list.1', 'shortcut.1') 75 | new_state_dict[k] = v 76 | elif '.bn_list.0' in k: 77 | k = k.replace('.bn_list.0', '') 78 | new_state_dict_normal[k] = v 79 | elif '.bn_list.1' in k: 80 | k = k.replace('.bn_list.1', '') 81 | new_state_dict[k] = v 82 | elif 'shortcut.conv' in k: 83 | k = k.replace('shortcut.conv', 'shortcut.0') 84 | new_state_dict_normal[k] = v 85 | new_state_dict[k] = v 86 | else: 87 | new_state_dict_normal[k] = v 88 | new_state_dict[k] = v 89 | return new_state_dict, new_state_dict_normal 90 | 91 | 92 | def format_time(seconds): 93 | days = int(seconds / 3600 / 24) 94 | seconds = seconds - days * 3600 * 24 95 | hours = int(seconds / 3600) 96 | seconds = seconds - hours * 3600 97 | minutes = int(seconds / 60) 98 | seconds = seconds - minutes * 60 99 | secondsf = int(seconds) 100 | seconds = seconds - secondsf 101 | millis = int(seconds * 1000) 102 | 103 | f = '' 104 | i = 1 105 | if days > 0: 106 | f += str(days) + 'D' 107 | i += 1 108 | if hours > 0 and i <= 2: 109 | f += str(hours) + 'h' 110 | i += 1 111 | if minutes > 0 and i <= 2: 112 | f += str(minutes) + 'm' 113 | i += 1 114 | if secondsf > 0 and i <= 2: 115 | f += str(secondsf) + 's' 116 | i += 1 117 | if millis > 0 and i <= 2: 118 | f += str(millis) + 'ms' 119 | i += 1 120 | if f == '': 121 | f = '0ms' 122 | return f 123 | 124 | 125 | class TwoCropTransform: 126 | """Create two crops of the same image""" 127 | 128 | def __init__(self, transform): 129 | self.transform = transform 130 | 131 | def __call__(self, x): 132 | return [self.transform(x), self.transform(x)] 133 | 134 | 135 | class TwoCropTransformAdv: 136 | """Create two crops of the same image""" 137 | 138 | def __init__(self, transform, transform_adv): 139 | self.transform = transform 140 | self.transform_adv = transform_adv 141 | 142 | def __call__(self, x): 143 | return [self.transform(x), self.transform(x), self.transform_adv(x)] 144 | 145 | 146 | class AverageMeter(object): 147 | """Computes and stores the average and current value""" 148 | 149 | def __init__(self): 150 | self.reset() 151 | 152 | def reset(self): 153 | self.val = 0 154 | self.avg = 0 155 | self.sum = 0 156 | self.count = 0 157 | 158 | def update(self, val, n=1): 159 | self.val = val 160 | self.sum += val * n 161 | self.count += n 162 | self.avg = self.sum / self.count 163 | 164 | 165 | def adjust_learning_rate(args, optimizer, epoch): 166 | lr = args.learning_rate 167 | if args.cosine: 168 | eta_min = lr * (args.lr_decay_rate ** 3) 169 | lr = eta_min + (lr - eta_min) * ( 170 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 171 | else: 172 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 173 | if steps > 0: 174 | lr = lr * (args.lr_decay_rate ** steps) 175 | 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = lr 178 | 179 | 180 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 181 | if args.warm and epoch <= args.warm_epochs: 182 | p = (batch_id + (epoch - 1) * total_batches) / \ 183 | (args.warm_epochs * total_batches) 184 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 185 | 186 | for param_group in optimizer.param_groups: 187 | param_group['lr'] = lr 188 | 189 | 190 | def accuracy(output, target, topk=(1,)): 191 | """Computes the accuracy over the k top predictions for the specified values of k""" 192 | with torch.no_grad(): 193 | maxk = max(topk) 194 | batch_size = target.size(0) 195 | 196 | _, pred = output.topk(maxk, 1, True, True) 197 | pred = pred.t() 198 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 199 | 200 | res = [] 201 | for k in topk: 202 | correct_k = correct[:k].flatten().float().sum(0, keepdim=True) 203 | res.append(correct_k.mul_(100.0 / batch_size)) 204 | return res 205 | -------------------------------------------------------------------------------- /Asym-RoCL/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.cifar import CIFAR10, CIFAR100 3 | from torchvision import transforms 4 | 5 | def get_dataset(args): 6 | 7 | ### color augmentation ### 8 | color_jitter = transforms.ColorJitter(0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.2*args.color_jitter_strength) 9 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 10 | rnd_gray = transforms.RandomGrayscale(p=0.2) 11 | 12 | learning_type = args.train_type 13 | 14 | if args.dataset == 'cifar-10': 15 | 16 | if learning_type =='contrastive': 17 | transform_train = transforms.Compose([ 18 | rnd_color_jitter, 19 | rnd_gray, 20 | transforms.RandomHorizontalFlip(), 21 | transforms.RandomResizedCrop(32), 22 | transforms.ToTensor(), 23 | ]) 24 | 25 | transform_test = transform_train 26 | 27 | elif learning_type=='linear_eval': 28 | transform_train = transforms.Compose([ 29 | rnd_color_jitter, 30 | rnd_gray, 31 | transforms.RandomHorizontalFlip(), 32 | transforms.RandomResizedCrop(32), 33 | transforms.ToTensor(), 34 | ]) 35 | 36 | transform_test = transforms.Compose([ 37 | transforms.ToTensor(), 38 | ]) 39 | 40 | elif learning_type=='test': 41 | transform_train = transforms.Compose([ 42 | transforms.RandomHorizontalFlip(), 43 | transforms.RandomResizedCrop(32), 44 | transforms.ToTensor(), 45 | ]) 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | ]) 50 | else: 51 | assert('wrong learning type') 52 | 53 | train_dst = CIFAR10(root='~/data/cifar10/', train=True, download=True, 54 | transform=transform_train,contrastive_learning=learning_type) 55 | val_dst = CIFAR10(root='~/data/cifar10/', train=False, download=True, 56 | transform=transform_test,contrastive_learning=learning_type) 57 | 58 | if learning_type=='contrastive': 59 | train_sampler = torch.utils.data.distributed.DistributedSampler( 60 | train_dst, 61 | num_replicas=args.ngpu, 62 | rank=args.local_rank, 63 | ) 64 | train_loader = torch.utils.data.DataLoader(train_dst,batch_size=args.batch_size,num_workers=4, 65 | pin_memory=False, 66 | shuffle=(train_sampler is None), 67 | sampler=train_sampler, 68 | ) 69 | 70 | val_loader = torch.utils.data.DataLoader(val_dst,batch_size=100, 71 | num_workers=4, 72 | pin_memory=False, 73 | shuffle=False, 74 | ) 75 | 76 | return train_loader, train_dst, val_loader, val_dst, train_sampler 77 | else: 78 | train_loader = torch.utils.data.DataLoader(train_dst, 79 | batch_size=args.batch_size, 80 | shuffle=True, num_workers=4) 81 | val_batch = 100 82 | val_loader = torch.utils.data.DataLoader(val_dst, batch_size=val_batch, 83 | shuffle=False, num_workers=4) 84 | 85 | return train_loader, train_dst, val_loader, val_dst 86 | 87 | if args.dataset == 'cifar-100': 88 | 89 | if learning_type=='contrastive': 90 | transform_train = transforms.Compose([ 91 | rnd_color_jitter, 92 | rnd_gray, 93 | transforms.RandomHorizontalFlip(), 94 | transforms.RandomResizedCrop(32), 95 | transforms.ToTensor() 96 | ]) 97 | 98 | transform_test = transform_train 99 | 100 | elif learning_type=='linear_eval': 101 | transform_train = transforms.Compose([ 102 | rnd_color_jitter, 103 | rnd_gray, 104 | transforms.RandomHorizontalFlip(), 105 | transforms.RandomResizedCrop(32), 106 | transforms.ToTensor() 107 | ]) 108 | 109 | transform_test = transforms.Compose([ 110 | transforms.ToTensor() 111 | ]) 112 | 113 | elif learning_type=='test': 114 | transform_train = transforms.Compose([ 115 | transforms.RandomCrop(32, padding=4), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.ToTensor() 118 | ]) 119 | 120 | transform_test = transforms.Compose([ 121 | transforms.ToTensor() 122 | ]) 123 | else: 124 | assert('wrong learning type') 125 | 126 | train_dst = CIFAR100(root='~/data/cifar100/', train=True, download=True, 127 | transform=transform_train,contrastive_learning=learning_type) 128 | val_dst = CIFAR100(root='~/data/cifar100/', train=False, download=True, 129 | transform=transform_test,contrastive_learning=learning_type) 130 | 131 | if learning_type=='contrastive': 132 | train_sampler = torch.utils.data.distributed.DistributedSampler( 133 | train_dst, 134 | num_replicas=args.ngpu, 135 | rank=args.local_rank, 136 | ) 137 | train_loader = torch.utils.data.DataLoader(train_dst,batch_size=args.batch_size,num_workers=4, 138 | pin_memory=True, 139 | shuffle=(train_sampler is None), 140 | sampler=train_sampler, 141 | ) 142 | 143 | val_loader = torch.utils.data.DataLoader(val_dst,batch_size=100,num_workers=4, 144 | pin_memory=True, 145 | ) 146 | return train_loader, train_dst, val_loader, val_dst, train_sampler 147 | 148 | else: 149 | train_loader = torch.utils.data.DataLoader(train_dst, 150 | batch_size=args.batch_size, 151 | shuffle=True, num_workers=4) 152 | 153 | val_loader = torch.utils.data.DataLoader(val_dst, batch_size=100, 154 | shuffle=False, num_workers=4) 155 | 156 | return train_loader, train_dst, val_loader, val_dst 157 | 158 | -------------------------------------------------------------------------------- /Asym-RoCL/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_planes, planes, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_planes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 32 | nn.BatchNorm2d(self.expansion*planes) 33 | ) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class PreActBlock(nn.Module): 44 | '''Pre-activation version of the BasicBlock.''' 45 | expansion = 1 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBlock, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = conv3x3(in_planes, planes, stride) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = conv3x3(planes, planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(x)) 62 | shortcut = self.shortcut(out) 63 | out = self.conv1(out) 64 | out = self.conv2(F.relu(self.bn2(out))) 65 | out += shortcut 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, in_planes, planes, stride=1): 73 | super(Bottleneck, self).__init__() 74 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 80 | 81 | self.shortcut = nn.Sequential() 82 | if stride != 1 or in_planes != self.expansion*planes: 83 | self.shortcut = nn.Sequential( 84 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(self.expansion*planes) 86 | ) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = F.relu(self.bn2(self.conv2(out))) 91 | out = self.bn3(self.conv3(out)) 92 | out += self.shortcut(x) 93 | out = F.relu(out) 94 | return out 95 | 96 | 97 | class PreActBottleneck(nn.Module): 98 | '''Pre-activation version of the original Bottleneck module.''' 99 | expansion = 4 100 | 101 | def __init__(self, in_planes, planes, stride=1): 102 | super(PreActBottleneck, self).__init__() 103 | self.bn1 = nn.BatchNorm2d(in_planes) 104 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 105 | self.bn2 = nn.BatchNorm2d(planes) 106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 107 | self.bn3 = nn.BatchNorm2d(planes) 108 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 109 | 110 | self.shortcut = nn.Sequential() 111 | if stride != 1 or in_planes != self.expansion*planes: 112 | self.shortcut = nn.Sequential( 113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 114 | ) 115 | 116 | def forward(self, x): 117 | out = F.relu(self.bn1(x)) 118 | shortcut = self.shortcut(out) 119 | out = self.conv1(out) 120 | out = self.conv2(F.relu(self.bn2(out))) 121 | out = self.conv3(F.relu(self.bn3(out))) 122 | out += shortcut 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | def __init__(self, block, num_blocks, num_classes=10, contrastive_learning=True): 128 | super(ResNet, self).__init__() 129 | self.in_planes = 64 130 | 131 | self.conv1 = conv3x3(3,64) 132 | self.bn1 = nn.BatchNorm2d(64) 133 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 134 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 135 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 136 | 137 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 138 | 139 | self.contranstive_learning = contrastive_learning 140 | 141 | if not contrastive_learning: 142 | self.linear = nn.Linear(512*block.expansion, num_classes) 143 | 144 | def _make_layer(self, block, planes, num_blocks, stride): 145 | strides = [stride] + [1]*(num_blocks-1) 146 | layers = [] 147 | for stride in strides: 148 | layers.append(block(self.in_planes, planes, stride)) 149 | self.in_planes = planes * block.expansion 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x, internal_outputs=False): 153 | out = x 154 | out_list = [] 155 | 156 | out = self.conv1(out) 157 | out = self.bn1(out) 158 | out = F.relu(out) 159 | out_list.append(out) 160 | 161 | out = self.layer1(out) 162 | out_list.append(out) 163 | 164 | out = self.layer2(out) 165 | out_list.append(out) 166 | 167 | out = self.layer3(out) 168 | out_list.append(out) 169 | 170 | out = self.layer4(out) 171 | out_list.append(out) 172 | 173 | out = F.avg_pool2d(out, 4) 174 | out = out.view(out.size(0), -1) 175 | 176 | if not self.contranstive_learning: 177 | out = self.linear(out) 178 | 179 | if internal_outputs: 180 | return out, out_list 181 | 182 | return out 183 | 184 | def ResNet18(num_classes, contrastive_learning): 185 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, contrastive_learning=contrastive_learning) 186 | 187 | def ResNet34(num_classes,contrastive_learning): 188 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, contrastive_learning=contrastive_learning) 189 | 190 | def ResNet50(num_classes,contrastive_learning): 191 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, contrastive_learning=contrastive_learning) 192 | -------------------------------------------------------------------------------- /Asym-RoCL/attack_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is modified from 3 | 4 | https://github.com/utkuozbulak/pytorch-cnn-adversarial-attacks 5 | https://github.com/louis2889184/pytorch-adversarial-training 6 | https://github.com/MadryLab/robustness 7 | https://github.com/yaodongyu/TRADES 8 | 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from loss import pairwise_similarity, NT_xent 14 | 15 | def project(x, original_x, epsilon, _type='linf'): 16 | 17 | if _type == 'linf': 18 | max_x = original_x + epsilon 19 | min_x = original_x - epsilon 20 | 21 | x = torch.max(torch.min(x, max_x), min_x) 22 | else: 23 | raise NotImplementedError 24 | 25 | return x 26 | 27 | class FastGradientSignUntargeted(): 28 | """ 29 | Fast gradient sign untargeted adversarial attack, minimizes the initial class activation 30 | with iterative grad sign updates 31 | """ 32 | def __init__(self, model, linear, epsilon, alpha, min_val, max_val, max_iters, _type='linf'): 33 | 34 | # Model 35 | self.model = model 36 | self.linear = linear 37 | # Maximum perturbation 38 | self.epsilon = epsilon 39 | # Movement multiplier per iteration 40 | self.alpha = alpha 41 | # Minimum value of the pixels 42 | self.min_val = min_val 43 | # Maximum value of the pixels 44 | self.max_val = max_val 45 | # Maximum numbers of iteration to generated adversaries 46 | self.max_iters = max_iters 47 | # The perturbation of epsilon 48 | self._type = _type 49 | 50 | def perturb(self, original_images, labels, reduction4loss='mean', random_start=True): 51 | # original_images: values are within self.min_val and self.max_val 52 | # The adversaries created from random close points to the original data 53 | if random_start: 54 | rand_perturb = torch.FloatTensor(original_images.shape).uniform_( 55 | -self.epsilon, self.epsilon) 56 | rand_perturb = rand_perturb.cuda() 57 | x = original_images.clone() + rand_perturb 58 | x = torch.clamp(x,self.min_val, self.max_val) 59 | else: 60 | x = original_images.clone() 61 | 62 | x.requires_grad = True 63 | 64 | self.model.eval() 65 | if not self.linear=='None': 66 | self.linear.eval() 67 | 68 | with torch.enable_grad(): 69 | for _iter in range(self.max_iters): 70 | 71 | self.model.zero_grad() 72 | if not self.linear=='None': 73 | self.linear.zero_grad() 74 | 75 | if self.linear=='None': 76 | outputs = self.model(x) 77 | else: 78 | outputs = self.linear(self.model(x)) 79 | 80 | loss = F.cross_entropy(outputs, labels, reduction=reduction4loss) 81 | 82 | grad_outputs = None 83 | grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, only_inputs=True, retain_graph=False)[0] 84 | 85 | if self._type == 'linf': 86 | scaled_g = torch.sign(grads.data) 87 | 88 | x.data += self.alpha * scaled_g 89 | 90 | x = torch.clamp(x, self.min_val, self.max_val) 91 | x = project(x, original_images, self.epsilon, self._type) 92 | 93 | 94 | return x.detach() 95 | 96 | class RepresentationAdv(): 97 | 98 | def __init__(self, model, projector, epsilon, alpha, min_val, max_val, max_iters, args, _type='linf', loss_type='sim', regularize='original'): 99 | 100 | # Model 101 | self.model = model 102 | self.projector = projector 103 | self.regularize = regularize 104 | # Maximum perturbation 105 | self.epsilon = epsilon 106 | # Movement multiplier per iteration 107 | self.alpha = alpha 108 | # Minimum value of the pixels 109 | self.min_val = min_val 110 | # Maximum value of the pixels 111 | self.max_val = max_val 112 | # Maximum numbers of iteration to generated adversaries 113 | self.max_iters = max_iters 114 | # The perturbation of epsilon 115 | self._type = _type 116 | # loss type 117 | self.loss_type = loss_type 118 | 119 | self.args = args 120 | 121 | 122 | def get_loss(self, original_images, target, optimizer, weight, random_start=True): 123 | if random_start: 124 | rand_perturb = torch.FloatTensor(original_images.shape).uniform_( 125 | -self.epsilon, self.epsilon) 126 | rand_perturb = rand_perturb.float().cuda() 127 | x = original_images.float().clone() + rand_perturb 128 | x = torch.clamp(x,self.min_val, self.max_val) 129 | else: 130 | x = original_images.clone() 131 | 132 | x.requires_grad = True 133 | 134 | self.model.eval() 135 | self.projector.eval() 136 | batch_size = len(x) 137 | 138 | with torch.enable_grad(): 139 | for _iter in range(self.max_iters): 140 | 141 | self.model.zero_grad() 142 | self.projector.zero_grad() 143 | 144 | if self.loss_type == 'mse': 145 | loss = F.mse_loss(self.projector(self.model(x)),self.projector(self.model(target))) 146 | elif self.loss_type == 'sim': 147 | inputs = torch.cat((x, target)) 148 | output = self.projector(self.model(inputs)) 149 | similarity,_ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type = 'None') 150 | loss = NT_xent(similarity, 'None', self.args) 151 | elif self.loss_type == 'l1': 152 | loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target))) 153 | elif self.loss_type =='cos': 154 | loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).mean() 155 | 156 | grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0] 157 | 158 | if self._type == 'linf': 159 | scaled_g = torch.sign(grads.data) 160 | 161 | x.data += self.alpha * scaled_g 162 | 163 | x = torch.clamp(x,self.min_val,self.max_val) 164 | x = project(x, original_images, self.epsilon, self._type) 165 | 166 | self.model.train() 167 | self.projector.train() 168 | optimizer.zero_grad() 169 | 170 | if self.loss_type == 'mse': 171 | loss = F.mse_loss(self.projector(self.model(x)),self.projector(self.model(target))) * (1.0/batch_size) 172 | elif self.loss_type == 'sim': 173 | if self.regularize== 'original': 174 | inputs = torch.cat((x, original_images)) 175 | else: 176 | inputs = torch.cat((x, target)) 177 | output = self.projector(self.model(inputs)) 178 | similarity, _ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type = 'None') 179 | loss = (1.0/weight) * NT_xent(similarity, 'None', self.args) 180 | elif self.loss_type == 'l1': 181 | loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target))) * (1.0/batch_size) 182 | elif self.loss_type == 'cos': 183 | loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).sum() * (1.0/batch_size) 184 | 185 | return x.detach(), loss 186 | -------------------------------------------------------------------------------- /Asym-RoCL/data/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | from torchvision import transforms 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | from .vision import VisionDataset 14 | from .utils import check_integrity, download_and_extract_archive 15 | 16 | 17 | class CIFAR10(VisionDataset): 18 | """`CIFAR10 `_ Dataset. 19 | Args: 20 | root (string): Root directory of dataset where directory 21 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 22 | train (bool, optional): If True, creates dataset from training set, otherwise 23 | creates from test set. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | base_folder = 'cifar-10-batches-py' 33 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 34 | filename = "cifar-10-python.tar.gz" 35 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 36 | train_list = [ 37 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 38 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 39 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 40 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 41 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 42 | ] 43 | 44 | test_list = [ 45 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 46 | ] 47 | meta = { 48 | 'filename': 'batches.meta', 49 | 'key': 'label_names', 50 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 51 | } 52 | 53 | def __init__(self, root, train=True, transform=None, target_transform=None, 54 | download=False, contrastive_learning=False): 55 | 56 | super(CIFAR10, self).__init__(root, transform=transform, 57 | target_transform=target_transform) 58 | 59 | self.train = train # training set or test set 60 | self.learning_type = contrastive_learning 61 | 62 | if download: 63 | self.download() 64 | 65 | if not self._check_integrity(): 66 | raise RuntimeError('Dataset not found or corrupted.' + 67 | ' You can use download=True to download it') 68 | 69 | if self.train: 70 | downloaded_list = self.train_list 71 | else: 72 | downloaded_list = self.test_list 73 | 74 | self.data = [] 75 | self.targets = [] 76 | 77 | # now load the picked numpy arrays 78 | for file_name, checksum in downloaded_list: 79 | file_path = os.path.join(self.root, self.base_folder, file_name) 80 | with open(file_path, 'rb') as f: 81 | if sys.version_info[0] == 2: 82 | entry = pickle.load(f) 83 | else: 84 | entry = pickle.load(f, encoding='latin1') 85 | self.data.append(entry['data']) 86 | if 'labels' in entry: 87 | self.targets.extend(entry['labels']) 88 | else: 89 | self.targets.extend(entry['fine_labels']) 90 | 91 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 92 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 93 | 94 | self._load_meta() 95 | 96 | def _load_meta(self): 97 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 98 | if not check_integrity(path, self.meta['md5']): 99 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 100 | ' You can use download=True to download it') 101 | with open(path, 'rb') as infile: 102 | if sys.version_info[0] == 2: 103 | data = pickle.load(infile) 104 | else: 105 | data = pickle.load(infile, encoding='latin1') 106 | self.classes = data[self.meta['key']] 107 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 108 | 109 | def __getitem__(self, index): 110 | """ 111 | Args: 112 | index (int): Index 113 | Returns: 114 | tuple: (image, target) where target is index of the target class. 115 | """ 116 | img, target = self.data[index], self.targets[index] 117 | ori_img = img 118 | toTensor = transforms.ToTensor() 119 | ori_img = toTensor(ori_img) 120 | 121 | if self.learning_type=='contrastive': 122 | img_2 = img.copy() 123 | 124 | elif self.learning_type=='linear_eval': 125 | if self.train: 126 | img_2 = img.copy() 127 | 128 | # doing this so that it is consistent with all other datasets 129 | # to return a PIL Image 130 | img = Image.fromarray(img) 131 | if self.learning_type=='contrastive': 132 | img_2 = Image.fromarray(img_2) 133 | elif self.learning_type=='linear_eval': 134 | if self.train: 135 | img_2 = img.copy() 136 | 137 | if self.transform is not None: 138 | img = self.transform(img) 139 | if self.learning_type=='contrastive': 140 | img_2 = self.transform(img_2) 141 | elif self.learning_type=='linear_eval': 142 | if self.train: 143 | img_2 = self.transform(img_2) 144 | 145 | if self.target_transform is not None: 146 | target = self.target_transform(target) 147 | 148 | if self.learning_type=='contrastive': 149 | return ori_img, img, img_2, target 150 | elif self.learning_type=='linear_eval': 151 | if self.train: 152 | return ori_img, img, img_2, target 153 | else: 154 | return img, target 155 | else: 156 | return img, target 157 | 158 | def __len__(self): 159 | return len(self.data) 160 | 161 | def _check_integrity(self): 162 | root = self.root 163 | for fentry in (self.train_list + self.test_list): 164 | filename, md5 = fentry[0], fentry[1] 165 | fpath = os.path.join(root, self.base_folder, filename) 166 | if not check_integrity(fpath, md5): 167 | return False 168 | return True 169 | 170 | def download(self): 171 | if self._check_integrity(): 172 | print('Files already downloaded and verified') 173 | return 174 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 175 | 176 | def extra_repr(self): 177 | return "Split: {}".format("Train" if self.train is True else "Test") 178 | 179 | 180 | class CIFAR100(CIFAR10): 181 | """`CIFAR100 `_ Dataset. 182 | This is a subclass of the `CIFAR10` Dataset. 183 | """ 184 | base_folder = 'cifar-100-python' 185 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 186 | filename = "cifar-100-python.tar.gz" 187 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 188 | train_list = [ 189 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 190 | ] 191 | 192 | test_list = [ 193 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 194 | ] 195 | meta = { 196 | 'filename': 'meta', 197 | 'key': 'fine_label_names', 198 | 'md5': '7973b15100ade9c7d40fb424638fde48', 199 | } 200 | 201 | -------------------------------------------------------------------------------- /Asym-RoCL/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | import torch.nn.functional as F 15 | 16 | import numpy as np 17 | # import cv2 18 | import scipy.misc 19 | from itertools import chain 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | def accuracy(output, target, topk=(1,)): 40 | """Computes the accuracy over the k top predictions for the specified values of k""" 41 | with torch.no_grad(): 42 | maxk = max(topk) 43 | batch_size = target.size(0) 44 | 45 | _, pred = output.topk(maxk, 1, True, True) 46 | pred = pred.t() 47 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 48 | 49 | res = [] 50 | for k in topk: 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | def get_mean_and_std(dataset): 56 | '''Compute the mean and std value of dataset.''' 57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 58 | mean = torch.zeros(3) 59 | std = torch.zeros(3) 60 | print('==> Computing mean and std..') 61 | for inputs, targets in dataloader: 62 | for i in range(3): 63 | mean[i] += inputs[:,i,:,:].mean() 64 | std[i] += inputs[:,i,:,:].std() 65 | mean.div_(len(dataset)) 66 | std.div_(len(dataset)) 67 | return mean, std 68 | 69 | def init_params(net): 70 | '''Init layer parameters.''' 71 | for m in net.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | init.kaiming_normal(m.weight, mode='fan_out') 74 | if m.bias: 75 | init.constant(m.bias, 0) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | init.constant(m.weight, 1) 78 | init.constant(m.bias, 0) 79 | elif isinstance(m, nn.Linear): 80 | init.normal(m.weight, std=1e-3) 81 | if m.bias: 82 | init.constant(m.bias, 0) 83 | 84 | 85 | # _, term_width = os.popen('stty size', 'r').read().split() 86 | term_width = int(80) 87 | 88 | TOTAL_BAR_LENGTH = 86. 89 | last_time = time.time() 90 | begin_time = last_time 91 | def progress_bar(current, total, msg=None): 92 | global last_time, begin_time 93 | if current == 0: 94 | begin_time = time.time() # Reset for new bar. 95 | 96 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 97 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 98 | 99 | sys.stdout.write(' [') 100 | for i in range(cur_len): 101 | sys.stdout.write('=') 102 | sys.stdout.write('>') 103 | for i in range(rest_len): 104 | sys.stdout.write('.') 105 | sys.stdout.write(']') 106 | 107 | cur_time = time.time() 108 | step_time = cur_time - last_time 109 | last_time = cur_time 110 | tot_time = cur_time - begin_time 111 | 112 | L = [] 113 | L.append(' Step: %s' % format_time(step_time)) 114 | L.append(' | Tot: %s' % format_time(tot_time)) 115 | if msg: 116 | L.append(' | ' + msg) 117 | 118 | msg = ''.join(L) 119 | sys.stdout.write(msg) 120 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 121 | sys.stdout.write(' ') 122 | 123 | # Go back to the center of the bar. 124 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 125 | sys.stdout.write('\b') 126 | sys.stdout.write(' %d/%d ' % (current+1, total)) 127 | 128 | if current < total-1: 129 | sys.stdout.write('\r') 130 | else: 131 | sys.stdout.write('\n') 132 | sys.stdout.flush() 133 | 134 | def format_time(seconds): 135 | days = int(seconds / 3600/24) 136 | seconds = seconds - days*3600*24 137 | hours = int(seconds / 3600) 138 | seconds = seconds - hours*3600 139 | minutes = int(seconds / 60) 140 | seconds = seconds - minutes*60 141 | secondsf = int(seconds) 142 | seconds = seconds - secondsf 143 | millis = int(seconds*1000) 144 | 145 | f = '' 146 | i = 1 147 | if days > 0: 148 | f += str(days) + 'D' 149 | i += 1 150 | if hours > 0 and i <= 2: 151 | f += str(hours) + 'h' 152 | i += 1 153 | if minutes > 0 and i <= 2: 154 | f += str(minutes) + 'm' 155 | i += 1 156 | if secondsf > 0 and i <= 2: 157 | f += str(secondsf) + 's' 158 | i += 1 159 | if millis > 0 and i <= 2: 160 | f += str(millis) + 'ms' 161 | i += 1 162 | if f == '': 163 | f = '0ms' 164 | return f 165 | 166 | def checkpoint_train(model, acc, epoch, args, optimizer, save_name_add=''): 167 | # Save checkpoint. 168 | print('Saving..') 169 | state = { 170 | 'epoch': epoch, 171 | 'acc': acc, 172 | 'model': model.state_dict(), 173 | 'optimizer_state' : optimizer.state_dict(), 174 | 'rng_state': torch.get_rng_state() 175 | } 176 | 177 | save_name = './checkpoint/ckpt.t7' + args.name 178 | save_name += save_name_add 179 | 180 | if not os.path.isdir('./checkpoint'): 181 | os.mkdir('./checkpoint') 182 | torch.save(state, save_name) 183 | 184 | def checkpoint(model, finetune_type, acc, acc_adv, epoch, args, optimizer, save_name_add=''): 185 | # Save checkpoint. 186 | print('Saving..') 187 | state = { 188 | 'epoch': epoch, 189 | 'acc': acc, 190 | 'acc_adv': acc_adv, 191 | 'finetune_type': finetune_type, 192 | 'model': model.state_dict(), 193 | 'optimizer_state' : optimizer.state_dict(), 194 | 'rng_state': torch.get_rng_state() 195 | } 196 | 197 | save_name = './checkpoint/ckpt.t7' + args.name + finetune_type 198 | save_name += save_name_add 199 | 200 | if not os.path.isdir('./checkpoint'): 201 | os.mkdir('./checkpoint') 202 | torch.save(state, save_name) 203 | 204 | def learning_rate_warmup(optimizer, epoch, args): 205 | """Learning rate warmup for first 10 epoch""" 206 | 207 | lr = args.lr 208 | lr /= 10 209 | lr *= (epoch+1) 210 | 211 | for param_group in optimizer.param_groups: 212 | param_group['lr'] = lr 213 | 214 | class LabelDict(): 215 | def __init__(self, dataset='cifar-10'): 216 | self.dataset = dataset 217 | if dataset == 'cifar-10': 218 | self.label_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 219 | 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 220 | 8: 'ship', 9: 'truck'} 221 | 222 | self.class_dict = {v: k for k, v in self.label_dict.items()} 223 | 224 | def label2class(self, label): 225 | assert label in self.label_dict, 'the label %d is not in %s' % (label, self.dataset) 226 | return self.label_dict[label] 227 | 228 | def class2label(self, _class): 229 | assert isinstance(_class, str) 230 | assert _class in self.class_dict, 'the class %s is not in %s' % (_class, self.dataset) 231 | return self.class_dict[_class] 232 | 233 | def get_highest_incorrect_predict(outputs,targets): 234 | _, sorted_prediction = torch.topk(outputs.data,k=2,dim=1) 235 | 236 | ### correct then second predict, incorrect then highest predict ### 237 | 238 | highest_incorrect_predict = ((sorted_prediction[:,0] == targets).type(torch.cuda.LongTensor) * sorted_prediction[:,1] + (sorted_prediction[:,0] != targets).type(torch.cuda.LongTensor) * sorted_prediction[:,0]).detach() 239 | 240 | return highest_incorrect_predict 241 | 242 | -------------------------------------------------------------------------------- /Asym-AdvCL/models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | import numpy as np 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes=10, feat_dim=128, low_freq=False, high_freq=False, radius=0): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | # self.normalize = NormalizeByChannelMeanStd( 76 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.linear = nn.Linear(512*block.expansion, num_classes) 85 | 86 | 87 | # self.linear_contrast = nn.Linear(512*block.expansion, 128) 88 | dim_in = 512*block.expansion 89 | self.head_proj = nn.Sequential( 90 | nn.Linear(dim_in, dim_in), 91 | # nn.BatchNorm1d(dim_in), 92 | nn.ReLU(inplace=True), 93 | nn.Linear(dim_in, 128) 94 | ) 95 | 96 | self.head_pred = nn.Sequential( 97 | nn.Linear(128, 128), 98 | # nn.BatchNorm1d(128), 99 | nn.ReLU(inplace=True), 100 | nn.Linear(128, 128) 101 | ) 102 | 103 | self.low_freq = low_freq 104 | self.high_freq = high_freq 105 | self.radius = radius 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1]*(num_blocks-1) 109 | layers = [] 110 | for stride in strides: 111 | layers.append(block(self.in_planes, planes, stride)) 112 | self.in_planes = planes * block.expansion 113 | return nn.Sequential(*layers) 114 | 115 | def distance(self, i, j, imageSize, r): 116 | dis = np.sqrt((i - imageSize / 2) ** 2 + (j - imageSize / 2) ** 2) 117 | if dis < r: 118 | return 1.0 119 | else: 120 | return 0 121 | def mask_radial(self, img, r): 122 | rows, cols = img.shape 123 | mask = torch.zeros((rows, cols)) 124 | for i in range(rows): 125 | for j in range(cols): 126 | mask[i, j] = self.distance(i, j, imageSize=rows, r=r) 127 | return mask.cuda() 128 | def filter_low(self, Images, r): 129 | mask = self.mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r) 130 | bs, c, h, w = Images.shape 131 | x = Images.reshape([bs*c, h, w]) 132 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1))) 133 | mask = mask.unsqueeze(0).repeat([bs*c, 1, 1]) 134 | fd = fd * mask 135 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1)) 136 | fd = torch.real(fd) 137 | fd = fd.reshape([bs, c, h, w]) 138 | return fd 139 | 140 | def filter_high(self, Images, r): 141 | mask = self.mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r) 142 | bs, c, h, w = Images.shape 143 | x = Images.reshape([bs * c, h, w]) 144 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1))) 145 | mask = mask.unsqueeze(0).repeat([bs * c, 1, 1]) 146 | fd = fd * (1. - mask) 147 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1)) 148 | fd = torch.real(fd) 149 | fd = fd.reshape([bs, c, h, w]) 150 | return fd 151 | 152 | # return np.array(Images_freq_low), np.array(Images_freq_high) 153 | 154 | def forward(self, x, contrast=False, return_feat=False, CF=False): 155 | # img_org = x[0] 156 | # x = self.normalize(x) 157 | 158 | 159 | if self.low_freq: 160 | x = self.filter_low(x, self.radius) 161 | x = torch.clamp(x, 0, 1) 162 | 163 | if self.high_freq: 164 | x = self.filter_high(x, self.radius) 165 | x = torch.clamp(x, 0, 1) 166 | 167 | # img_filter = x[0] 168 | # import cv2 169 | # img_org = img_org.detach().cpu().numpy()*255. 170 | # img_filter = img_filter.detach().cpu().numpy()*255. 171 | # cv2.imwrite('org.jpg', img_org.transpose([1,2,0])) 172 | # cv2.imwrite('filter.jpg', img_filter.transpose([1,2,0])) 173 | # exit(0) 174 | 175 | out = F.relu(self.bn1(self.conv1(x))) 176 | out = self.layer1(out) 177 | out = self.layer2(out) 178 | out = self.layer3(out) 179 | out = self.layer4(out) 180 | out = F.avg_pool2d(out, 4) 181 | out = out.view(out.size(0), -1) 182 | feat = out 183 | if return_feat: 184 | return out 185 | if contrast: 186 | # out = self.linear_contrast(out) 187 | proj = self.head_proj(out) 188 | pred = self.head_pred(proj) 189 | proj = F.normalize(proj, dim=1) 190 | pred = F.normalize(pred, dim=1) 191 | if CF: 192 | return proj, pred, feat 193 | else: 194 | return proj, pred 195 | else: 196 | out = self.linear(out) 197 | return out 198 | 199 | 200 | def ResNet18(num_class=10, radius=8, low_freq=False, high_freq=False): 201 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_class, radius=radius, low_freq=low_freq, high_freq=high_freq) 202 | 203 | def ResNet34(num_class=10): 204 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_class) 205 | 206 | def ResNet50(num_class=10): 207 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_class) 208 | 209 | def ResNet101(): 210 | return ResNet(Bottleneck, [3,4,23,3]) 211 | 212 | def ResNet152(): 213 | return ResNet(Bottleneck, [3,8,36,3]) 214 | 215 | 216 | def test(): 217 | net = ResNet18() 218 | y = net(Variable(torch.randn(1,3,32,32))) 219 | print(y.size()) 220 | 221 | # test() -------------------------------------------------------------------------------- /Asym-RoCL/rocl_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import csv 6 | import os 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | import torch.optim as optim 12 | import data_loader 13 | import model_loader 14 | 15 | from attack_lib import FastGradientSignUntargeted,RepresentationAdv 16 | 17 | from models.projector import Projector 18 | from argument import parser, print_args 19 | from utils import progress_bar, checkpoint, checkpoint_train, AverageMeter, accuracy 20 | 21 | from loss import pairwise_similarity_train, NT_xent, NT_xent_HN 22 | from torchlars import LARS 23 | from warmup_scheduler import GradualWarmupScheduler 24 | 25 | args = parser() 26 | 27 | def print_status(string): 28 | if args.local_rank % ngpus_per_node == 0: 29 | print(string) 30 | 31 | ngpus_per_node = torch.cuda.device_count() 32 | if args.ngpu>1: 33 | multi_gpu=True 34 | elif args.ngpu==1: 35 | multi_gpu=False 36 | else: 37 | assert("Need GPU....") 38 | if args.local_rank % ngpus_per_node == 0: 39 | print_args(args) 40 | 41 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 42 | 43 | if args.seed != 0: 44 | torch.manual_seed(args.seed) 45 | 46 | world_size = args.ngpu 47 | torch.distributed.init_process_group( 48 | 'nccl', 49 | init_method='env://', 50 | world_size=world_size, 51 | rank=args.local_rank, 52 | ) 53 | 54 | # Data 55 | print_status('==> Preparing data..') 56 | if not (args.train_type=='contrastive'): 57 | assert('wrong train phase...') 58 | else: 59 | trainloader, traindst, testloader, testdst ,train_sampler = data_loader.get_dataset(args) 60 | 61 | # Model 62 | print_status('==> Building model..') 63 | torch.cuda.set_device(args.local_rank) 64 | model = model_loader.get_model(args) 65 | 66 | if args.model=='ResNet18': 67 | expansion=1 68 | elif args.model=='ResNet50': 69 | expansion=4 70 | else: 71 | assert('wrong model type') 72 | projector = Projector(expansion=expansion) 73 | 74 | if 'Rep' in args.advtrain_type: 75 | Rep_info = 'Rep_attack_ep_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(args.min) + '_max_val_' + str(args.max) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start) 76 | args.name += Rep_info 77 | 78 | print_status("Representation attack info ...") 79 | print_status(Rep_info) 80 | Rep = RepresentationAdv(model, projector, epsilon=args.epsilon, alpha=args.alpha, min_val=args.min, max_val=args.max, max_iters=args.k, args=args, _type=args.attack_type, loss_type=args.loss_type, regularize = args.regularize_to) 81 | else: 82 | assert('wrong adversarial train type') 83 | 84 | # Model upload to GPU # 85 | model.cuda() 86 | projector.cuda() 87 | 88 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 89 | model = torch.nn.parallel.DistributedDataParallel( 90 | model, 91 | device_ids=[args.local_rank], 92 | output_device=args.local_rank, 93 | find_unused_parameters=True, 94 | ) 95 | projector = torch.nn.parallel.DistributedDataParallel( 96 | projector, 97 | device_ids=[args.local_rank], 98 | output_device=args.local_rank, 99 | find_unused_parameters=True, 100 | ) 101 | 102 | cudnn.benchmark = True 103 | print_status('Using CUDA..') 104 | 105 | # Aggregating model parameter & projection parameter # 106 | model_params = [] 107 | model_params += model.parameters() 108 | model_params += projector.parameters() 109 | 110 | # LARS optimizer from KAKAO-BRAIN github "pip install torchlars" or git from https://github.com/kakaobrain/torchlars 111 | base_optimizer = optim.SGD(model_params, lr=args.lr, momentum=0.9, weight_decay=args.decay) 112 | optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) 113 | 114 | # Cosine learning rate annealing (SGDR) & Learning rate warmup git from https://github.com/ildoonet/pytorch-gradual-warmup-lr # 115 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch) 116 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=args.lr_multiplier, total_epoch=10, after_scheduler=scheduler_cosine) 117 | 118 | def train(epoch): 119 | if args.local_rank % ngpus_per_node == 0: 120 | print('\nEpoch: %d' % epoch) 121 | 122 | model.train() 123 | projector.train() 124 | 125 | train_sampler.set_epoch(epoch) 126 | scheduler_warmup.step() 127 | 128 | total_loss = 0 129 | reg_simloss = 0 130 | reg_loss = 0 131 | 132 | for batch_idx, (ori, inputs_1, inputs_2, label) in enumerate(trainloader): 133 | ori, inputs_1, inputs_2 = ori.cuda(), inputs_1.cuda() ,inputs_2.cuda() 134 | 135 | if args.attack_to=='original': 136 | attack_target = inputs_1 137 | else: 138 | attack_target = inputs_2 139 | 140 | if 'Rep' in args.advtrain_type : 141 | advinputs, adv_loss = Rep.get_loss(original_images=inputs_1, target = attack_target, optimizer=optimizer, weight= args.lamda, random_start=args.random_start) 142 | reg_loss += adv_loss.data 143 | 144 | if not (args.advtrain_type == 'None'): 145 | inputs = torch.cat((inputs_1, inputs_2, advinputs)) 146 | else: 147 | inputs = torch.cat((inputs_1, inputs_2)) 148 | 149 | outputs = projector(model(inputs)) 150 | 151 | similarity, similarity_grad, gathered_outputs = pairwise_similarity_train(outputs, args, temperature=args.temperature, multi_gpu=multi_gpu, adv_type = args.advtrain_type) 152 | 153 | if args.HN: 154 | simloss = NT_xent_HN(similarity, args.advtrain_type, args) 155 | else: 156 | simloss = NT_xent(similarity, args.advtrain_type, args) 157 | 158 | if not (args.advtrain_type=='None'): 159 | loss = simloss + adv_loss 160 | else: 161 | loss = simloss 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | total_loss += loss.data 166 | reg_simloss += simloss.data 167 | 168 | optimizer.step() 169 | 170 | if (args.local_rank % ngpus_per_node == 0): 171 | if 'Rep' in args.advtrain_type: 172 | progress_bar(batch_idx, len(trainloader), 173 | 'Loss: %.3f | SimLoss: %.3f | Adv: %.2f' 174 | % (total_loss / (batch_idx + 1), reg_simloss / (batch_idx + 1), reg_loss / (batch_idx + 1))) 175 | else: 176 | progress_bar(batch_idx, len(trainloader), 177 | 'Loss: %.3f | Adv: %.3f' 178 | % (total_loss/(batch_idx+1), reg_simloss/(batch_idx+1))) 179 | 180 | return (total_loss/batch_idx, reg_simloss/batch_idx) 181 | 182 | 183 | def test(epoch, train_loss): 184 | model.eval() 185 | projector.eval() 186 | 187 | # Save at the last epoch # 188 | if epoch == args.epoch - 1 and args.local_rank % ngpus_per_node == 0: 189 | checkpoint_train(model, train_loss, epoch, args, optimizer) 190 | checkpoint_train(projector, train_loss, epoch, args, optimizer, save_name_add='_projector') 191 | 192 | # Save at every 100 epoch # 193 | elif epoch % 100 == 0 and args.local_rank % ngpus_per_node == 0: 194 | checkpoint_train(model, train_loss, epoch, args, optimizer, save_name_add='_epoch_'+str(epoch)) 195 | checkpoint_train(projector, train_loss, epoch, args, optimizer, save_name_add=('_projector_epoch_' + str(epoch))) 196 | 197 | 198 | # Log and saving checkpoint information # 199 | if not os.path.isdir('results') and args.local_rank % ngpus_per_node == 0: 200 | os.mkdir('results') 201 | 202 | args.name += (args.train_type + '_' +args.model + '_' + args.dataset + '_b' + str(args.batch_size)+'_nGPU'+str(args.ngpu)+'_l'+str(args.lamda)) 203 | if args.HN: 204 | args.name += '_HN' 205 | loginfo = 'results/log_' + args.name + '_' + str(args.seed) 206 | logname = (loginfo+ '.csv') 207 | print_status('Training info...') 208 | print_status(loginfo) 209 | 210 | ##### Log file ##### 211 | if args.local_rank % ngpus_per_node == 0: 212 | with open(logname, 'w') as logfile: 213 | logwriter = csv.writer(logfile, delimiter=',') 214 | logwriter.writerow(['epoch', 'train loss', 'reg loss']) 215 | 216 | print(args.name) 217 | 218 | ##### Training ##### 219 | for epoch in range(start_epoch, args.epoch): 220 | train_loss, reg_loss = train(epoch) 221 | test(epoch, train_loss) 222 | 223 | if args.local_rank % ngpus_per_node == 0: 224 | with open(logname, 'a') as logfile: 225 | logwriter = csv.writer(logfile, delimiter=',') 226 | logwriter.writerow([epoch, train_loss.item(), reg_loss.item()]) 227 | 228 | 229 | -------------------------------------------------------------------------------- /Asym-RoCL/loss.py: -------------------------------------------------------------------------------- 1 | import nturl2path 2 | import diffdist.functional as distops 3 | import torch 4 | import torch.distributed as dist 5 | import numpy as np 6 | 7 | def pairwise_similarity(outputs,temperature=0.5,multi_gpu=False, adv_type='None'): 8 | ''' 9 | Compute pairwise similarity and return the matrix 10 | input: aggregated outputs & temperature for scaling 11 | return: pairwise cosine similarity 12 | ''' 13 | if multi_gpu and adv_type=='None': 14 | 15 | B = int(outputs.shape[0]/2) 16 | 17 | outputs_1 = outputs[0:B] 18 | outputs_2 = outputs[B:] 19 | 20 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())] 21 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1) 22 | 23 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())] 24 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2) 25 | 26 | outputs_1 = torch.cat(gather_t_1) 27 | outputs_2 = torch.cat(gather_t_2) 28 | outputs = torch.cat((outputs_1,outputs_2)) 29 | elif multi_gpu and 'Rep' in adv_type: 30 | if adv_type == 'Rep': 31 | N=3 32 | B = int(outputs.shape[0]/N) 33 | 34 | outputs_1 = outputs[0:B] 35 | outputs_2 = outputs[B:2*B] 36 | outputs_3 = outputs[2*B:3*B] 37 | 38 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())] 39 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1) 40 | 41 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())] 42 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2) 43 | 44 | gather_t_3 = [torch.empty_like(outputs_3) for i in range(dist.get_world_size())] 45 | gather_t_3 = distops.all_gather(gather_t_3, outputs_3) 46 | 47 | outputs_1 = torch.cat(gather_t_1) 48 | outputs_2 = torch.cat(gather_t_2) 49 | outputs_3 = torch.cat(gather_t_3) 50 | 51 | if N==3: 52 | outputs = torch.cat((outputs_1,outputs_2,outputs_3)) 53 | 54 | B = outputs.shape[0] 55 | outputs_norm = outputs/(outputs.norm(dim=1).view(B,1) + 1e-8) 56 | similarity_matrix = (1./temperature) * torch.mm(outputs_norm,outputs_norm.transpose(0,1).detach()) 57 | 58 | return similarity_matrix, outputs 59 | 60 | def pairwise_similarity_train(outputs, args, temperature=0.5,multi_gpu=False, adv_type='None'): 61 | ''' 62 | Compute pairwise similarity and return the matrix 63 | input: aggregated outputs & temperature for scaling 64 | return: pairwise cosine similarity 65 | 66 | ''' 67 | if multi_gpu and adv_type=='None': 68 | 69 | B = int(outputs.shape[0]/2) 70 | 71 | outputs_1 = outputs[0:B] 72 | outputs_2 = outputs[B:] 73 | 74 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())] 75 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1) 76 | 77 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())] 78 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2) 79 | 80 | outputs_1 = torch.cat(gather_t_1) 81 | outputs_2 = torch.cat(gather_t_2) 82 | outputs = torch.cat((outputs_1,outputs_2)) 83 | elif multi_gpu and 'Rep' in adv_type: 84 | if adv_type == 'Rep': 85 | N=3 86 | B = int(outputs.shape[0]/N) 87 | 88 | outputs_1 = outputs[0:B] 89 | outputs_2 = outputs[B:2*B] 90 | outputs_3 = outputs[2*B:3*B] 91 | 92 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())] 93 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1) 94 | 95 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())] 96 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2) 97 | 98 | gather_t_3 = [torch.empty_like(outputs_3) for i in range(dist.get_world_size())] 99 | gather_t_3 = distops.all_gather(gather_t_3, outputs_3) 100 | 101 | outputs_1 = torch.cat(gather_t_1) 102 | outputs_2 = torch.cat(gather_t_2) 103 | outputs_3 = torch.cat(gather_t_3) 104 | 105 | if N==3: 106 | outputs = torch.cat((outputs_1,outputs_2,outputs_3)) 107 | 108 | B = outputs.shape[0] 109 | outputs_norm = outputs/(outputs.norm(dim=1).view(B,1) + 1e-8) 110 | similarity_matrix = (1./temperature) * torch.mm(outputs_norm,outputs_norm.transpose(0,1).detach()) 111 | 112 | return similarity_matrix, None, outputs 113 | 114 | def NT_xent(similarity_matrix, adv_type, args): 115 | """ 116 | Compute NT_xent loss 117 | input: pairwise-similarity matrix 118 | return: NT xent loss 119 | """ 120 | 121 | N2 = len(similarity_matrix) 122 | if adv_type=='None': 123 | N = int(len(similarity_matrix) / 2) 124 | contrast_num = 2 125 | elif adv_type=='Rep': # [inp1, inp2, adv1] 126 | N = int(len(similarity_matrix) / 3) 127 | contrast_num = 3 128 | 129 | # Removing diagonal # 130 | similarity_matrix_exp = torch.exp(similarity_matrix) 131 | similarity_matrix_exp = similarity_matrix_exp * (1 - torch.eye(N2,N2)).cuda() 132 | NT_xent_loss = - torch.log(similarity_matrix_exp/(torch.sum(similarity_matrix_exp,dim=1).view(N2,1) + 1e-8) + 1e-8) 133 | 134 | if adv_type =='None': 135 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:]) + torch.diag(NT_xent_loss[N:,0:N])) 136 | elif adv_type =='Rep': 137 | if args.stop_grad: 138 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N]) 139 | + args.adv_weight * 140 | ((torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[N:2*N,2*N:])) * args.stpg_degree 141 | + (torch.diag(NT_xent_loss[2*N:,0:N]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) * (1-args.stpg_degree)) 142 | ) 143 | else: 144 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N]) 145 | + torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[2*N:,0:N]) 146 | + torch.diag(NT_xent_loss[N:2*N,2*N:]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) 147 | return NT_xent_loss_total 148 | 149 | def NT_xent_HN(similarity_matrix, adv_type, args): 150 | """ 151 | Compute NT_xent loss 152 | input: pairwise-similarity matrix 153 | return: NT xent loss 154 | """ 155 | 156 | N2 = len(similarity_matrix) 157 | if adv_type=='None': 158 | N = int(len(similarity_matrix) / 2) 159 | contrast_num = 2 160 | elif adv_type=='Rep': # [inp1, inp2, adv1] 161 | N = int(len(similarity_matrix) / 3) 162 | contrast_num = 3 163 | 164 | # Removing diagonal # 165 | similarity_matrix_exp = torch.exp(similarity_matrix) 166 | # similarity_matrix_exp = similarity_matrix_exp * (1 - torch.eye(N2,N2)).cuda() 167 | # NT_xent_loss = - torch.log(similarity_matrix_exp/(torch.sum(similarity_matrix_exp,dim=1).view(N2,1) + 1e-8) + 1e-8) 168 | # tau_plus = 0.1 169 | # beta = 1.0 170 | tau_plus = args.tau 171 | beta = args.beta 172 | temperature = 0.5 173 | N_neg = (N - 1) * contrast_num 174 | mask = torch.eye(N, dtype=torch.float32).cuda() 175 | mask = mask.repeat(contrast_num, contrast_num) 176 | logits_mask = torch.scatter( 177 | torch.ones_like(mask), 178 | 1, 179 | torch.arange(N * contrast_num).view(-1, 1).cuda(), 180 | 0 181 | ) 182 | mask = mask * logits_mask 183 | # =============== reweight neg ================= 184 | # for numerical stability 185 | exp_logits_neg = similarity_matrix_exp * (1 - mask) * logits_mask 186 | exp_logits_pos = similarity_matrix_exp * mask 187 | pos = exp_logits_pos.sum(dim=1) / mask.sum(1) 188 | 189 | imp = (beta * (exp_logits_neg + 1e-9).log()).exp() 190 | # imp = exp_logits_neg ** beta 191 | # imp = exp_logits_neg 192 | reweight_logits_neg = (imp * exp_logits_neg) / imp.mean(dim=-1) 193 | Ng = (-tau_plus * N_neg * pos + reweight_logits_neg.sum(dim=-1)) / (1 - tau_plus) # [4 bsz, 1] 194 | # constrain (optional) 195 | Ng = torch.clamp(Ng, min=N_neg * np.e**(-1 / temperature)) 196 | NT_xent_loss = - torch.log(similarity_matrix_exp / ((pos + Ng).view(N2,1))) 197 | # =============================================== 198 | 199 | if adv_type =='None': 200 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:]) + torch.diag(NT_xent_loss[N:,0:N])) 201 | elif adv_type =='Rep': 202 | if args.stop_grad: 203 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N]) 204 | + args.adv_weight * 205 | ((torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[N:2*N,2*N:])) * args.stpg_degree 206 | + (torch.diag(NT_xent_loss[2*N:,0:N]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) * (1-args.stpg_degree)) 207 | ) 208 | else: 209 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N]) 210 | + torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[2*N:,0:N]) 211 | + torch.diag(NT_xent_loss[N:2*N,2*N:]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) 212 | return NT_xent_loss_total -------------------------------------------------------------------------------- /Asym-RoCL/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import gzip 5 | import errno 6 | import tarfile 7 | import zipfile 8 | 9 | import torch 10 | from torch.utils.model_zoo import tqdm 11 | from torch._six import PY37 12 | 13 | 14 | def gen_bar_updater(): 15 | pbar = tqdm(total=None) 16 | 17 | def bar_update(count, block_size, total_size): 18 | if pbar.total is None and total_size: 19 | pbar.total = total_size 20 | progress_bytes = count * block_size 21 | pbar.update(progress_bytes - pbar.n) 22 | 23 | return bar_update 24 | 25 | 26 | def calculate_md5(fpath, chunk_size=1024 * 1024): 27 | md5 = hashlib.md5() 28 | with open(fpath, 'rb') as f: 29 | for chunk in iter(lambda: f.read(chunk_size), b''): 30 | md5.update(chunk) 31 | return md5.hexdigest() 32 | 33 | 34 | def check_md5(fpath, md5, **kwargs): 35 | return md5 == calculate_md5(fpath, **kwargs) 36 | 37 | 38 | def check_integrity(fpath, md5=None): 39 | if not os.path.isfile(fpath): 40 | return False 41 | if md5 is None: 42 | return True 43 | return check_md5(fpath, md5) 44 | 45 | 46 | def makedir_exist_ok(dirpath): 47 | """ 48 | Python2 support for os.makedirs(.., exist_ok=True) 49 | """ 50 | try: 51 | os.makedirs(dirpath) 52 | except OSError as e: 53 | if e.errno == errno.EEXIST: 54 | pass 55 | else: 56 | raise 57 | 58 | 59 | def download_url(url, root, filename=None, md5=None): 60 | """Download a file from a url and place it in root. 61 | Args: 62 | url (str): URL to download file from 63 | root (str): Directory to place downloaded file in 64 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 65 | md5 (str, optional): MD5 checksum of the download. If None, do not check 66 | """ 67 | from six.moves import urllib 68 | 69 | root = os.path.expanduser(root) 70 | if not filename: 71 | filename = os.path.basename(url) 72 | fpath = os.path.join(root, filename) 73 | 74 | makedir_exist_ok(root) 75 | 76 | # check if file is already present locally 77 | if check_integrity(fpath, md5): 78 | print('Using downloaded and verified file: ' + fpath) 79 | else: # download the file 80 | try: 81 | print('Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | except (urllib.error.URLError, IOError) as e: 87 | if url[:5] == 'https': 88 | url = url.replace('https:', 'http:') 89 | print('Failed download. Trying https -> http instead.' 90 | ' Downloading ' + url + ' to ' + fpath) 91 | urllib.request.urlretrieve( 92 | url, fpath, 93 | reporthook=gen_bar_updater() 94 | ) 95 | else: 96 | raise e 97 | # check integrity of downloaded file 98 | if not check_integrity(fpath, md5): 99 | raise RuntimeError("File not found or corrupted.") 100 | 101 | 102 | def list_dir(root, prefix=False): 103 | """List all directories at a given root 104 | Args: 105 | root (str): Path to directory whose folders need to be listed 106 | prefix (bool, optional): If true, prepends the path to each result, otherwise 107 | only returns the name of the directories found 108 | """ 109 | root = os.path.expanduser(root) 110 | directories = list( 111 | filter( 112 | lambda p: os.path.isdir(os.path.join(root, p)), 113 | os.listdir(root) 114 | ) 115 | ) 116 | 117 | if prefix is True: 118 | directories = [os.path.join(root, d) for d in directories] 119 | 120 | return directories 121 | 122 | 123 | def list_files(root, suffix, prefix=False): 124 | """List all files ending with a suffix at a given root 125 | Args: 126 | root (str): Path to directory whose folders need to be listed 127 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 128 | It uses the Python "str.endswith" method and is passed directly 129 | prefix (bool, optional): If true, prepends the path to each result, otherwise 130 | only returns the name of the files found 131 | """ 132 | root = os.path.expanduser(root) 133 | files = list( 134 | filter( 135 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 136 | os.listdir(root) 137 | ) 138 | ) 139 | 140 | if prefix is True: 141 | files = [os.path.join(root, d) for d in files] 142 | 143 | return files 144 | 145 | 146 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 147 | """Download a Google Drive file from and place it in root. 148 | Args: 149 | file_id (str): id of file to be downloaded 150 | root (str): Directory to place downloaded file in 151 | filename (str, optional): Name to save the file under. If None, use the id of the file. 152 | md5 (str, optional): MD5 checksum of the download. If None, do not check 153 | """ 154 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 155 | import requests 156 | url = "https://docs.google.com/uc?export=download" 157 | 158 | root = os.path.expanduser(root) 159 | if not filename: 160 | filename = file_id 161 | fpath = os.path.join(root, filename) 162 | 163 | makedir_exist_ok(root) 164 | 165 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 166 | print('Using downloaded and verified file: ' + fpath) 167 | else: 168 | session = requests.Session() 169 | 170 | response = session.get(url, params={'id': file_id}, stream=True) 171 | token = _get_confirm_token(response) 172 | 173 | if token: 174 | params = {'id': file_id, 'confirm': token} 175 | response = session.get(url, params=params, stream=True) 176 | 177 | _save_response_content(response, fpath) 178 | 179 | 180 | def _get_confirm_token(response): 181 | for key, value in response.cookies.items(): 182 | if key.startswith('download_warning'): 183 | return value 184 | 185 | return None 186 | 187 | 188 | def _save_response_content(response, destination, chunk_size=32768): 189 | with open(destination, "wb") as f: 190 | pbar = tqdm(total=None) 191 | progress = 0 192 | for chunk in response.iter_content(chunk_size): 193 | if chunk: # filter out keep-alive new chunks 194 | f.write(chunk) 195 | progress += len(chunk) 196 | pbar.update(progress - pbar.n) 197 | pbar.close() 198 | 199 | 200 | def _is_tarxz(filename): 201 | return filename.endswith(".tar.xz") 202 | 203 | 204 | def _is_tar(filename): 205 | return filename.endswith(".tar") 206 | 207 | 208 | def _is_targz(filename): 209 | return filename.endswith(".tar.gz") 210 | 211 | 212 | def _is_tgz(filename): 213 | return filename.endswith(".tgz") 214 | 215 | 216 | def _is_gzip(filename): 217 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 218 | 219 | 220 | def _is_zip(filename): 221 | return filename.endswith(".zip") 222 | 223 | 224 | def extract_archive(from_path, to_path=None, remove_finished=False): 225 | if to_path is None: 226 | to_path = os.path.dirname(from_path) 227 | 228 | if _is_tar(from_path): 229 | with tarfile.open(from_path, 'r') as tar: 230 | tar.extractall(path=to_path) 231 | elif _is_targz(from_path) or _is_tgz(from_path): 232 | with tarfile.open(from_path, 'r:gz') as tar: 233 | tar.extractall(path=to_path) 234 | elif _is_tarxz(from_path) and PY3: 235 | # .tar.xz archive only supported in Python 3.x 236 | with tarfile.open(from_path, 'r:xz') as tar: 237 | tar.extractall(path=to_path) 238 | elif _is_gzip(from_path): 239 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 240 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 241 | out_f.write(zip_f.read()) 242 | elif _is_zip(from_path): 243 | with zipfile.ZipFile(from_path, 'r') as z: 244 | z.extractall(to_path) 245 | else: 246 | raise ValueError("Extraction of {} not supported".format(from_path)) 247 | 248 | if remove_finished: 249 | os.remove(from_path) 250 | 251 | 252 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 253 | md5=None, remove_finished=False): 254 | download_root = os.path.expanduser(download_root) 255 | if extract_root is None: 256 | extract_root = download_root 257 | if not filename: 258 | filename = os.path.basename(url) 259 | 260 | download_url(url, download_root, filename, md5) 261 | 262 | archive = os.path.join(download_root, filename) 263 | print("Extracting {} to {}".format(archive, extract_root)) 264 | extract_archive(archive, extract_root, remove_finished) 265 | 266 | 267 | def iterable_to_str(iterable): 268 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 269 | 270 | 271 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 272 | if not isinstance(value, torch._six.string_classes): 273 | if arg is None: 274 | msg = "Expected type str, but got type {type}." 275 | else: 276 | msg = "Expected type str for argument {arg}, but got type {type}." 277 | msg = msg.format(type=type(value), arg=arg) 278 | raise ValueError(msg) 279 | 280 | if valid_values is None: 281 | return value 282 | 283 | if value not in valid_values: 284 | if custom_msg is not None: 285 | msg = custom_msg 286 | else: 287 | msg = ("Unknown value '{value}' for argument {arg}. " 288 | "Valid values are {{{valid_values}}}.") 289 | msg = msg.format(value=value, arg=arg, 290 | valid_values=iterable_to_str(valid_values)) 291 | raise ValueError(msg) 292 | 293 | return value 294 | -------------------------------------------------------------------------------- /Asym-RoCL/linear_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os 8 | import json 9 | import copy 10 | 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | 21 | import data_loader 22 | import model_loader 23 | import models 24 | from models.projector import Projector 25 | 26 | from argument import linear_parser, print_args 27 | from utils import progress_bar, checkpoint_train 28 | from collections import OrderedDict 29 | from attack_lib import FastGradientSignUntargeted 30 | from loss import pairwise_similarity, NT_xent 31 | 32 | args = linear_parser() 33 | use_cuda = torch.cuda.is_available() 34 | if use_cuda: 35 | ngpus_per_node = torch.cuda.device_count() 36 | 37 | if args.local_rank % ngpus_per_node==0: 38 | print_args(args) 39 | 40 | def print_status(string): 41 | if args.local_rank % ngpus_per_node == 0: 42 | print(string) 43 | 44 | print_status(torch.cuda.device_count()) 45 | print_status('Using CUDA..') 46 | 47 | best_acc = 0 # best test accuracy 48 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 49 | 50 | if args.seed != 0: 51 | torch.manual_seed(args.seed) 52 | 53 | # Data 54 | print_status('==> Preparing data..') 55 | if not (args.train_type=='linear_eval'): 56 | assert('wrong train phase...') 57 | else: 58 | trainloader, traindst, testloader, testdst = data_loader.get_dataset(args) 59 | 60 | if args.dataset == 'cifar-10' or args.dataset=='mnist': 61 | num_outputs = 10 62 | elif args.dataset == 'cifar-100': 63 | num_outputs = 100 64 | 65 | if args.model == 'ResNet50': 66 | expansion = 4 67 | else: 68 | expansion = 1 69 | 70 | # Model 71 | print_status('==> Building model..') 72 | train_type = args.train_type 73 | 74 | def load(args, epoch): 75 | model = model_loader.get_model(args) 76 | 77 | if epoch == 0: 78 | add = '' 79 | else: 80 | add = '_epoch_'+str(epoch) 81 | 82 | checkpoint_ = torch.load(args.load_checkpoint+add) 83 | 84 | new_state_dict = OrderedDict() 85 | for k, v in checkpoint_['model'].items(): 86 | name = k[7:] 87 | new_state_dict[name] = v 88 | 89 | model.load_state_dict(new_state_dict) 90 | 91 | if args.ss: 92 | projector = Projector(expansion=expansion) 93 | checkpoint_p = torch.load(args.load_checkpoint+'_projector'+add) 94 | new_state_dict = OrderedDict() 95 | for k, v in checkpoint_p['model'].items(): 96 | name = k[7:] 97 | new_state_dict[name] = v 98 | projector.load_state_dict(new_state_dict) 99 | 100 | if args.dataset=='cifar-10': 101 | Linear = nn.Sequential(nn.Linear(512*expansion, 10)) 102 | elif args.dataset=='cifar-100': 103 | Linear = nn.Sequential(nn.Linear(512*expansion, 100)) 104 | 105 | model_params = [] 106 | if args.finetune: 107 | model_params += model.parameters() 108 | if args.ss: 109 | model_params += projector.parameters() 110 | model_params += Linear.parameters() 111 | loptim = torch.optim.SGD(model_params, lr = args.lr, momentum=0.9, weight_decay=5e-4) 112 | 113 | use_cuda = torch.cuda.is_available() 114 | if use_cuda: 115 | ngpus_per_node = torch.cuda.device_count() 116 | model.cuda() 117 | Linear.cuda() 118 | model = nn.DataParallel(model) 119 | Linear = nn.DataParallel(Linear) 120 | if args.ss: 121 | projector.cuda() 122 | projector = nn.DataParallel(projector) 123 | else: 124 | assert("Need to use GPU...") 125 | 126 | print_status('Using CUDA..') 127 | cudnn.benchmark = True 128 | 129 | if args.adv_img: 130 | attack_info = 'Adv_train_epsilon_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(args.min) + '_max_val_' + str(args.max) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start) 131 | print_status("Adversarial training info...") 132 | print_status(attack_info) 133 | 134 | attacker = FastGradientSignUntargeted(model, linear=Linear, epsilon=args.epsilon, alpha=args.alpha, min_val=args.min, max_val=args.max, max_iters=args.k, _type=args.attack_type) 135 | 136 | if args.adv_img: 137 | if args.ss: 138 | return model, Linear, projector, loptim, attacker 139 | return model, Linear, 'None', loptim, attacker 140 | if args.ss: 141 | return model, Linear, projector, loptim, 'None' 142 | return model, Linear, 'None', loptim, 'None' 143 | 144 | criterion = nn.CrossEntropyLoss() 145 | 146 | 147 | def linear_train(epoch, model, Linear, projector, loptim, attacker=None): 148 | Linear.train() 149 | if args.finetune: 150 | model.train() 151 | if args.ss: 152 | projector.train() 153 | else: 154 | model.eval() 155 | 156 | total_loss = 0 157 | correct = 0 158 | total = 0 159 | 160 | for batch_idx, (ori, inputs, inputs_2, target) in enumerate(trainloader): 161 | ori, inputs_1, inputs_2, target = ori.cuda(), inputs.cuda(), inputs_2.cuda(), target.cuda() 162 | input_flag = False 163 | if args.trans: 164 | inputs = inputs_1 165 | else: 166 | inputs = ori 167 | 168 | if args.adv_img: 169 | advinputs = attacker.perturb(original_images=inputs, labels=target, random_start=args.random_start) 170 | 171 | if args.clean: 172 | total_inputs = inputs 173 | total_targets = target 174 | input_flag = True 175 | 176 | if args.ss: 177 | total_inputs = torch.cat((inputs, inputs_2)) 178 | total_targets = torch.cat((target, target)) 179 | 180 | if args.adv_img: 181 | if input_flag: 182 | total_inputs = torch.cat((total_inputs, advinputs)) 183 | total_targets = torch.cat((total_targets, target)) 184 | else: 185 | total_inputs = advinputs 186 | total_targets = target 187 | input_flag = True 188 | 189 | if not input_flag: 190 | assert('choose the linear evaluation data type (clean, adv_img)') 191 | 192 | feat = model(total_inputs) 193 | if args.ss: 194 | output_p = projector(feat) 195 | B = ori.size(0) 196 | 197 | similarity, _ = pairwise_similarity(output_p[:2*B,:2*B], temperature=args.temperature, multi_gpu=False, adv_type = 'None') 198 | simloss = NT_xent(similarity, 'None') 199 | 200 | output = Linear(feat) 201 | 202 | _, predx = torch.max(output.data, 1) 203 | loss = criterion(output, total_targets) 204 | 205 | if args.ss: 206 | loss += simloss 207 | 208 | correct += predx.eq(total_targets.data).cpu().sum().item() 209 | total += total_targets.size(0) 210 | acc = 100.*correct/total 211 | 212 | total_loss += loss.data 213 | 214 | loptim.zero_grad() 215 | loss.backward() 216 | loptim.step() 217 | 218 | progress_bar(batch_idx, len(trainloader), 219 | 'Loss: {:.4f} | Acc: {:.2f}'.format(total_loss/(batch_idx+1), acc)) 220 | 221 | print ("Epoch: {}, train accuracy: {}".format(epoch, acc)) 222 | 223 | return acc, model, Linear, projector, loptim 224 | 225 | def test(model, Linear): 226 | global best_acc 227 | 228 | model.eval() 229 | Linear.eval() 230 | 231 | test_loss = 0 232 | correct = 0 233 | total = 0 234 | 235 | for idx, (image, label) in enumerate(testloader): 236 | img = image.cuda() 237 | y = label.cuda() 238 | 239 | out = Linear(model(img)) 240 | 241 | _, predx = torch.max(out.data, 1) 242 | loss = criterion(out, y) 243 | 244 | correct += predx.eq(y.data).cpu().sum().item() 245 | total += y.size(0) 246 | acc = 100.*correct/total 247 | 248 | test_loss += loss.data 249 | if args.local_rank % ngpus_per_node == 0: 250 | progress_bar(idx, len(testloader),'Testing Loss {:.3f}, acc {:.3f}'.format(test_loss/(idx+1), acc)) 251 | 252 | print ("Test accuracy: {0}".format(acc)) 253 | 254 | return (acc, model, Linear) 255 | 256 | def adjust_lr(epoch, optim): 257 | lr = args.lr 258 | if args.dataset=='cifar-10' or args.dataset=='cifar-100': 259 | lr_list = [30,50,100] 260 | if epoch>=lr_list[0]: 261 | lr = lr/10 262 | if epoch>=lr_list[1]: 263 | lr = lr/10 264 | if epoch>=lr_list[2]: 265 | lr = lr/10 266 | 267 | for param_group in optim.param_groups: 268 | param_group['lr'] = lr 269 | 270 | ##### Log file for training selected tasks ##### 271 | if not os.path.isdir('results'): 272 | os.mkdir('results') 273 | 274 | args.name += ('_Evaluate_'+ args.train_type + '_' +args.model + '_' + args.dataset) 275 | loginfo = 'results/log_generalization_' + args.name + '_' + str(args.seed) 276 | logname = (loginfo+ '.csv') 277 | 278 | with open(logname, 'w') as logfile: 279 | logwriter = csv.writer(logfile, delimiter=',') 280 | logwriter.writerow(['epoch', 'train acc','test acc']) 281 | 282 | if args.epochwise: 283 | for k in range(100,1000,100): 284 | model, linear, projector, loptim, attacker = load(args, k) 285 | print('loading.......epoch ', str(k)) 286 | ##### Linear evaluation ##### 287 | for i in range(args.epoch): 288 | print('Epoch ', i) 289 | train_acc, model, linear, projector, loptim = linear_train(i, model, linear, projector, loptim, attacker) 290 | test_acc, model, linear = test(model, linear) 291 | adjust_lr(i, loptim) 292 | 293 | checkpoint_train(model, test_acc, args.epoch, args, loptim, save_name_add='epochwise'+str(k)) 294 | checkpoint_train(linear, test_acc, args.epoch, args, loptim, save_name_add='epochwise'+str(k)+'_linear') 295 | if args.local_rank % ngpus_per_node == 0: 296 | with open(logname, 'a') as logfile: 297 | logwriter = csv.writer(logfile, delimiter=',') 298 | logwriter.writerow([k, train_acc, test_acc]) 299 | 300 | model, linear, projector, loptim, attacker = load(args, 0) 301 | 302 | ##### Linear evaluation ##### 303 | for epoch in range(args.epoch): 304 | print('Epoch ', epoch) 305 | 306 | train_acc, model, linear, projector, loptim = linear_train(epoch, model=model, Linear=linear, projector=projector, loptim=loptim, attacker=attacker) 307 | test_acc, model, linear = test(model, linear) 308 | adjust_lr(epoch, loptim) 309 | 310 | if args.local_rank % ngpus_per_node == 0: 311 | with open(logname, 'a') as logfile: 312 | logwriter = csv.writer(logfile, delimiter=',') 313 | logwriter.writerow([epoch, train_acc, test_acc]) 314 | 315 | checkpoint_train(model, test_acc, args.epoch, args, loptim) 316 | checkpoint_train(linear, test_acc, args.epoch, args, loptim, save_name_add='_linear') 317 | 318 | if args.local_rank % ngpus_per_node == 0: 319 | with open(logname, 'a') as logfile: 320 | logwriter = csv.writer(logfile, delimiter=',') 321 | logwriter.writerow([1000, train_acc, test_acc]) 322 | -------------------------------------------------------------------------------- /Asym-AdvCL/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | 10 | class SupConLoss(nn.Module): 11 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 12 | It also supports the unsupervised contrastive loss in SimCLR""" 13 | 14 | def __init__(self, args, temperature=0.07, contrast_mode='all', 15 | base_temperature=0.07): 16 | super(SupConLoss, self).__init__() 17 | self.temperature = temperature 18 | self.contrast_mode = contrast_mode 19 | self.base_temperature = base_temperature 20 | 21 | self.args = args 22 | 23 | def forward(self, features, labels=None, mask=None, stop_grad=False, stop_grad_sd=-1.0): 24 | """Compute loss for model. If both `labels` and `mask` are None, 25 | it degenerates to SimCLR unsupervised loss: 26 | https://arxiv.org/pdf/2002.05709.pdf 27 | Args: 28 | features: hidden vector of shape [bsz, n_views, ...]. 29 | labels: ground truth of shape [bsz]. 30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 31 | has the same class as sample i. Can be asymmetric. 32 | Returns: 33 | A loss scalar. 34 | """ 35 | 36 | device = (torch.device('cuda') 37 | if features.is_cuda 38 | else torch.device('cpu')) 39 | 40 | if len(features.shape) < 3: 41 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 42 | 'at least 3 dimensions are required') 43 | if len(features.shape) > 3: 44 | features = features.view(features.shape[0], features.shape[1], -1) 45 | 46 | batch_size = features.shape[0] 47 | if labels is not None and mask is not None: 48 | raise ValueError('Cannot define both `labels` and `mask`') 49 | elif labels is None and mask is None: 50 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 51 | elif labels is not None: 52 | labels = labels.contiguous().view(-1, 1) 53 | if labels.shape[0] != batch_size: 54 | raise ValueError('Num of labels does not match num of features') 55 | mask = torch.eq(labels, labels.T).float().to(device) 56 | else: 57 | mask = mask.float().to(device) 58 | 59 | contrast_count = features.shape[1] 60 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 61 | if self.contrast_mode == 'one': 62 | anchor_feature = features[:, 0] 63 | anchor_count = 1 64 | elif self.contrast_mode == 'all': 65 | anchor_feature = contrast_feature 66 | anchor_count = contrast_count 67 | else: 68 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 69 | 70 | # compute logits 71 | anchor_dot_contrast = torch.div( 72 | torch.matmul(anchor_feature, contrast_feature.T), 73 | self.temperature) 74 | if stop_grad: 75 | anchor_dot_contrast_stpg = torch.div( 76 | torch.matmul(anchor_feature, contrast_feature.T.detach()), 77 | self.temperature) 78 | 79 | # tile mask 80 | mask = mask.repeat(anchor_count, contrast_count) 81 | # mask-out self-contrast cases 82 | logits_mask = torch.scatter( 83 | torch.ones_like(mask), 84 | 1, 85 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 86 | 0 87 | ) 88 | mask = mask * logits_mask 89 | 90 | # For hard negatives, code adapted from HCL (https://github.com/joshr17/HCL) 91 | # =============== hard neg params ================= 92 | tau_plus = self.args.tau_plus 93 | beta = self.args.beta 94 | temperature = 0.5 95 | N = (batch_size - 1) * contrast_count 96 | # =============== reweight neg ================= 97 | # for numerical stability 98 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 99 | logits = anchor_dot_contrast - logits_max.detach() 100 | exp_logits = torch.exp(logits) 101 | exp_logits_neg = exp_logits * (1 - mask) * logits_mask 102 | exp_logits_pos = exp_logits * mask 103 | pos = exp_logits_pos.sum(dim=1) / mask.sum(1) 104 | 105 | imp = (beta * (exp_logits_neg + 1e-9).log()).exp() 106 | reweight_logits_neg = (imp * exp_logits_neg) / imp.mean(dim=-1) 107 | Ng = (-tau_plus * N * pos + reweight_logits_neg.sum(dim=-1)) / (1 - tau_plus) 108 | # constrain (optional) 109 | Ng = torch.clamp(Ng, min=N * np.e ** (-1 / temperature)) 110 | log_prob = -torch.log(exp_logits / (pos + Ng)) 111 | # =============================================== 112 | 113 | loss_square = mask * log_prob # only positive positions have elements 114 | 115 | # mix_square = exp_logits 116 | mix_square = loss_square 117 | 118 | if stop_grad: 119 | logits_max_stpg, _ = torch.max(anchor_dot_contrast_stpg, dim=1, keepdim=True) 120 | logits_stpg = anchor_dot_contrast_stpg - logits_max_stpg.detach() 121 | # =============== reweight neg ================= 122 | exp_logits_stpg = torch.exp(logits_stpg) 123 | exp_logits_neg_stpg = exp_logits_stpg * (1 - mask) * logits_mask 124 | exp_logits_pos_stpg = exp_logits_stpg * mask 125 | pos_stpg = exp_logits_pos_stpg.sum(dim=1) / mask.sum(1) 126 | 127 | imp_stpg = (beta * (exp_logits_neg_stpg + 1e-9).log()).exp() 128 | reweight_logits_neg_stpg = (imp_stpg * exp_logits_neg_stpg) / imp_stpg.mean(dim=-1) 129 | Ng_stpg = ((-tau_plus * N * pos_stpg + reweight_logits_neg_stpg.sum(dim=-1)) / (1 - tau_plus)) 130 | 131 | # constrain (optional) 132 | Ng_stpg = torch.clamp(Ng_stpg, min=N * np.e ** (-1 / temperature)) 133 | log_prob_stpg = -torch.log(exp_logits_stpg / (pos_stpg + Ng_stpg)) 134 | # =============================================== 135 | tmp_square = mask * log_prob_stpg 136 | else: 137 | # tmp_square = exp_logits 138 | tmp_square = loss_square 139 | if stop_grad: 140 | ac_square = stop_grad_sd * tmp_square[batch_size:, 0:batch_size].T + (1 - stop_grad_sd) * tmp_square[ 141 | 0:batch_size, 142 | batch_size:] 143 | else: 144 | ac_square = tmp_square[0:batch_size, batch_size:] 145 | 146 | mix_square[0:batch_size, batch_size:] = ac_square * self.args.adv_weight 147 | mix_square[batch_size:, 0:batch_size] = ac_square.T * self.args.adv_weight 148 | 149 | # compute mean of log-likelihood over positive 150 | mean_log_prob_pos = mix_square.sum(1) / mask.sum(1) 151 | 152 | # loss 153 | loss = (self.temperature / self.base_temperature) * mean_log_prob_pos 154 | loss = loss.view(anchor_count, batch_size).mean() 155 | 156 | return loss 157 | 158 | 159 | class ori_SupConLoss(nn.Module): 160 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 161 | It also supports the unsupervised contrastive loss in SimCLR""" 162 | 163 | def __init__(self, args, temperature=0.07, contrast_mode='all', 164 | base_temperature=0.07): 165 | super(ori_SupConLoss, self).__init__() 166 | self.temperature = temperature 167 | self.contrast_mode = contrast_mode 168 | self.base_temperature = base_temperature 169 | 170 | self.args = args 171 | 172 | def forward(self, features, labels=None, mask=None, stop_grad=False, stop_grad_sd=-1.0): 173 | """Compute loss for model. If both `labels` and `mask` are None, 174 | it degenerates to SimCLR unsupervised loss: 175 | https://arxiv.org/pdf/2002.05709.pdf 176 | Args: 177 | features: hidden vector of shape [bsz, n_views, ...]. 178 | labels: ground truth of shape [bsz]. 179 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 180 | has the same class as sample i. Can be asymmetric. 181 | Returns: 182 | A loss scalar. 183 | """ 184 | 185 | device = (torch.device('cuda') 186 | if features.is_cuda 187 | else torch.device('cpu')) 188 | 189 | if len(features.shape) < 3: 190 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 191 | 'at least 3 dimensions are required') 192 | if len(features.shape) > 3: 193 | features = features.view(features.shape[0], features.shape[1], -1) 194 | 195 | batch_size = features.shape[0] 196 | if labels is not None and mask is not None: 197 | raise ValueError('Cannot define both `labels` and `mask`') 198 | elif labels is None and mask is None: 199 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 200 | elif labels is not None: 201 | labels = labels.contiguous().view(-1, 1) 202 | if labels.shape[0] != batch_size: 203 | raise ValueError('Num of labels does not match num of features') 204 | mask = torch.eq(labels, labels.T).float().to(device) 205 | else: 206 | mask = mask.float().to(device) 207 | 208 | contrast_count = features.shape[1] 209 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 210 | if self.contrast_mode == 'one': 211 | anchor_feature = features[:, 0] 212 | anchor_count = 1 213 | elif self.contrast_mode == 'all': 214 | anchor_feature = contrast_feature 215 | anchor_count = contrast_count 216 | else: 217 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 218 | 219 | # compute logits 220 | anchor_dot_contrast = torch.div( 221 | torch.matmul(anchor_feature, contrast_feature.T), 222 | self.temperature) 223 | if stop_grad: 224 | anchor_dot_contrast_stpg = torch.div( 225 | torch.matmul(anchor_feature, contrast_feature.T.detach()), 226 | self.temperature) 227 | # for numerical stability 228 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 229 | logits = anchor_dot_contrast - logits_max.detach() 230 | 231 | # tile mask 232 | mask = mask.repeat(anchor_count, contrast_count) 233 | # mask-out self-contrast cases 234 | logits_mask = torch.scatter( 235 | torch.ones_like(mask), 236 | 1, 237 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 238 | 0 239 | ) 240 | mask = mask * logits_mask 241 | 242 | # compute log_prob 243 | exp_logits = torch.exp(logits) * logits_mask 244 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 245 | 246 | loss_square = mask * log_prob # only positive position have elements 247 | mix_square = loss_square 248 | if stop_grad: 249 | logits_max_stpg, _ = torch.max(anchor_dot_contrast_stpg, dim=1, keepdim=True) 250 | logits_stpg = anchor_dot_contrast_stpg - logits_max_stpg.detach() 251 | # compute log_prob 252 | exp_logits_stpg = torch.exp(logits_stpg) * logits_mask 253 | log_prob_stpg = logits_stpg - torch.log(exp_logits_stpg.sum(1, keepdim=True)) 254 | loss_square_stpg = mask * log_prob_stpg 255 | tmp_square = loss_square_stpg 256 | else: 257 | tmp_square = loss_square 258 | if stop_grad: 259 | ac_square = stop_grad_sd * tmp_square[batch_size:, 0:batch_size].T + (1 - stop_grad_sd) * tmp_square[ 260 | 0:batch_size, 261 | batch_size:] 262 | else: 263 | ac_square = tmp_square[0:batch_size, batch_size:] 264 | 265 | mix_square[0:batch_size, batch_size:] = ac_square * self.args.adv_weight 266 | mix_square[batch_size:, 0:batch_size] = ac_square.T * self.args.adv_weight 267 | 268 | # compute mean of log-likelihood over positive 269 | mean_log_prob_pos = mix_square.sum(1) / mask.sum(1) 270 | 271 | # loss 272 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 273 | loss = loss.view(anchor_count, batch_size).mean() 274 | 275 | return loss 276 | -------------------------------------------------------------------------------- /Asym-RoCL/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def test_parser(): 4 | parser = argparse.ArgumentParser(description='linear eval test') 5 | parser.add_argument('--train_type', default='linear_eval', type=str, help='standard') 6 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100') 7 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT') 8 | parser.add_argument('--model', default='ResNet18', type=str, 9 | help='model type ResNet18/ResNet50') 10 | parser.add_argument('--name', default='', type=str, help='name of run') 11 | parser.add_argument('--seed', default=2342, type=int, help='random seed') 12 | parser.add_argument('--batch-size', default=128, type=int, help='batch size / multi-gpu setting: batch per gpu') 13 | 14 | ##### arguments for data augmentation ##### 15 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR, 1.0 for ImageNet') 16 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity') 17 | 18 | ##### arguments for distributted parallel ##### 19 | parser.add_argument('--local_rank', type=int, default=0) 20 | parser.add_argument('--ngpu', type=int, default=1) 21 | 22 | parser.add_argument('--attack_type', type=str, default='linf', 23 | help='adversarial l_p') 24 | parser.add_argument('--epsilon', type=float, default=0.0314, 25 | help='maximum perturbation of adversaries (8/255 for cifar-10)') 26 | parser.add_argument('--alpha', type=float, default=0.007, 27 | help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)') 28 | parser.add_argument('--k', type=int, default=10, 29 | help='maximum iteration when generating adversarial examples') 30 | parser.add_argument('--random_start', type=bool, default=True, 31 | help='True for PGD') 32 | 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | def linear_parser(): 38 | parser = argparse.ArgumentParser(description='RoCL linear training') 39 | 40 | ##### arguments for RoCL Linear eval (LE) or Robust Linear eval (r-LE)##### 41 | parser.add_argument('--train_type', default='linear_eval', type=str, help='contrastive/linear eval/test') 42 | parser.add_argument('--finetune', default=False, type=bool, help='finetune the model') 43 | parser.add_argument('--epochwise', type=bool, default=False, help='epochwise saving...') 44 | parser.add_argument('--ss', default=False, type=bool, help='using self-supervised learning loss') 45 | 46 | parser.add_argument('--trans', default=False, type=bool, help='use transformed sample') 47 | parser.add_argument('--clean', default=False, type=bool, help='use clean sample') 48 | parser.add_argument('--adv_img', default=False, type=bool, help='use adversarial sample') 49 | 50 | ##### arguments for training ##### 51 | parser.add_argument('--lr', default=0.2, type=float, help='learning rate') 52 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier') 53 | 54 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100') 55 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT') 56 | parser.add_argument('--model', default="ResNet18", type=str, 57 | help='model type ResNet18/ResNet50') 58 | 59 | parser.add_argument('--name', default='', type=str, help='name of run') 60 | parser.add_argument('--seed', default=2342, type=int, help='random seed') 61 | parser.add_argument('--batch-size', default=128, type=int, help='batch size / multi-gpu setting: batch per gpu') 62 | parser.add_argument('--epoch', default=150, type=int, 63 | help='total epochs to run') 64 | 65 | ##### arguments for data augmentation ##### 66 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR') 67 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity') 68 | 69 | ##### arguments for distributted parallel ##### 70 | parser.add_argument('--local_rank', type=int, default=0) 71 | parser.add_argument('--ngpu', type=int, default=1) 72 | 73 | ##### arguments for PGD attack & Adversarial Training ##### 74 | parser.add_argument('--min', type=float, default=0.0, 75 | help='min for cliping image') 76 | parser.add_argument('--max', type=float, default=1.0, 77 | help='max for cliping image') 78 | parser.add_argument('--attack_type', type=str, default='linf', 79 | help='adversarial l_p') 80 | parser.add_argument('--epsilon', type=float, default=0.0314, 81 | help='maximum perturbation of adversaries (8/255 for cifar-10)') 82 | parser.add_argument('--alpha', type=float, default=0.007, 83 | help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)') 84 | parser.add_argument('--k', type=int, default=10, 85 | help='maximum iteration when generating adversarial examples') 86 | 87 | parser.add_argument('--random_start', type=bool, default=True, 88 | help='True for PGD') 89 | args = parser.parse_args() 90 | 91 | return args 92 | 93 | 94 | def parser(): 95 | parser = argparse.ArgumentParser(description='PyTorch RoCL training') 96 | parser.add_argument('--module',action='store_true') 97 | 98 | ##### arguments for RoCL ##### 99 | parser.add_argument('--lamda', default=256, type=float) 100 | 101 | parser.add_argument('--regularize_to', default='other', type=str, help='original/other') 102 | parser.add_argument('--attack_to', default='other', type=str, help='original/other') 103 | 104 | parser.add_argument('--loss_type', type=str, default='sim', help='loss type for Rep') 105 | parser.add_argument('--advtrain_type', default='Rep', type=str, help='Rep/None') 106 | 107 | ##### arguments for Training Self-Sup ##### 108 | parser.add_argument('--train_type', default='contrastive', type=str, help='contrastive/linear eval/test') 109 | 110 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 111 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier') 112 | parser.add_argument('--decay', default=1e-6, type=float, help='weight decay') 113 | 114 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100') 115 | 116 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT') 117 | parser.add_argument('--resume', '-r', action='store_true', 118 | help='resume from checkpoint') 119 | parser.add_argument('--model', default='ResNet18', type=str, 120 | help='model type ResNet18/ResNet50') 121 | 122 | parser.add_argument('--name', default='', type=str, help='name of run') 123 | parser.add_argument('--seed', default=0, type=int, help='random seed') 124 | parser.add_argument('--batch-size', default=256, type=int, help='batch size / multi-gpu setting: batch per gpu') 125 | parser.add_argument('--epoch', default=1000, type=int, 126 | help='total epochs to run') 127 | 128 | ##### arguments for data augmentation ##### 129 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR') 130 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity') 131 | 132 | ##### arguments for distributted parallel ##### 133 | parser.add_argument('--local_rank', type=int, default=0) 134 | parser.add_argument('--ngpu', type=int, default=2) 135 | 136 | ##### arguments for PGD attack & Adversarial Training ##### 137 | parser.add_argument('--min', type=float, default=0.0, help='min for cliping image') 138 | parser.add_argument('--max', type=float, default=1.0, help='max for cliping image') 139 | parser.add_argument('--attack_type', type=str, default='linf', help='adversarial l_p') 140 | parser.add_argument('--epsilon', type=float, default=0.0314, 141 | help='maximum perturbation of adversaries (8/255 for cifar-10)') 142 | parser.add_argument('--alpha', type=float, default=0.007, help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)') 143 | parser.add_argument('--k', type=int, default=7, help='maximum iteration when generating adversarial examples') 144 | parser.add_argument('--random_start', type=bool, default=True, 145 | help='True for PGD') 146 | 147 | ##### new 148 | parser.add_argument('--stop_grad', action='store_true') 149 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss: stop_grad stpg-weight") 150 | # must use with --stop_grad 151 | parser.add_argument('--stpg_degree', type=float, default=-1, 152 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch") 153 | 154 | ##### HN 155 | parser.add_argument('--HN', action='store_true', default=False, 156 | help='use HN') 157 | parser.add_argument('--tau', type=float, default=0.1, help='tau in HN') 158 | parser.add_argument('--beta', type=float, default=1.0, help='beta in HN') 159 | args = parser.parse_args() 160 | 161 | return args 162 | 163 | def parser_load(): 164 | parser = argparse.ArgumentParser(description='PyTorch RoCL training') 165 | parser.add_argument('--module',action='store_true') 166 | 167 | parser.add_argument('--load_checkpoint', type=str, default='./checkpoint/ckpt.t7one_task_0', help='load checkpoint') 168 | ##### arguments for RoCL ##### 169 | parser.add_argument('--lamda', default=256, type=float) 170 | 171 | parser.add_argument('--regularize_to', default='other', type=str, help='original/other') 172 | parser.add_argument('--attack_to', default='other', type=str, help='original/other') 173 | 174 | parser.add_argument('--loss_type', type=str, default='sim', help='loss type for Rep') 175 | parser.add_argument('--advtrain_type', default='Rep', type=str, help='Rep/None') 176 | 177 | ##### arguments for Training Self-Sup ##### 178 | parser.add_argument('--train_type', default='contrastive', type=str, help='contrastive/linear eval/test') 179 | 180 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 181 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier') 182 | parser.add_argument('--decay', default=1e-6, type=float, help='weight decay') 183 | 184 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100') 185 | 186 | # parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT') 187 | parser.add_argument('--resume', '-r', action='store_true', 188 | help='resume from checkpoint') 189 | parser.add_argument('--model', default='ResNet18', type=str, 190 | help='model type ResNet18/ResNet50') 191 | 192 | parser.add_argument('--name', default='', type=str, help='name of run') 193 | parser.add_argument('--seed', default=0, type=int, help='random seed') 194 | parser.add_argument('--batch-size', default=256, type=int, help='batch size / multi-gpu setting: batch per gpu') 195 | parser.add_argument('--epoch', default=1000, type=int, 196 | help='total epochs to run') 197 | 198 | ##### arguments for data augmentation ##### 199 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR') 200 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity') 201 | 202 | ##### arguments for distributted parallel ##### 203 | parser.add_argument('--local_rank', type=int, default=0) 204 | parser.add_argument('--ngpu', type=int, default=2) 205 | 206 | ##### arguments for PGD attack & Adversarial Training ##### 207 | parser.add_argument('--min', type=float, default=0.0, help='min for cliping image') 208 | parser.add_argument('--max', type=float, default=1.0, help='max for cliping image') 209 | parser.add_argument('--attack_type', type=str, default='linf', help='adversarial l_p') 210 | parser.add_argument('--epsilon', type=float, default=0.0314, 211 | help='maximum perturbation of adversaries (8/255 for cifar-10)') 212 | parser.add_argument('--alpha', type=float, default=0.007, help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)') 213 | parser.add_argument('--k', type=int, default=7, help='maximum iteration when generating adversarial examples') 214 | parser.add_argument('--random_start', type=bool, default=True, 215 | help='True for PGD') 216 | 217 | ##### new 218 | parser.add_argument('--stop_grad', action='store_true') 219 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss: stop_grad stpg-weight") 220 | # must use with --stop_grad 221 | parser.add_argument('--stpg_degree', type=float, default=-1, 222 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch") 223 | 224 | ##### HN 225 | parser.add_argument('--HN', action='store_true', default=False, 226 | help='use HN') 227 | parser.add_argument('--tau', type=float, default=0.1, help='tau in HN') 228 | parser.add_argument('--beta', type=float, default=1.0, help='beta in HN') 229 | args = parser.parse_args() 230 | 231 | return args 232 | 233 | def print_args(args): 234 | for k, v in vars(args).items(): 235 | print('{:<16} : {}'.format(k, v)) 236 | 237 | -------------------------------------------------------------------------------- /Asym-AdvCL/finetuning_advCL_SLF.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import argparse 5 | import time 6 | import math 7 | 8 | import torch 9 | import os 10 | import torch.backends.cudnn as cudnn 11 | 12 | from utils import AverageMeter 13 | from utils import accuracy 14 | from torch import optim 15 | from torchvision import transforms, datasets 16 | import models 17 | import torch.nn.functional as F 18 | from torch import nn 19 | from models.resnet_cifar import ResNet18 20 | from models.linear import LinearClassifier 21 | import tensorboard_logger as tb_logger 22 | from utils import load_BN_checkpoint 23 | from trades import trades_loss 24 | 25 | 26 | def set_loader(opt): 27 | # construct data loader 28 | train_transform = transforms.Compose([ 29 | transforms.Resize(32), 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | ]) 34 | val_transform = transforms.Compose([ 35 | transforms.Resize(32), 36 | transforms.ToTensor(), 37 | ]) 38 | 39 | if opt.dataset == "cifar10": 40 | opt.data_folder += 'cifar10/' 41 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 42 | transform=train_transform, 43 | download=True) 44 | val_dataset = datasets.CIFAR10(root=opt.data_folder, 45 | train=False, 46 | transform=val_transform) 47 | elif opt.dataset == "stl10": 48 | train_dataset = datasets.STL10(root=opt.data_folder, 49 | transform=train_transform, 50 | download=True) 51 | val_dataset = datasets.STL10(root=opt.data_folder, 52 | train=False, 53 | transform=val_transform) 54 | elif opt.dataset == "cifar100": 55 | opt.data_folder += 'cifar100/' 56 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 57 | transform=train_transform, 58 | download=True) 59 | val_dataset = datasets.CIFAR100(root=opt.data_folder, 60 | train=False, 61 | transform=val_transform) 62 | else: 63 | raise NotImplementedError("Dataset Not Supported") 64 | 65 | train_sampler = None 66 | train_loader = torch.utils.data.DataLoader( 67 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 68 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 69 | val_loader = torch.utils.data.DataLoader( 70 | val_dataset, batch_size=256, shuffle=False, 71 | num_workers=8, pin_memory=True) 72 | 73 | return train_loader, val_loader 74 | 75 | 76 | def parse_option(): 77 | parser = argparse.ArgumentParser('argument for training') 78 | 79 | parser.add_argument('--print_freq', type=int, default=10, 80 | help='print frequency') 81 | parser.add_argument('--save_freq', type=int, default=50, 82 | help='save frequency') 83 | parser.add_argument('--batch_size', type=int, default=512, 84 | help='batch_size') 85 | parser.add_argument('--num_workers', type=int, default=16, 86 | help='num of workers to use') 87 | parser.add_argument('--epochs', type=int, default=25, 88 | help='number of training epochs') 89 | # optimization 90 | parser.add_argument('--learning_rate', type=float, default=0.1, 91 | help='learning rate') 92 | parser.add_argument('--weight_decay', type=float, default=2e-4, 93 | help='weight decay') 94 | parser.add_argument('--momentum', type=float, default=0.9, 95 | help='momentum') 96 | 97 | # model dataset 98 | parser.add_argument('--dataset', type=str, default='cifar10', 99 | choices=['cifar10', 'cifar100', 'stl10'], help='dataset') 100 | # other setting 101 | parser.add_argument('--ckpt', type=str, default='', 102 | help='path to pre-trained model') 103 | parser.add_argument('--name', type=str, default='advcl_slf', 104 | help='name of the exp') 105 | 106 | # control the finetuning type 107 | parser.add_argument('--finetune_type', type=str, default='SLF', choices=['SLF', 'ALF', 'AFF_trades']) 108 | 109 | opt = parser.parse_args() 110 | print(opt) 111 | 112 | # set the path according to the environment 113 | opt.data_folder = '~/data/' 114 | if opt.dataset == 'cifar10' or opt.dataset == 'stl10': 115 | opt.n_cls = 10 116 | elif opt.dataset == 'cifar100': 117 | opt.n_cls = 100 118 | else: 119 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 120 | 121 | return opt 122 | 123 | 124 | # PGD attack model 125 | class AttackPGD(nn.Module): 126 | def __init__(self, model, classifier, config): 127 | super(AttackPGD, self).__init__() 128 | self.model = model 129 | self.classifier = classifier 130 | self.rand = config['random_start'] 131 | self.step_size = config['step_size'] 132 | self.epsilon = config['epsilon'] 133 | assert config['loss_func'] == 'xent', 'Plz use xent for loss function.' 134 | 135 | def forward(self, inputs, targets, train=True, finetune_type="AFF"): 136 | x = inputs.detach() 137 | if self.rand: 138 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 139 | if train: 140 | num_step = 10 141 | else: 142 | num_step = 20 143 | for i in range(num_step): 144 | x.requires_grad_() 145 | with torch.enable_grad(): 146 | features = self.model(x, return_feat=True) 147 | logits = self.classifier(features) 148 | loss = F.cross_entropy(logits, targets, size_average=False) 149 | grad = torch.autograd.grad(loss, [x])[0] 150 | x = x.detach() + self.step_size * torch.sign(grad.detach()) 151 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 152 | x = torch.clamp(x, 0, 1) 153 | if finetune_type == "SLF" or finetune_type == "ALF": 154 | with torch.no_grad(): 155 | features = self.model(x, return_feat=True) 156 | else: 157 | features = self.model(x, return_feat=True) 158 | return self.classifier(features), x 159 | 160 | 161 | def set_model(opt): 162 | model = ResNet18() 163 | criterion = torch.nn.CrossEntropyLoss() 164 | classifier = LinearClassifier(name=opt.name, feat_dim=512, num_classes=opt.n_cls) 165 | if len(opt.ckpt) > 2: 166 | print('loading from {}'.format(opt.ckpt)) 167 | ckpt = torch.load(opt.ckpt, map_location='cpu') 168 | if 'state_dict' in ckpt.keys(): 169 | state_dict = ckpt['state_dict'] 170 | else: 171 | state_dict = ckpt['model'] 172 | print(torch.cuda.device_count()) 173 | if torch.cuda.is_available(): 174 | if torch.cuda.device_count() > 1: 175 | model = torch.nn.DataParallel(model) 176 | state_dict, _ = load_BN_checkpoint(state_dict) 177 | model = model.cuda() 178 | classifier = classifier.cuda() 179 | criterion = criterion.cuda() 180 | 181 | config = { 182 | 'epsilon': 8.0 / 255., 183 | 'num_steps': 10, 184 | 'step_size': 2.0 / 255, 185 | 'random_start': True, 186 | 'loss_func': 'xent', 187 | } 188 | net = AttackPGD(model, classifier, config) 189 | net = net.cuda() 190 | 191 | cudnn.benchmark = True 192 | model.load_state_dict(state_dict, strict=False) 193 | else: 194 | print("only GPU version supported") 195 | raise NotImplementedError 196 | else: 197 | print("please specify pretrained model") 198 | raise NotImplementedError 199 | return model, classifier, net, criterion 200 | 201 | 202 | def train(train_loader, model, classifier, net, criterion, optimizer, epoch, opt): 203 | """one epoch training""" 204 | model.eval() 205 | classifier.train() 206 | 207 | batch_time = AverageMeter() 208 | data_time = AverageMeter() 209 | losses = AverageMeter() 210 | top1 = AverageMeter() 211 | 212 | end = time.time() 213 | for idx, (images, labels) in enumerate(train_loader): 214 | data_time.update(time.time() - end) 215 | 216 | images = images.cuda(non_blocking=True) 217 | labels = labels.cuda(non_blocking=True) 218 | bsz = labels.shape[0] 219 | 220 | if opt.finetune_type == "AFF": 221 | _, x_adv = net(images, labels, train=False, finetune_type="AFF") 222 | # calculate robust loss 223 | beta = 6.0 224 | trainmode = 'adv' 225 | 226 | # fixmode f1: fix nothing, f2: fix previous 3 stages, f3: fix all except fc') 227 | loss, output = trades_loss(model=model, 228 | classifier=classifier, 229 | x_natural=images, 230 | x_adv=x_adv, 231 | y=labels, 232 | optimizer=optimizer, 233 | beta=beta, 234 | trainmode=trainmode, 235 | fixmode='f1') 236 | elif opt.finetune_type == "AFF_trades": 237 | # calculate robust loss 238 | step_size = 2. / 255. 239 | epsilon = 8. / 255. 240 | num_steps_train = 10 241 | beta = 6.0 242 | trainmode = 'adv' 243 | 244 | # fixmode f1: fix nothing, f2: fix previous 3 stages, f3: fix all except fc') 245 | loss, output = trades_loss(model=model, 246 | classifier=classifier, 247 | x_natural=images, 248 | x_adv="", 249 | y=labels, 250 | optimizer=optimizer, 251 | step_size=step_size, 252 | epsilon=epsilon, 253 | perturb_steps=num_steps_train, 254 | beta=beta, 255 | trainmode=trainmode, 256 | fixmode='f1', 257 | trades=True) 258 | elif opt.finetune_type == "SLF": 259 | with torch.no_grad(): 260 | features = model(images, return_feat=True) 261 | output = classifier(features.detach()) 262 | loss = criterion(output, labels) 263 | else: # ALF, use PGD examples to train the classifier 264 | output, _ = net(images, labels, train=False, finetune_type="ALF") 265 | loss = criterion(output, labels) 266 | 267 | # update metric 268 | losses.update(loss.item(), bsz) 269 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 270 | top1.update(acc1[0], bsz) 271 | 272 | # SGD 273 | optimizer.zero_grad() 274 | loss.backward() 275 | optimizer.step() 276 | 277 | # measure elapsed time 278 | batch_time.update(time.time() - end) 279 | end = time.time() 280 | 281 | # print info 282 | if (idx + 1) % opt.print_freq == 0: 283 | print('Train: [{0}][{1}/{2}]\t' 284 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 285 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 286 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 287 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 288 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 289 | data_time=data_time, loss=losses, top1=top1)) 290 | sys.stdout.flush() 291 | 292 | return losses.avg, top1.avg 293 | 294 | 295 | def validate(val_loader, model, classifier, net, criterion, opt): 296 | """validation""" 297 | model.eval() 298 | classifier.eval() 299 | 300 | batch_time = AverageMeter() 301 | losses = AverageMeter() 302 | top1 = AverageMeter() 303 | 304 | top1_clean = AverageMeter() 305 | 306 | with torch.no_grad(): 307 | end = time.time() 308 | for idx, (images, labels) in enumerate(val_loader): 309 | images = images.float().cuda() 310 | labels = labels.cuda() 311 | bsz = labels.shape[0] 312 | 313 | # forward 314 | output, _ = net(images, labels, train=False) 315 | loss = criterion(output, labels) 316 | 317 | features_clean = model(images, return_feat=True) 318 | output_clean = classifier(features_clean) 319 | acc1_clean, acc5_clean = accuracy(output_clean, labels, topk=(1, 5)) 320 | 321 | # update metric 322 | losses.update(loss.item(), bsz) 323 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 324 | top1.update(acc1[0], bsz) 325 | top1_clean.update(acc1_clean[0], bsz) 326 | 327 | # measure elapsed time 328 | batch_time.update(time.time() - end) 329 | end = time.time() 330 | 331 | if idx % opt.print_freq == 0: 332 | print('Test: [{0}/{1}]\t' 333 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 334 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 335 | 'Acc@1 Clean {top1_clean.val:.4f} ({top1_clean.avg:.4f})\t' 336 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 337 | idx, len(val_loader), batch_time=batch_time, 338 | loss=losses, top1=top1, top1_clean=top1_clean)) 339 | 340 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 341 | print(' * Acc@1 Clean {top1_clean.avg:.3f}'.format(top1_clean=top1_clean)) 342 | return losses.avg, top1.avg, top1_clean.avg 343 | 344 | 345 | def adjust_lr(lr, optimizer, epoch): 346 | if epoch >= 15: 347 | lr /= 10 348 | if epoch >= 20: 349 | lr /= 10 350 | for param_group in optimizer.param_groups: 351 | param_group['lr'] = lr 352 | 353 | 354 | def main(): 355 | best_acc = 0 356 | best_acc_clean = 0 357 | opt = parse_option() 358 | 359 | log_ra = 0 360 | log_ta = 0 361 | 362 | # build data loader 363 | train_loader, val_loader = set_loader(opt) 364 | 365 | # build model and criterion 366 | model, classifier, net, criterion = set_model(opt) 367 | 368 | # build optimizer 369 | if opt.finetune_type == "SLF" or opt.finetune_type == "ALF": 370 | params = list(classifier.parameters()) 371 | else: 372 | params = list(model.parameters()) + list(classifier.parameters()) 373 | optimizer = optim.SGD(params, 374 | lr=opt.learning_rate, 375 | momentum=opt.momentum, 376 | weight_decay=opt.weight_decay) 377 | logname = ('logger/' + '{}_linearnormal'.format(opt.name+str(opt.learning_rate))) 378 | logger = tb_logger.Logger(logdir=logname, flush_secs=2) 379 | 380 | # training routine 381 | for epoch in range(1, opt.epochs + 1): 382 | adjust_lr(opt.learning_rate, optimizer, epoch - 1) 383 | # train for one epoch 384 | time1 = time.time() 385 | loss, acc = train(train_loader, model, classifier, net, criterion, 386 | optimizer, epoch, opt) 387 | logger.log_value('train_loss', loss, epoch) 388 | logger.log_value('train_acc', acc, epoch) 389 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 390 | time2 = time.time() 391 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( 392 | epoch, time2 - time1, acc)) 393 | # eval for one epoch 394 | loss, val_acc, val_acc_clean = validate(val_loader, model, classifier, net, criterion, opt) 395 | logger.log_value('val_loss', loss, epoch) 396 | logger.log_value('val_acc', val_acc, epoch) 397 | logger.log_value('val_acc_clean', val_acc_clean, epoch) 398 | 399 | logger.log_value('best_val_acc', log_ra, epoch) 400 | logger.log_value('best_val_acc_clean', log_ta, epoch) 401 | 402 | if val_acc > best_acc: 403 | best_acc = val_acc 404 | best_acc_clean = val_acc_clean 405 | log_ra = val_acc 406 | log_ta = val_acc_clean 407 | state = { 408 | 'model': model.state_dict(), 409 | 'classifier': classifier.state_dict(), 410 | 'epoch': epoch, 411 | 'rng_state': torch.get_rng_state() 412 | } 413 | if not os.path.exists('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))): 414 | os.makedirs('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))) 415 | if opt.finetune_type == "AFF_trades": 416 | torch.save(state, './checkpoint/{}/best_aff.ckpt'.format(opt.name+str(opt.learning_rate), epoch)) 417 | elif opt.finetune_type == "ALF": 418 | torch.save(state, './checkpoint/{}/best_alf.ckpt'.format(opt.name+str(opt.learning_rate), epoch)) 419 | else: 420 | torch.save(state, './checkpoint/{}/best.ckpt'.format(opt.name+str(opt.learning_rate), epoch)) 421 | 422 | if epoch % opt.save_freq == 0: 423 | state = { 424 | 'model': model.state_dict(), 425 | 'classifier': classifier.state_dict(), 426 | 'epoch': epoch, 427 | 'rng_state': torch.get_rng_state() 428 | } 429 | if not os.path.exists('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))): 430 | os.makedirs('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))) 431 | torch.save(state, './checkpoint/{}/ep_{}.ckpt'.format(opt.name+str(opt.learning_rate), epoch)) 432 | 433 | print('best accuracy: {:.2f}'.format(best_acc)) 434 | print('best accuracy clean: {:.2f}'.format(best_acc_clean)) 435 | 436 | 437 | if __name__ == '__main__': 438 | main() 439 | -------------------------------------------------------------------------------- /Asym-AdvCL/pretraining_advCL.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import numpy as np 5 | import os, csv 6 | from dataset import CIFAR10IndexPseudoLabelEnsemble, CIFAR100IndexPseudoLabelEnsemble 7 | import pickle 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.transforms as transforms 12 | import torch.utils.data as Data 13 | import torch.backends.cudnn as cudnn 14 | 15 | from utils import progress_bar, TwoCropTransformAdv 16 | from losses import SupConLoss, ori_SupConLoss 17 | import tensorboard_logger as tb_logger 18 | from models.resnet_cifar_multibn_ensembleFC import resnet18 as ResNet18 19 | import random 20 | from fr_util import generate_high 21 | from utils import adjust_learning_rate, warmup_learning_rate, AverageMeter 22 | import apex 23 | 24 | # ================================================================== # 25 | # Inputs and Pre-definition # 26 | # ================================================================== # 27 | 28 | # Arguments 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--name', type=str, default='advcl_cifar10', 31 | help='name of the run') 32 | parser.add_argument('--cname', type=str, default='imagenet_clPretrain', 33 | help='') 34 | parser.add_argument('--batch_size', type=int, default=512, 35 | help='batch size') 36 | parser.add_argument('--epoch', type=int, default=1000, 37 | help='total epochs') 38 | parser.add_argument('--save-epoch', type=int, default=100, 39 | help='save epochs') 40 | parser.add_argument('--epsilon', type=float, default=8, 41 | help='The upper bound change of L-inf norm on input pixels') 42 | parser.add_argument('--iter', type=int, default=5, 43 | help='The number of iterations for iterative attacks') 44 | parser.add_argument('--radius', type=int, default=8, 45 | help='radius of low freq images') 46 | parser.add_argument('--ce_weight', type=float, default=0.2, 47 | help='cross entp weight') 48 | 49 | # contrastive related 50 | parser.add_argument('-t', '--nce_t', default=0.5, type=float, 51 | help='temperature') 52 | parser.add_argument('--seed', default=0, type=float, 53 | help='random seed') 54 | parser.add_argument('--dataset', type=str, default='cifar10', help='dataset') 55 | parser.add_argument('--cosine', action='store_true', 56 | help='using cosine annealing') 57 | parser.add_argument('--warm', action='store_true', 58 | help='warm-up for large batch training') 59 | parser.add_argument('--learning_rate', type=float, default=0.5, 60 | help='learning rate') 61 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', 62 | help='where to decay lr, can be a list') 63 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 64 | help='decay rate for learning rate') 65 | parser.add_argument('--weight_decay', type=float, default=1e-4, 66 | help='weight decay') 67 | parser.add_argument('--momentum', type=float, default=0.9, 68 | help='momentum') 69 | 70 | # ====== HN ====== 71 | parser.add_argument('--attack_ori', action='store_true', default=False, 72 | help='attack use ori criterion') 73 | parser.add_argument('--HN', action='store_true', default=False, 74 | help='use HN') 75 | parser.add_argument('--tau_plus', type=float, default=0.1, help="tau-plus in HN") 76 | parser.add_argument('--beta', type=float, default=1.0, help="beta in HN") 77 | 78 | # ====== stop grad ====== 79 | parser.add_argument('--stop_grad', action='store_true', default=False, help="whether to stop gradient") 80 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss") 81 | parser.add_argument('--d_min', type=float, default=0.4, help="min distance in adaptive grad stopping") 82 | parser.add_argument('--d_max', type=float, default=0, help="max distance in adaptive grad stopping") 83 | # must use with --stop_grad 84 | parser.add_argument('--stpg_degree', type=float, default=-1.0, 85 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch") 86 | parser.add_argument('--stop_grad_adaptive', type=int, default=-1, help="adaptively stop grad") 87 | 88 | args = parser.parse_args() 89 | args.epochs = args.epoch 90 | args.decay = args.weight_decay 91 | args.cosine = True 92 | import math 93 | 94 | if args.batch_size > 256: 95 | args.warm = True 96 | if args.warm: 97 | args.warmup_from = 0.01 98 | args.warm_epochs = 10 99 | if args.cosine: 100 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3) 101 | args.warmup_to = eta_min + (args.learning_rate - eta_min) * ( 102 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2 103 | else: 104 | args.warmup_to = args.learning_rate 105 | 106 | print(args) 107 | 108 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 109 | # Device configuration 110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 111 | config = { 112 | 'epsilon': args.epsilon / 255., 113 | 'num_steps': args.iter, 114 | 'step_size': 2.0 / 255, 115 | 'random_start': True, 116 | 'loss_func': 'xent', 117 | } 118 | # ================================================================== # 119 | # Data and Pre-processing # 120 | # ================================================================== # 121 | print('=====> Preparing data...') 122 | # Multi-cuda 123 | if torch.cuda.is_available(): 124 | n_gpu = torch.cuda.device_count() 125 | batch_size = args.batch_size 126 | 127 | transform_train = transforms.Compose([ 128 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.RandomApply([ 131 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 132 | ], p=0.8), 133 | transforms.RandomGrayscale(p=0.2), 134 | transforms.ToTensor(), 135 | ]) 136 | train_transform_org = transforms.Compose([ 137 | transforms.RandomCrop(32, padding=4), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | ]) 141 | transform_train = TwoCropTransformAdv(transform_train, train_transform_org) 142 | 143 | transform_test = transforms.Compose([ 144 | transforms.ToTensor(), 145 | ]) 146 | 147 | label_pseudo_train_list = [] 148 | num_classes_list = [2, 10, 50, 100, 500] 149 | 150 | if args.dataset == "cifar10": 151 | dict_name = 'data/{}_pseudo_labels.pkl'.format(args.cname) 152 | elif args.dataset == "cifar100": 153 | dict_name = 'data/cifar100_pseudo_labels.pkl' 154 | f = open(dict_name, 'rb') 155 | feat_label_dict = pickle.load(f) # dump data to f 156 | f.close() 157 | for i in range(5): 158 | class_num = num_classes_list[i] 159 | key_train = 'pseudo_train_{}'.format(class_num) 160 | label_pseudo_train = feat_label_dict[key_train] 161 | label_pseudo_train_list.append(label_pseudo_train) 162 | 163 | data_path = "~/data/" 164 | 165 | if args.dataset == "cifar10": 166 | train_dataset = CIFAR10IndexPseudoLabelEnsemble(root=data_path + 'cifar10/', 167 | transform=transform_train, 168 | pseudoLabel_002=label_pseudo_train_list[0], 169 | pseudoLabel_010=label_pseudo_train_list[1], 170 | pseudoLabel_050=label_pseudo_train_list[2], 171 | pseudoLabel_100=label_pseudo_train_list[3], 172 | pseudoLabel_500=label_pseudo_train_list[4], 173 | download=True) 174 | elif args.dataset == "cifar100": 175 | train_dataset = CIFAR100IndexPseudoLabelEnsemble(root=data_path + 'cifar100/', 176 | transform=transform_train, 177 | pseudoLabel_002=label_pseudo_train_list[0], 178 | pseudoLabel_010=label_pseudo_train_list[1], 179 | pseudoLabel_050=label_pseudo_train_list[2], 180 | pseudoLabel_100=label_pseudo_train_list[3], 181 | pseudoLabel_500=label_pseudo_train_list[4], 182 | download=True) 183 | # Data Loader 184 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 185 | batch_size=batch_size, 186 | shuffle=True, 187 | num_workers=n_gpu * 4) 188 | 189 | 190 | # ================================================================== # 191 | # Model, Loss and Optimizer # 192 | # ================================================================== # 193 | 194 | # PGD attack model 195 | class AttackPGD(nn.Module): 196 | def __init__(self, model, config): 197 | super(AttackPGD, self).__init__() 198 | self.model = model 199 | self.rand = config['random_start'] 200 | self.step_size = config['step_size'] 201 | self.epsilon = config['epsilon'] 202 | self.num_steps = config['num_steps'] 203 | assert config['loss_func'] == 'xent', 'Plz use xent for loss function.' 204 | 205 | def forward(self, images_t1, images_t2, images_org, targets, criterion): 206 | x1 = images_t1.clone().detach() 207 | x2 = images_t2.clone().detach() 208 | x_cl = images_org.clone().detach() 209 | x_ce = images_org.clone().detach() 210 | 211 | images_org_high = generate_high(x_cl.clone(), r=args.radius) 212 | x_HFC = images_org_high.clone().detach() 213 | 214 | if self.rand: 215 | x_cl = x_cl + torch.zeros_like(x1).uniform_(-self.epsilon, self.epsilon) 216 | x_ce = x_ce + torch.zeros_like(x1).uniform_(-self.epsilon, self.epsilon) 217 | 218 | for i in range(self.num_steps): 219 | x_cl.requires_grad_() 220 | x_ce.requires_grad_() 221 | with torch.enable_grad(): 222 | f_proj, f_pred = self.model(x_cl, bn_name='pgd', contrast=True) 223 | fce_proj, fce_pred, logits_ce = self.model(x_ce, bn_name='pgd_ce', contrast=True, CF=True, 224 | return_logits=True, nonlinear=False) 225 | f1_proj, f1_pred = self.model(x1, bn_name='normal', contrast=True) 226 | f2_proj, f2_pred = self.model(x2, bn_name='normal', contrast=True) 227 | f_high_proj, f_high_pred = self.model(x_HFC, bn_name='normal', contrast=True) 228 | features = torch.cat( 229 | [f_proj.unsqueeze(1), f1_proj.unsqueeze(1), f2_proj.unsqueeze(1), f_high_proj.unsqueeze(1)], dim=1) 230 | loss_contrast = criterion(features, stop_grad=False) 231 | loss_ce = 0 232 | for label_idx in range(5): 233 | tgt = targets[label_idx].long() 234 | lgt = logits_ce[label_idx] 235 | loss_ce += F.cross_entropy(lgt, tgt, size_average=False, ignore_index=-1) / 5. 236 | loss = loss_contrast + loss_ce * args.ce_weight 237 | # torch.autograd.set_detect_anomaly(True) 238 | grad_x_cl, grad_x_ce = torch.autograd.grad(loss, [x_cl, x_ce]) 239 | x_cl = x_cl.detach() + self.step_size * torch.sign(grad_x_cl.detach()) 240 | x_cl = torch.min(torch.max(x_cl, images_org - self.epsilon), images_org + self.epsilon) 241 | x_cl = torch.clamp(x_cl, 0, 1) 242 | x_ce = x_ce.detach() + self.step_size * torch.sign(grad_x_ce.detach()) 243 | x_ce = torch.min(torch.max(x_ce, images_org - self.epsilon), images_org + self.epsilon) 244 | x_ce = torch.clamp(x_ce, 0, 1) 245 | return x1, x2, x_cl, x_ce, x_HFC 246 | 247 | 248 | print('=====> Building model...') 249 | bn_names = ['normal', 'pgd', 'pgd_ce'] 250 | model = ResNet18(bn_names=bn_names) 251 | model = model.cuda() 252 | # tb_logger 253 | if not os.path.exists('./logger'): 254 | os.makedirs('./logger') 255 | logname = ('./logger/pretrain_{}'.format(args.name)) 256 | logger = tb_logger.Logger(logdir=logname, flush_secs=2) 257 | if torch.cuda.device_count() > 1: 258 | print("=====> Let's use", torch.cuda.device_count(), "GPUs!") 259 | model = apex.parallel.convert_syncbn_model(model) 260 | model = nn.DataParallel(model) 261 | model = model.cuda() 262 | cudnn.benchmark = True 263 | else: 264 | print('single gpu version is not supported, please use multiple GPUs!') 265 | raise NotImplementedError 266 | net = AttackPGD(model, config) 267 | # Loss and optimizer 268 | ce_criterion = nn.CrossEntropyLoss(ignore_index=-1) 269 | if args.HN: # loss of hard negative 270 | contrast_criterion = SupConLoss(args, temperature=args.nce_t) 271 | else: # loss of inferior positive 272 | contrast_criterion = ori_SupConLoss(args, temperature=args.nce_t) 273 | optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.decay) 274 | 275 | 276 | # ================================================================== # 277 | # Train and Test # 278 | # ================================================================== # 279 | 280 | def train(epoch): 281 | print('\nEpoch: %d' % epoch) 282 | net.train() 283 | train_loss = 0 284 | correct = 0 285 | total = 0 286 | 287 | for batch_idx, (inputs, _, targets, ind) in enumerate(train_loader): 288 | tt = [] 289 | for tt_ in targets: 290 | tt.append(tt_.to(device).long()) 291 | targets = tt 292 | image_t1, image_t2, image_org = inputs 293 | image_t1 = image_t1.cuda(non_blocking=True) 294 | image_t2 = image_t2.cuda(non_blocking=True) 295 | image_org = image_org.cuda(non_blocking=True) 296 | warmup_learning_rate(args, epoch + 1, batch_idx, len(train_loader), optimizer) 297 | # attack contrast 298 | optimizer.zero_grad() 299 | if args.attack_ori: 300 | attack_criterion = ori_SupConLoss(args, temperature=args.nce_t) 301 | else: 302 | attack_criterion = contrast_criterion 303 | x1, x2, x_cl, x_ce, x_HFC = net(image_t1, image_t2, image_org, targets, attack_criterion) 304 | f_proj, f_pred = model(x_cl, bn_name='pgd', contrast=True) 305 | fce_proj, fce_pred, logits_ce = model(x_ce, bn_name='pgd_ce', contrast=True, CF=True, return_logits=True, 306 | nonlinear=False) 307 | # ======== aug1&aug2&HF ======== 308 | f1_proj, f1_pred = model(x1, bn_name='normal', contrast=True) 309 | f2_proj, f2_pred = model(x2, bn_name='normal', contrast=True) 310 | f_high_proj, f_high_pred = model(x_HFC, bn_name='normal', contrast=True) 311 | features = torch.cat( 312 | [f_proj.unsqueeze(1), f1_proj.unsqueeze(1), f2_proj.unsqueeze(1), f_high_proj.unsqueeze(1)], dim=1) 313 | 314 | # ======== adaptive grad stopping ======== 315 | stop_grad_sd = args.stpg_degree 316 | if args.stop_grad_adaptive >= 0: 317 | with torch.no_grad(): 318 | f_ori = model(image_org, bn_name='normal', return_feat=True) 319 | f_cl = model(x_cl, bn_name='normal', return_feat=True) 320 | distance = F.pairwise_distance(f_ori, f_cl) 321 | mean_distance = distance.mean() 322 | if epoch + 1 > args.stop_grad_adaptive: 323 | if mean_distance > dis_avg.avg: 324 | pass 325 | elif mean_distance < args.d_min: 326 | stop_grad_sd = 0.5 327 | else: 328 | stop_grad_sd = (dis_avg.avg - mean_distance) / (dis_avg.avg - args.d_min) * ( 329 | 0.5 - args.stpg_degree) + args.stpg_degree 330 | 331 | else: 332 | if mean_distance > args.d_max: 333 | args.d_max = mean_distance 334 | dis_avg.update(mean_distance, image_org.shape[0]) 335 | 336 | stop_grad_sd_print = stop_grad_sd if isinstance(stop_grad_sd, float) else stop_grad_sd.mean() 337 | 338 | logger.log_value('stop-degree', stop_grad_sd_print, epoch * len(train_loader) + batch_idx) 339 | if batch_idx % 40 == 0: 340 | print( 341 | "dis_avg: {:.3f}, d_max: {:.3f}, distance: {:.3f}, stop_grad_sd: {}".format(dis_avg.avg, args.d_max, 342 | mean_distance, 343 | stop_grad_sd_print)) 344 | 345 | contrast_loss = contrast_criterion(features, stop_grad=args.stop_grad, stop_grad_sd=stop_grad_sd) 346 | ce_loss = 0 347 | for label_idx in range(5): 348 | tgt = targets[label_idx].long() 349 | lgt = logits_ce[label_idx] 350 | ce_loss += ce_criterion(lgt, tgt) / 5. 351 | 352 | loss = contrast_loss + ce_loss * args.ce_weight 353 | loss.backward() 354 | optimizer.step() 355 | 356 | train_loss += loss.item() 357 | total += targets[0].size(0) 358 | 359 | progress_bar(batch_idx, len(train_loader), 360 | 'Loss: %.3f (%d/%d)' 361 | % (train_loss / (batch_idx + 1), correct, total)) 362 | 363 | return train_loss / batch_idx, 0. 364 | 365 | 366 | # ================================================================== # 367 | # Checkpoint # 368 | # ================================================================== # 369 | 370 | # Save checkpoint 371 | def checkpoint(epoch): 372 | state = { 373 | 'model': model.state_dict(), 374 | 'epoch': epoch, 375 | 'rng_state': torch.get_rng_state() 376 | } 377 | save_dir = './checkpoint/{}'.format(args.name) 378 | if not os.path.isdir(save_dir): 379 | os.makedirs(save_dir) 380 | torch.save(state, '{}/epoch_{}.ckpt'.format(save_dir, epoch)) 381 | print('=====> Saving checkpoint to {}/epoch_{}.ckpt'.format(save_dir, epoch)) 382 | 383 | 384 | # ================================================================== # 385 | # Run the model # 386 | # ================================================================== # 387 | 388 | 389 | np.random.seed(args.seed) 390 | random.seed(args.seed) 391 | torch.manual_seed(args.seed) 392 | torch.cuda.manual_seed_all(args.seed) 393 | 394 | dis_avg = AverageMeter() 395 | for epoch in range(start_epoch, args.epoch + 2): 396 | adjust_learning_rate(args, optimizer, epoch + 1) 397 | train_loss, train_acc = train(epoch) 398 | logger.log_value('train_loss', train_loss, epoch) 399 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 400 | if epoch % args.save_epoch == 0: 401 | checkpoint(epoch) 402 | --------------------------------------------------------------------------------