├── Asym-RoCL
├── data
│ ├── __init__.py
│ ├── .DS_Store
│ ├── vision.py
│ ├── cifar.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── .DS_Store
│ ├── projector.py
│ └── resnet.py
├── warmup_scheduler
│ ├── __init__.py
│ ├── .DS_Store
│ ├── run.py
│ └── scheduler.py
├── init_paths.py
├── model_loader.py
├── shell
│ ├── rocl-cifar10.sh
│ ├── rocl-cifar100.sh
│ ├── rocl-HN-cifar100.sh
│ ├── rocl-IP-cifar10.sh
│ ├── rocl-IP-cifar100.sh
│ ├── rocl-HN-cifar10.sh
│ ├── rocl-IPHN-cifar100.sh
│ └── rocl-IPHN-cifar10.sh
├── trades.py
├── robustness_test.py
├── data_loader.py
├── attack_lib.py
├── utils.py
├── rocl_train.py
├── loss.py
├── linear_eval.py
└── argument.py
├── figures
├── .DS_Store
└── intro.png
├── Asym-AdvCL
├── models
│ ├── .DS_Store
│ ├── linear.py
│ └── resnet_cifar.py
├── shell
│ ├── .DS_Store
│ ├── cifar10
│ │ ├── .DS_Store
│ │ ├── advcl-HN-cifar10.sh
│ │ ├── advcl-IP-cifar10.sh
│ │ ├── advcl-cifar10.sh
│ │ └── advcl-IPHN-cifar10.sh
│ └── cifar100
│ │ ├── .DS_Store
│ │ ├── advcl-HN-cifar100.sh
│ │ ├── advcl-cifar100.sh
│ │ ├── advcl-IP-cifar100.sh
│ │ └── advcl-IPHN-cifar100.sh
├── data
│ ├── cifar100_pseudo_labels.pkl
│ └── imagenet_clPretrain_pseudo_labels.pkl
├── .idea
│ ├── inspectionProfiles
│ │ ├── profiles_settings.xml
│ │ └── Project_Default.xml
│ ├── modules.xml
│ └── Asym-AdvCL.iml
├── fr_util.py
├── dataset.py
├── trades.py
├── utils.py
├── losses.py
├── finetuning_advCL_SLF.py
└── pretraining_advCL.py
└── README.md
/Asym-RoCL/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Asym-RoCL/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .projector import *
3 |
4 |
--------------------------------------------------------------------------------
/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/figures/.DS_Store
--------------------------------------------------------------------------------
/figures/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/figures/intro.png
--------------------------------------------------------------------------------
/Asym-RoCL/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/data/.DS_Store
--------------------------------------------------------------------------------
/Asym-AdvCL/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/models/.DS_Store
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/.DS_Store
--------------------------------------------------------------------------------
/Asym-RoCL/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/models/.DS_Store
--------------------------------------------------------------------------------
/Asym-RoCL/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar10/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/cifar10/.DS_Store
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar100/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/shell/cifar100/.DS_Store
--------------------------------------------------------------------------------
/Asym-RoCL/init_paths.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | sys.path.insert(0, 'lib')
5 |
6 |
--------------------------------------------------------------------------------
/Asym-RoCL/warmup_scheduler/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-RoCL/warmup_scheduler/.DS_Store
--------------------------------------------------------------------------------
/Asym-AdvCL/data/cifar100_pseudo_labels.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/data/cifar100_pseudo_labels.pkl
--------------------------------------------------------------------------------
/Asym-AdvCL/data/imagenet_clPretrain_pseudo_labels.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yqy2001/A-InfoNCE/HEAD/Asym-AdvCL/data/imagenet_clPretrain_pseudo_labels.pkl
--------------------------------------------------------------------------------
/Asym-AdvCL/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Asym-AdvCL/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Asym-AdvCL/.idea/Asym-AdvCL.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Asym-RoCL/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from warmup_scheduler import GradualWarmupScheduler
4 |
5 |
6 | if __name__ == '__main__':
7 | v = torch.zeros(10)
8 | optim = torch.optim.SGD([v], lr=0.01)
9 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=10)
10 |
11 | for epoch in range(1, 20):
12 | scheduler.step(epoch)
13 |
14 | print(epoch, optim.param_groups[0]['lr'])
15 |
--------------------------------------------------------------------------------
/Asym-RoCL/models/projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Projector(nn.Module):
6 | def __init__(self, expansion=0):
7 | super(Projector, self).__init__()
8 |
9 | self.linear_1 = nn.Linear(512*expansion, 2048)
10 | self.linear_2 = nn.Linear(2048, 128)
11 |
12 | def forward(self, x):
13 |
14 | output = self.linear_1(x)
15 | output = F.relu(output)
16 |
17 | output = self.linear_2(output)
18 |
19 | return output
20 |
21 |
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar10/advcl-HN-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --learning_rate 1.2 --dataset cifar10 --cosine --name $1 --epoch 400 --attack_ori --HN --beta 1.0 --tau_plus 0.12
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar10/advcl-HN-cifar10.sh advcl-HN-cifar10 | tee logs/advcl-HN-cifar10.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar100/advcl-HN-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --attack_ori --HN --beta 1. --tau_plus 0.01
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar100/advcl-HN-cifar100.sh advcl-HN-cifar100 | tee logs/advcl-HN-cifar100.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar10/advcl-IP-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --dataset cifar10 --cosine --name $1 --epoch 400 --stpg_degree 0.2 --stop_grad --adv_weight 1.2 --stop_grad_adaptive 30 --learning_rate 1.
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar10/advcl-IP-cifar10.sh advcl-IP-cifar10 | tee logs/advcl-IP-cifar10.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar10/advcl-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # follow the settings in [AdvCL](https://arxiv.org/pdf/2111.01124.pdf)
4 |
5 | python pretraining_advCL.py --learning_rate 0.5 --dataset cifar10 --cosine --name $1 --epoch 400
6 |
7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF
8 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF
9 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
10 |
11 | # sh shell/cifar10/advcl-cifar10.sh advcl-cifar10 | tee logs/advcl-cifar10.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar100/advcl-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # follow the settings in [AdvCL](https://arxiv.org/pdf/2111.01124.pdf)
4 |
5 | python pretraining_advCL.py --learning_rate 0.5 --dataset cifar100 --cosine --name $1 --epoch 400
6 |
7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF
8 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF
9 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
10 |
11 | # sh shell/cifar100/advcl-cifar100.sh advcl-cifar100 | tee logs/advcl-cifar100.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar100/advcl-IP-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --stpg_degree 0.3 --stop_grad --adv_weight 1.5 --stop_grad_adaptive 20 --d_min 0.5
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar100/advcl-IP-cifar100.sh advcl-IP-cifar100 | tee logs/advcl-IP-cifar100.out
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar10/advcl-IPHN-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar10 --cosine --name $1 --epoch 400 --stpg_degree 0.2 --stop_grad --adv_weight 0.7 --stop_grad_adaptive 30 --attack_ori --HN
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.01 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar10 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar10/advcl-IPHN-cifar10.sh advcl-IPHN-cifar10 | tee logs/advcl-IPHN-cifar10.out
--------------------------------------------------------------------------------
/Asym-RoCL/model_loader.py:
--------------------------------------------------------------------------------
1 | from models.resnet import ResNet18,ResNet50
2 | def get_model(args):
3 |
4 | if args.dataset == 'cifar-10':
5 | num_classes=10
6 | elif args.dataset == 'cifar-100':
7 | num_classes=100
8 | else:
9 | raise NotImplementedError
10 |
11 | if 'contrastive' in args.train_type or 'linear_eval' in args.train_type:
12 | contrastive_learning=True
13 | else:
14 | contrastive_learning=False
15 |
16 | if args.model == 'ResNet18':
17 | model = ResNet18(num_classes,contrastive_learning)
18 | print('ResNet18 is loading ...')
19 | elif args.model == 'ResNet50':
20 | model = ResNet50(num_classes,contrastive_learning)
21 | print('ResNet 50 is loading ...')
22 | return model
23 |
24 |
--------------------------------------------------------------------------------
/Asym-AdvCL/shell/cifar100/advcl-IPHN-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python pretraining_advCL.py --learning_rate 2. --dataset cifar100 --cosine --name $1 --epoch 400 --stpg_degree 0.3 --stop_grad --adv_weight 1.5 --stop_grad_adaptive 20 --d_min 0.5 --attack_ori --HN --beta 1.0 --tau 0.01
4 |
5 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type SLF
6 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type ALF
7 | python finetuning_advCL_SLF.py --dataset cifar100 --ckpt checkpoint/$1/epoch_400.ckpt --name $1 --learning_rate 0.1 --finetune_type AFF_trades
8 |
9 | # sh shell/cifar100/advcl-IPHN-cifar100.sh advcl-IPHN-cifar100 | tee logs/advcl-IPHN-cifar100.out
--------------------------------------------------------------------------------
/Asym-AdvCL/fr_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def distance(i, j, imageSize, r):
6 | dis = np.sqrt((i - imageSize / 2) ** 2 + (j - imageSize / 2) ** 2)
7 | if dis < r:
8 | return 1.0
9 | else:
10 | return 0
11 |
12 |
13 | def mask_radial(img, r):
14 | rows, cols = img.shape
15 | mask = torch.zeros((rows, cols))
16 | for i in range(rows):
17 | for j in range(cols):
18 | mask[i, j] = distance(i, j, imageSize=rows, r=r)
19 | return mask.cuda()
20 |
21 |
22 | def generate_high(Images, r):
23 | # Image: bsxcxhxw, input batched images
24 | # r: int, radius
25 | mask = mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r)
26 | bs, c, h, w = Images.shape
27 | x = Images.reshape([bs * c, h, w])
28 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1)))
29 | mask = mask.unsqueeze(0).repeat([bs * c, 1, 1])
30 | fd = fd * (1. - mask)
31 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1))
32 | fd = torch.real(fd)
33 | fd = fd.reshape([bs, c, h, w])
34 | return fd
35 |
--------------------------------------------------------------------------------
/Asym-AdvCL/models/linear.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 | class LinearClassifier(nn.Module):
5 | """Linear classifier"""
6 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10):
7 | super(LinearClassifier, self).__init__()
8 | # _, feat_dim = model_dict[name]
9 | self.fc = nn.Linear(feat_dim, num_classes)
10 |
11 | def forward(self, features):
12 | return self.fc(features)
13 |
14 |
15 | class NonLinearClassifier(nn.Module):
16 | """Linear classifier"""
17 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10):
18 | super(NonLinearClassifier, self).__init__()
19 | # _, feat_dim = model_dict[name]
20 | self.fc1 = nn.Linear(feat_dim, feat_dim)
21 | # self.bn1 = nn.BatchNorm1d(512)
22 | self.fc2 = nn.Linear(feat_dim, feat_dim)
23 | # self.bn2 = nn.BatchNorm1d(512)
24 | self.fc3 = nn.Linear(feat_dim, num_classes)
25 |
26 | def forward(self, features):
27 | # features = F.relu(self.bn1(self.fc1(features)))
28 | # features = F.relu(self.bn2(self.fc2(features)))
29 | features = F.relu(self.fc1(features))
30 | features = F.relu(self.fc2(features))
31 | return self.fc3(features)
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12562 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12562 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12562 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12352 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12352 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-cifar10.sh rocl-cifar10 ResNet18 cifar-10 | tee logs/rocl-cifar10.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12561 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12561 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12561 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12351 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12351 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-cifar100.sh rocl-cifar100 ResNet18 cifar-100 | tee logs/rocl-cifar100.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-HN-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12565 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12565 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12565 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12355 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12355 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar100-2Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-HN-cifar100.sh rocl-HN-cifar100 ResNet18 cifar-100 | tee logs/rocl-HN-cifar100.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-IP-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12564 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --stop_grad --stpg_degree 0.2
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12564 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12564 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12354 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12354 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256 \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-IP-cifar10.sh rocl-IP-cifar10 ResNet18 cifar-10 | tee logs/rocl-IP-cifar10.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-IP-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12567 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --stop_grad --stpg_degree 0.2
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12567 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12567 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12357 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12357 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IP-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256 \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-IP-cifar100.sh rocl-IP-cifar100 ResNet18 cifar-100 | tee logs/rocl-IP-cifar100.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-HN-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12563 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN --beta 1.0 --tau 0.12\
6 | --lr 0.2
7 |
8 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12563 \
9 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
10 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
11 | --clean=True --dataset=$3
12 |
13 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12563 \
14 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
15 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
16 |
17 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12353 \
18 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
19 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
20 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
21 |
22 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12353 \
23 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
24 | --load_checkpoint=checkpoint/ckpt.t7rocl-HN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
25 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
26 |
27 | # sh shell/rocl-HN-cifar10.sh rocl-HN-cifar10 ResNet18 cifar-10 | tee logs/rocl-HN-cifar10.out
28 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-IPHN-cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12560 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 --HN --stop_grad --stpg_degree 0.2 --lr 0.2
6 |
7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12560 \
8 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
9 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
10 | --clean=True --dataset=$3
11 |
12 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12560 \
13 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
14 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
15 |
16 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12350 \
17 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
18 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
19 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
20 |
21 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12350 \
22 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
23 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar100Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-100_b256_nGPU2_l256_HN \
24 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
25 |
26 | # sh shell/rocl-IPHN-cifar100.sh rocl-IPHN-cifar100 ResNet18 cifar-100 | tee logs/rocl-IPHN-cifar100.out
27 |
--------------------------------------------------------------------------------
/Asym-RoCL/shell/rocl-IPHN-cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12566 \
4 | rocl_train.py --ngpu 2 --batch-size=256 --model=$2 --k=7 --loss_type=sim --advtrain_type=Rep --attack_type=linf \
5 | --name=$1 --regularize_to=other --attack_to=other --train_type=contrastive --dataset=$3 \
6 | --HN --beta 1.0 --tau 0.12 --stop_grad --stpg_degree 0.2 --lr 0.3
7 |
8 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12566 \
9 | linear_eval.py --ngpu 1 --batch-size=1024 --train_type=linear_eval --model=$2 --epoch 150 --lr 0.1 --name $1 \
10 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
11 | --clean=True --dataset=$3
12 |
13 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12566 \
14 | robustness_test.py --ngpu 1 --train_type=linear_eval --name=$1 --batch-size=1024 --model=$2 \
15 | --load_checkpoint='./checkpoint/ckpt.t7'$1'_Evaluate_linear_eval_ResNet18_'$3 --attack_type=linf --epsilon=0.0314 --alpha=0.00314 --k=20 --dataset=$3
16 |
17 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12356 \
18 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type ALF --model=$2 --epoch 25 --lr 0.1 --name $1 \
19 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
20 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
21 |
22 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12356 \
23 | rocl_finetune.py --ngpu 1 --batch-size=1024 --finetune_type AFF_trades --model=$2 --epoch 25 --lr 0.1 --name $1 \
24 | --load_checkpoint=checkpoint/ckpt.t7rocl-IPHN-cifar10Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_HN \
25 | --dataset=$3 --epsilon=0.0314 --alpha=0.00314 --k=20
26 |
27 | # sh shell/rocl-IPHN-cifar10.sh rocl-IPHN-cifar10 ResNet18 cifar-10 | tee logs/rocl-IPHN-cifar10.out
--------------------------------------------------------------------------------
/Asym-AdvCL/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
37 |
38 |
39 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/Asym-AdvCL/dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms, datasets
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 |
5 | class CIFAR10IndexPseudoLabelEnsemble(Dataset):
6 | def __init__(self, root='', transform=None, download=False, train=True,
7 | pseudoLabel_002=None,
8 | pseudoLabel_010=None,
9 | pseudoLabel_050=None,
10 | pseudoLabel_100=None,
11 | pseudoLabel_500=None,
12 | ):
13 | self.cifar10 = datasets.CIFAR10(root=root,
14 | download=download,
15 | train=train,
16 | transform=transform)
17 |
18 | self.pseudo_label_002 = pseudoLabel_002
19 | self.pseudo_label_010 = pseudoLabel_010
20 | self.pseudo_label_050 = pseudoLabel_050
21 | self.pseudo_label_100 = pseudoLabel_100
22 | self.pseudo_label_500 = pseudoLabel_500
23 |
24 | def __getitem__(self, index):
25 | data, target = self.cifar10[index]
26 | label_p_002 = self.pseudo_label_002[index]
27 | label_p_010 = self.pseudo_label_010[index]
28 | label_p_050 = self.pseudo_label_050[index]
29 | label_p_100 = self.pseudo_label_100[index]
30 | label_p_500 = self.pseudo_label_500[index]
31 |
32 | label_p = (label_p_002,
33 | label_p_010,
34 | label_p_050,
35 | label_p_100,
36 | label_p_500)
37 | return data, target, label_p, index
38 |
39 | def __len__(self):
40 | return len(self.cifar10)
41 |
42 |
43 | class CIFAR100IndexPseudoLabelEnsemble(Dataset):
44 | def __init__(self, root='', transform=None, download=False, train=True,
45 | pseudoLabel_002=None,
46 | pseudoLabel_010=None,
47 | pseudoLabel_050=None,
48 | pseudoLabel_100=None,
49 | pseudoLabel_500=None,
50 | ):
51 | self.cifar100 = datasets.CIFAR100(root=root,
52 | download=download,
53 | train=train,
54 | transform=transform)
55 |
56 | self.pseudo_label_002 = pseudoLabel_002
57 | self.pseudo_label_010 = pseudoLabel_010
58 | self.pseudo_label_050 = pseudoLabel_050
59 | self.pseudo_label_100 = pseudoLabel_100
60 | self.pseudo_label_500 = pseudoLabel_500
61 |
62 | def __getitem__(self, index):
63 | data, target = self.cifar100[index]
64 | label_p_002 = self.pseudo_label_002[index]
65 | label_p_010 = self.pseudo_label_010[index]
66 | label_p_050 = self.pseudo_label_050[index]
67 | label_p_100 = self.pseudo_label_100[index]
68 | label_p_500 = self.pseudo_label_500[index]
69 |
70 | label_p = (label_p_002,
71 | label_p_010,
72 | label_p_050,
73 | label_p_100,
74 | label_p_500)
75 | return data, target, label_p, index
76 |
77 | def __len__(self):
78 | return len(self.cifar100)
79 |
--------------------------------------------------------------------------------
/Asym-RoCL/trades.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | def trades_loss(model,
7 | classifier,
8 | x_natural,
9 | x_adv,
10 | y,
11 | optimizer,
12 | beta=6.0,
13 | distance='l_inf',
14 | step_size=2./255.,
15 | epsilon=8./255.,
16 | perturb_steps=10,
17 | trainmode='adv',
18 | fixmode='',
19 | trades=True
20 | ):
21 | if trainmode == "adv":
22 | batch_size = len(x_natural)
23 | # define KL-loss
24 | criterion_kl = nn.KLDivLoss(size_average=False)
25 | model.eval()
26 |
27 | if trades:
28 | # generate adversarial example
29 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
30 | if distance == 'l_inf':
31 | for _ in range(perturb_steps):
32 | x_adv.requires_grad_()
33 | with torch.enable_grad():
34 | model.eval()
35 | loss_kl = criterion_kl(F.log_softmax(classifier(model(x_adv)), dim=1),
36 | F.softmax(classifier(model(x_natural)), dim=1))
37 | grad = torch.autograd.grad(loss_kl, [x_adv])[0]
38 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
39 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
40 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
41 | elif distance == 'l_2':
42 | assert False
43 | else:
44 | assert False
45 |
46 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
47 |
48 | # zero gradient
49 | optimizer.zero_grad()
50 | # calculate robust loss
51 | model.train()
52 |
53 | if fixmode == 'f1':
54 | for name, param in model.named_parameters():
55 | param.requires_grad = True
56 | elif fixmode == 'f2':
57 | # fix previous three layers
58 | for name, param in model.named_parameters():
59 | if not ("layer4" in name or "fc" in name):
60 | param.requires_grad = False
61 | else:
62 | param.requires_grad = True
63 | elif fixmode == 'f3':
64 | # fix every layer except fc
65 | # fix previous four layers
66 | for name, param in model.named_parameters():
67 | if not ("fc" in name):
68 | param.requires_grad = False
69 | else:
70 | param.requires_grad = True
71 | else:
72 | assert False
73 |
74 | logits = classifier(model(x_natural))
75 |
76 | loss = F.cross_entropy(logits, y)
77 |
78 | if trainmode == "adv":
79 | logits_adv = classifier(model(x_adv))
80 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1),
81 | F.softmax(logits, dim=1))
82 | loss += beta * loss_robust
83 | return loss, logits
--------------------------------------------------------------------------------
/Asym-RoCL/data/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
81 |
--------------------------------------------------------------------------------
/Asym-RoCL/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------
/Asym-AdvCL/trades.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | def trades_loss(model,
8 | classifier,
9 | x_natural,
10 | x_adv,
11 | y,
12 | optimizer,
13 | beta=6.0,
14 | distance='l_inf',
15 | step_size=2./255.,
16 | epsilon=8./255.,
17 | perturb_steps=10,
18 | trainmode='adv',
19 | fixmode='',
20 | trades=True
21 | ):
22 | if trainmode == "adv":
23 | batch_size = len(x_natural)
24 | # define KL-loss
25 | criterion_kl = nn.KLDivLoss(size_average=False)
26 | model.eval()
27 |
28 | if trades:
29 | # generate adversarial example
30 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
31 | if distance == 'l_inf':
32 | for _ in range(perturb_steps):
33 | x_adv.requires_grad_()
34 | with torch.enable_grad():
35 | model.eval()
36 | loss_kl = criterion_kl(F.log_softmax(classifier(model(x_adv, return_feat=True)), dim=1),
37 | F.softmax(classifier(model(x_natural, return_feat=True)), dim=1))
38 | grad = torch.autograd.grad(loss_kl, [x_adv])[0]
39 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
40 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
41 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
42 | elif distance == 'l_2':
43 | assert False
44 | else:
45 | assert False
46 |
47 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
48 |
49 | # zero gradient
50 | optimizer.zero_grad()
51 | # calculate robust loss
52 | model.train()
53 |
54 | if fixmode == 'f1':
55 | for name, param in model.named_parameters():
56 | param.requires_grad = True
57 | elif fixmode == 'f2':
58 | # fix previous three layers
59 | for name, param in model.named_parameters():
60 | if not ("layer4" in name or "fc" in name):
61 | param.requires_grad = False
62 | else:
63 | param.requires_grad = True
64 | elif fixmode == 'f3':
65 | # fix every layer except fc
66 | # fix previous four layers
67 | for name, param in model.named_parameters():
68 | if not ("fc" in name):
69 | param.requires_grad = False
70 | else:
71 | param.requires_grad = True
72 | else:
73 | assert False
74 |
75 | logits = classifier(model(x_natural, return_feat=True))
76 |
77 | loss = F.cross_entropy(logits, y)
78 |
79 | if trainmode == "adv":
80 | logits_adv = classifier(model(x_adv, return_feat=True))
81 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1),
82 | F.softmax(logits, dim=1))
83 | loss += beta * loss_robust
84 | return loss, logits
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial Contrastive Learning via Asymmetric InfoNCE
2 |
3 |
4 |
5 | This is a PyTorch implementation of the paper [Adversarial Contrastive Learning via Asymmetric InfoNCE](https://arxiv.org/abs/2207.08374) (ECCV 2022).
6 |
7 | ## Preparation
8 |
9 | In each sub-directory (`./Asym-AdvCL`, `./Asym-RoCL`) for each baseline, please first create `./logs` directories, then download CIFAR datasets into `~/data`. All models will be stored in `./checkpoint`, and logs stored in `./logs`.
10 |
11 | ### Environment requirements
12 |
13 | - PyTorch >= 1.8
14 | - NVIDIA apex
15 | - numpy
16 | - tensorboard_logger
17 | - [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2
18 | - [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr)
19 | - [diffdist](https://github.com/ag14774/diffdist) == 0.1
20 |
21 | This repo is a modification on the [RoCL](https://github.com/Kim-Minseon/RoCL) and [AdvCL](https://github.com/LijieFan/AdvCL) repos. Environments can also be installed according to the requirements of [RoCL](https://github.com/Kim-Minseon/RoCL) and [AdvCL](https://github.com/LijieFan/AdvCL) for experiments of each baseline.
22 |
23 | ## Training
24 |
25 | We provide shells for reproducing our main results in Table 1. The hyperparameter settings of baselines are the same as their original papers reported. All experiments were conducted on 2 Tesla V100 GPUs.
26 |
27 | First, for [RoCL](https://arxiv.org/abs/2006.07589):
28 |
29 | * RoCL on CIFAR10
30 |
31 | ```bash
32 | sh shell/rocl-cifar10.sh rocl-cifar10 ResNet18 cifar-10 | tee logs/rocl-cifar10.log
33 | ```
34 |
35 | * RoCL-IP (RoCL with inferior positive) on CIFAR10
36 |
37 | ```bash
38 | sh shell/rocl-IP-cifar10.sh rocl-IP-cifar10 ResNet18 cifar-10 | tee logs/rocl-IP-cifar10.log
39 | ```
40 |
41 | * RoCL-HN (RoCL with hard negative) on CIFAR10
42 |
43 | ```bash
44 | sh shell/rocl-HN-cifar10.sh rocl-HN-cifar10 ResNet18 cifar-10 | tee logs/rocl-HN-cifar10.log
45 | ```
46 |
47 | * RoCL-IPHN (RoCL with both) on CIFAR10
48 |
49 | ```bash
50 | sh shell/rocl-IPHN-cifar10.sh rocl-IPHN-cifar10 ResNet18 cifar-10 | tee logs/rocl-IPHN-cifar10.log
51 | ```
52 |
53 | * RoCL on CIFAR100
54 |
55 | ```bash
56 | sh shell/rocl-cifar100.sh rocl-cifar100 ResNet18 cifar-100 | tee logs/rocl-cifar100.log
57 | ```
58 |
59 | * RoCL-IP on CIFAR100
60 |
61 | ```bash
62 | sh shell/rocl-IP-cifar100.sh rocl-IP-cifar100 ResNet18 cifar-100 | tee logs/rocl-IP-cifar100.log
63 | ```
64 |
65 | * RoCL-HN on CIFAR100
66 |
67 | ```bash
68 | sh shell/rocl-HN-cifar100.sh rocl-HN-cifar100 ResNet18 cifar-100 | tee logs/rocl-HN-cifar100.log
69 | ```
70 |
71 | * RoCL-IPHN on CIFAR100
72 |
73 | ```bash
74 | sh shell/rocl-IPHN-cifar100.sh rocl-IPHN-cifar100 ResNet18 cifar-100 | tee logs/rocl-IPHN-cifar100.log
75 | ```
76 |
77 | For [AdvCL](https://arxiv.org/pdf/2111.01124.pdf):
78 |
79 | * AdvCL on CIFAR10
80 |
81 | ```bash
82 | sh shell/cifar10/advcl-cifar10.sh advcl-cifar10 | tee logs/advcl-cifar10.log
83 | ```
84 |
85 | * AdvCL-IP on CIFAR10
86 |
87 | ```bash
88 | sh shell/cifar10/advcl-IP-cifar10.sh advcl-IP-cifar10 | tee logs/advcl-IP-cifar10.log
89 | ```
90 |
91 | * AdvCL-HN on CIFAR10
92 |
93 | ```bash
94 | sh shell/cifar10/advcl-HN-cifar10.sh advcl-HN-cifar10 | tee logs/advcl-HN-cifar10.log
95 | ```
96 |
97 | * AdvCL-IPHN on CIFAR10
98 |
99 | ```bash
100 | sh shell/cifar10/advcl-IPHN-cifar10.sh advcl-IPHN-cifar10 | tee logs/advcl-IPHN-cifar10.log
101 | ```
102 |
103 | * AdvCL on CIFAR100
104 |
105 | ```bash
106 | sh shell/cifar100/advcl-cifar100.sh advcl-cifar100 | tee logs/advcl-cifar100.log
107 | ```
108 |
109 | * AdvCL-IP on CIFAR100
110 |
111 | ```bash
112 | sh shell/cifar100/advcl-IP-cifar100.sh advcl-IP-cifar100 | tee logs/advcl-IP-cifar100.log
113 | ```
114 |
115 | * AdvCL-HN on CIFAR100
116 |
117 | ```bash
118 | sh shell/cifar100/advcl-HN-cifar100.sh advcl-HN-cifar100 | tee logs/advcl-HN-cifar100.log
119 | ```
120 |
121 | * AdvCL-IPHN on CIFAR100
122 |
123 | ```bash
124 | sh shell/cifar100/advcl-IPHN-cifar100.sh advcl-IPHN-cifar100 | tee logs/advcl-IPHN-cifar100.log
125 | ```
126 |
127 | ## Citation
128 |
129 | If you find our work useful or provides some new insights about adversarial contrastive learning:blush:, please consider citing:
130 |
131 | ```
132 | @inproceedings{yu2022adversarial,
133 | title={Adversarial Contrastive Learning via Asymmetric InfoNCE},
134 | author={Yu, Qiying and Lou, Jieming and Zhan, Xianyuan and Li, Qizhang and Zuo, Wangmeng and Liu, Yang and Liu, Jingjing},
135 | booktitle={European Conference on Computer Vision},
136 | pages={53--69},
137 | year={2022},
138 | organization={Springer}
139 | }
140 | ```
141 |
142 | ## Acknowledgements
143 |
144 | We thank for the code implementation from [RoCL](https://github.com/Kim-Minseon/RoCL), [AdvCL](https://github.com/LijieFan/AdvCL), [HCL](https://github.com/joshr17/HCL) and [SupContrast](https://github.com/HobbitLong/SupContrast).
145 |
--------------------------------------------------------------------------------
/Asym-RoCL/robustness_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 |
3 | from __future__ import print_function
4 |
5 | import csv
6 | import os
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.nn as nn
11 | import data_loader
12 | import model_loader
13 | from argument import test_parser
14 |
15 | from utils import progress_bar
16 | from collections import OrderedDict
17 |
18 | from attack_lib import FastGradientSignUntargeted
19 |
20 | args = test_parser()
21 | use_cuda = torch.cuda.is_available()
22 | if use_cuda:
23 | ngpus_per_node = torch.cuda.device_count()
24 |
25 | def print_status(string):
26 | if args.local_rank % ngpus_per_node ==0:
27 | print(string)
28 |
29 | if args.local_rank % ngpus_per_node == 0:
30 | print(torch.cuda.device_count())
31 | print('Using CUDA..')
32 | best_acc = 0 # best test accuracy
33 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
34 |
35 | if args.seed != 0:
36 | torch.manual_seed(args.seed)
37 |
38 | # Data
39 | print_status('==> Preparing data..')
40 |
41 | trainloader, traindst, testloader, testdst = data_loader.get_dataset(args)
42 |
43 | if args.dataset == 'cifar-10':
44 | num_outputs = 10
45 | elif args.dataset == 'cifar-100':
46 | num_outputs = 100
47 |
48 | if args.model == 'ResNet50':
49 | expansion = 4
50 | else:
51 | expansion = 1
52 |
53 | # Model
54 | print_status('==> Building model..')
55 | train_type = args.train_type
56 |
57 | model = model_loader.get_model(args)#models.__dict__[args.model]()
58 | if args.dataset=='cifar-10':
59 | Linear = nn.Sequential(nn.Linear(512*expansion, 10))
60 | elif args.dataset=='cifar-100':
61 | Linear = nn.Sequential(nn.Linear(512*expansion, 100))
62 |
63 | checkpoint_ = torch.load(args.load_checkpoint)
64 | new_state_dict = OrderedDict()
65 | for k, v in checkpoint_['model'].items():
66 | name = k[7:]
67 | new_state_dict[name] = v
68 |
69 | model.load_state_dict(new_state_dict)
70 |
71 | linearcheckpoint_ = torch.load(args.load_checkpoint+'_linear')
72 | new_state_dict = OrderedDict()
73 | for k, v in linearcheckpoint_['model'].items():
74 | name = k[7:] # remove `module.`
75 | new_state_dict[name] = v
76 | Linear.load_state_dict(new_state_dict)
77 |
78 | criterion = nn.CrossEntropyLoss()
79 |
80 | use_cuda = torch.cuda.is_available()
81 | if use_cuda:
82 | ngpus_per_node = torch.cuda.device_count()
83 | model.cuda()
84 | Linear.cuda()
85 | print_status(torch.cuda.device_count())
86 | print_status('Using CUDA..')
87 | cudnn.benchmark = True
88 |
89 | attack_info = 'epsilon_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(0.0) + '_max_val_' + str(1.0) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start)
90 | print_status("Attack information...")
91 | print_status(attack_info)
92 | attacker = FastGradientSignUntargeted(model, Linear, epsilon=args.epsilon, alpha=args.alpha, min_val=0.0, max_val=1.0, max_iters=args.k, _type=args.attack_type)
93 |
94 | def test(attacker):
95 | global best_acc
96 |
97 | model.eval()
98 | Linear.eval()
99 |
100 | test_clean_loss = 0
101 | test_adv_loss = 0
102 | clean_correct = 0
103 | adv_correct = 0
104 | clean_acc = 0
105 | total = 0
106 |
107 | for idx, (image, label) in enumerate(testloader):
108 |
109 | img = image.cuda()
110 | y = label.cuda()
111 | total += y.size(0)
112 | if 'ResNet18' in args.model:
113 | if args.epsilon==0.0314 or args.epsilon==0.047:
114 | out = Linear(model(img))
115 | _, predx = torch.max(out.data, 1)
116 | clean_loss = criterion(out, y)
117 |
118 | clean_correct += predx.eq(y.data).cpu().sum().item()
119 |
120 | clean_acc = 100.*clean_correct/total
121 |
122 | test_clean_loss += clean_loss.data
123 |
124 | adv_inputs = attacker.perturb(original_images=img, labels=y, random_start=args.random_start)
125 |
126 | out = Linear(model(adv_inputs))
127 |
128 | _, predx = torch.max(out.data, 1)
129 | adv_loss = criterion(out, y)
130 |
131 | adv_correct += predx.eq(y.data).cpu().sum().item()
132 | adv_acc = 100.*adv_correct/total
133 |
134 |
135 | test_adv_loss += adv_loss.data
136 | if args.local_rank % ngpus_per_node == 0:
137 | progress_bar(idx, len(testloader),'Testing Loss {:.3f}, acc {:.3f} , adv Loss {:.3f}, adv acc {:.3f}'.format(test_clean_loss/(idx+1), clean_acc, test_adv_loss/(idx+1), adv_acc))
138 |
139 | print ("Test accuracy: {0}/{1}".format(clean_acc, adv_acc))
140 |
141 | return (clean_acc, adv_acc)
142 |
143 | test_acc, adv_acc = test(attacker)
144 |
145 | if not os.path.isdir('results'):
146 | os.mkdir('results')
147 |
148 | args.name += ('_Robust'+ args.train_type + '_' +args.model + '_' + args.dataset)
149 | loginfo = 'results/log_' + args.name + '_' + str(args.seed)
150 | logname = (loginfo+ '.csv')
151 |
152 | with open(logname, 'w') as logfile:
153 | logwriter = csv.writer(logfile, delimiter=',')
154 | logwriter.writerow(['random_start', 'attack_type','epsilon','k','adv_acc'])
155 |
156 | if args.local_rank % ngpus_per_node == 0:
157 | with open(logname, 'a') as logfile:
158 | logwriter = csv.writer(logfile, delimiter=',')
159 | logwriter.writerow([args.random_start, args.attack_type, args.epsilon, args.k, adv_acc])
160 |
161 |
--------------------------------------------------------------------------------
/Asym-AdvCL/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as init
10 |
11 | # _, term_width = os.popen('stty size', 'r').read().split()
12 | # term_width = int(term_width)
13 | term_width = 80
14 |
15 | TOTAL_BAR_LENGTH = 86.
16 | last_time = time.time()
17 | begin_time = last_time
18 |
19 |
20 | def progress_bar(current, total, msg=None):
21 | global last_time, begin_time
22 | if current == 0:
23 | begin_time = time.time() # Reset for new bar.
24 |
25 | cur_len = int(TOTAL_BAR_LENGTH * current / total)
26 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
27 |
28 | sys.stdout.write(' [')
29 | for i in range(cur_len):
30 | sys.stdout.write('=')
31 | sys.stdout.write('>')
32 | for i in range(rest_len):
33 | sys.stdout.write('.')
34 | sys.stdout.write(']')
35 |
36 | cur_time = time.time()
37 | step_time = cur_time - last_time
38 | last_time = cur_time
39 | tot_time = cur_time - begin_time
40 |
41 | L = []
42 | L.append(' Step: %s' % format_time(step_time))
43 | L.append(' | Tot: %s' % format_time(tot_time))
44 | if msg:
45 | L.append(' | ' + msg)
46 |
47 | msg = ''.join(L)
48 | sys.stdout.write(msg)
49 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
50 | sys.stdout.write(' ')
51 |
52 | # Go back to the center of the bar
53 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)):
54 | sys.stdout.write('\b')
55 | sys.stdout.write(' %d/%d ' % (current + 1, total))
56 |
57 | if current < total - 1:
58 | sys.stdout.write('\r')
59 | else:
60 | sys.stdout.write('\n')
61 | sys.stdout.flush()
62 |
63 |
64 | def load_BN_checkpoint(state_dict):
65 | new_state_dict = {}
66 | new_state_dict_normal = {}
67 | for k, v in state_dict.items():
68 | if 'downsample' in k:
69 | k = k.replace('downsample', 'shortcut')
70 | if 'shortcut.bn.bn_list.0' in k:
71 | k = k.replace('shortcut.bn.bn_list.0', 'shortcut.1')
72 | new_state_dict_normal[k] = v
73 | elif 'shortcut.bn.bn_list.1' in k:
74 | k = k.replace('shortcut.bn.bn_list.1', 'shortcut.1')
75 | new_state_dict[k] = v
76 | elif '.bn_list.0' in k:
77 | k = k.replace('.bn_list.0', '')
78 | new_state_dict_normal[k] = v
79 | elif '.bn_list.1' in k:
80 | k = k.replace('.bn_list.1', '')
81 | new_state_dict[k] = v
82 | elif 'shortcut.conv' in k:
83 | k = k.replace('shortcut.conv', 'shortcut.0')
84 | new_state_dict_normal[k] = v
85 | new_state_dict[k] = v
86 | else:
87 | new_state_dict_normal[k] = v
88 | new_state_dict[k] = v
89 | return new_state_dict, new_state_dict_normal
90 |
91 |
92 | def format_time(seconds):
93 | days = int(seconds / 3600 / 24)
94 | seconds = seconds - days * 3600 * 24
95 | hours = int(seconds / 3600)
96 | seconds = seconds - hours * 3600
97 | minutes = int(seconds / 60)
98 | seconds = seconds - minutes * 60
99 | secondsf = int(seconds)
100 | seconds = seconds - secondsf
101 | millis = int(seconds * 1000)
102 |
103 | f = ''
104 | i = 1
105 | if days > 0:
106 | f += str(days) + 'D'
107 | i += 1
108 | if hours > 0 and i <= 2:
109 | f += str(hours) + 'h'
110 | i += 1
111 | if minutes > 0 and i <= 2:
112 | f += str(minutes) + 'm'
113 | i += 1
114 | if secondsf > 0 and i <= 2:
115 | f += str(secondsf) + 's'
116 | i += 1
117 | if millis > 0 and i <= 2:
118 | f += str(millis) + 'ms'
119 | i += 1
120 | if f == '':
121 | f = '0ms'
122 | return f
123 |
124 |
125 | class TwoCropTransform:
126 | """Create two crops of the same image"""
127 |
128 | def __init__(self, transform):
129 | self.transform = transform
130 |
131 | def __call__(self, x):
132 | return [self.transform(x), self.transform(x)]
133 |
134 |
135 | class TwoCropTransformAdv:
136 | """Create two crops of the same image"""
137 |
138 | def __init__(self, transform, transform_adv):
139 | self.transform = transform
140 | self.transform_adv = transform_adv
141 |
142 | def __call__(self, x):
143 | return [self.transform(x), self.transform(x), self.transform_adv(x)]
144 |
145 |
146 | class AverageMeter(object):
147 | """Computes and stores the average and current value"""
148 |
149 | def __init__(self):
150 | self.reset()
151 |
152 | def reset(self):
153 | self.val = 0
154 | self.avg = 0
155 | self.sum = 0
156 | self.count = 0
157 |
158 | def update(self, val, n=1):
159 | self.val = val
160 | self.sum += val * n
161 | self.count += n
162 | self.avg = self.sum / self.count
163 |
164 |
165 | def adjust_learning_rate(args, optimizer, epoch):
166 | lr = args.learning_rate
167 | if args.cosine:
168 | eta_min = lr * (args.lr_decay_rate ** 3)
169 | lr = eta_min + (lr - eta_min) * (
170 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
171 | else:
172 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
173 | if steps > 0:
174 | lr = lr * (args.lr_decay_rate ** steps)
175 |
176 | for param_group in optimizer.param_groups:
177 | param_group['lr'] = lr
178 |
179 |
180 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
181 | if args.warm and epoch <= args.warm_epochs:
182 | p = (batch_id + (epoch - 1) * total_batches) / \
183 | (args.warm_epochs * total_batches)
184 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
185 |
186 | for param_group in optimizer.param_groups:
187 | param_group['lr'] = lr
188 |
189 |
190 | def accuracy(output, target, topk=(1,)):
191 | """Computes the accuracy over the k top predictions for the specified values of k"""
192 | with torch.no_grad():
193 | maxk = max(topk)
194 | batch_size = target.size(0)
195 |
196 | _, pred = output.topk(maxk, 1, True, True)
197 | pred = pred.t()
198 | correct = pred.eq(target.view(1, -1).expand_as(pred))
199 |
200 | res = []
201 | for k in topk:
202 | correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
203 | res.append(correct_k.mul_(100.0 / batch_size))
204 | return res
205 |
--------------------------------------------------------------------------------
/Asym-RoCL/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data.cifar import CIFAR10, CIFAR100
3 | from torchvision import transforms
4 |
5 | def get_dataset(args):
6 |
7 | ### color augmentation ###
8 | color_jitter = transforms.ColorJitter(0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.2*args.color_jitter_strength)
9 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
10 | rnd_gray = transforms.RandomGrayscale(p=0.2)
11 |
12 | learning_type = args.train_type
13 |
14 | if args.dataset == 'cifar-10':
15 |
16 | if learning_type =='contrastive':
17 | transform_train = transforms.Compose([
18 | rnd_color_jitter,
19 | rnd_gray,
20 | transforms.RandomHorizontalFlip(),
21 | transforms.RandomResizedCrop(32),
22 | transforms.ToTensor(),
23 | ])
24 |
25 | transform_test = transform_train
26 |
27 | elif learning_type=='linear_eval':
28 | transform_train = transforms.Compose([
29 | rnd_color_jitter,
30 | rnd_gray,
31 | transforms.RandomHorizontalFlip(),
32 | transforms.RandomResizedCrop(32),
33 | transforms.ToTensor(),
34 | ])
35 |
36 | transform_test = transforms.Compose([
37 | transforms.ToTensor(),
38 | ])
39 |
40 | elif learning_type=='test':
41 | transform_train = transforms.Compose([
42 | transforms.RandomHorizontalFlip(),
43 | transforms.RandomResizedCrop(32),
44 | transforms.ToTensor(),
45 | ])
46 |
47 | transform_test = transforms.Compose([
48 | transforms.ToTensor(),
49 | ])
50 | else:
51 | assert('wrong learning type')
52 |
53 | train_dst = CIFAR10(root='~/data/cifar10/', train=True, download=True,
54 | transform=transform_train,contrastive_learning=learning_type)
55 | val_dst = CIFAR10(root='~/data/cifar10/', train=False, download=True,
56 | transform=transform_test,contrastive_learning=learning_type)
57 |
58 | if learning_type=='contrastive':
59 | train_sampler = torch.utils.data.distributed.DistributedSampler(
60 | train_dst,
61 | num_replicas=args.ngpu,
62 | rank=args.local_rank,
63 | )
64 | train_loader = torch.utils.data.DataLoader(train_dst,batch_size=args.batch_size,num_workers=4,
65 | pin_memory=False,
66 | shuffle=(train_sampler is None),
67 | sampler=train_sampler,
68 | )
69 |
70 | val_loader = torch.utils.data.DataLoader(val_dst,batch_size=100,
71 | num_workers=4,
72 | pin_memory=False,
73 | shuffle=False,
74 | )
75 |
76 | return train_loader, train_dst, val_loader, val_dst, train_sampler
77 | else:
78 | train_loader = torch.utils.data.DataLoader(train_dst,
79 | batch_size=args.batch_size,
80 | shuffle=True, num_workers=4)
81 | val_batch = 100
82 | val_loader = torch.utils.data.DataLoader(val_dst, batch_size=val_batch,
83 | shuffle=False, num_workers=4)
84 |
85 | return train_loader, train_dst, val_loader, val_dst
86 |
87 | if args.dataset == 'cifar-100':
88 |
89 | if learning_type=='contrastive':
90 | transform_train = transforms.Compose([
91 | rnd_color_jitter,
92 | rnd_gray,
93 | transforms.RandomHorizontalFlip(),
94 | transforms.RandomResizedCrop(32),
95 | transforms.ToTensor()
96 | ])
97 |
98 | transform_test = transform_train
99 |
100 | elif learning_type=='linear_eval':
101 | transform_train = transforms.Compose([
102 | rnd_color_jitter,
103 | rnd_gray,
104 | transforms.RandomHorizontalFlip(),
105 | transforms.RandomResizedCrop(32),
106 | transforms.ToTensor()
107 | ])
108 |
109 | transform_test = transforms.Compose([
110 | transforms.ToTensor()
111 | ])
112 |
113 | elif learning_type=='test':
114 | transform_train = transforms.Compose([
115 | transforms.RandomCrop(32, padding=4),
116 | transforms.RandomHorizontalFlip(),
117 | transforms.ToTensor()
118 | ])
119 |
120 | transform_test = transforms.Compose([
121 | transforms.ToTensor()
122 | ])
123 | else:
124 | assert('wrong learning type')
125 |
126 | train_dst = CIFAR100(root='~/data/cifar100/', train=True, download=True,
127 | transform=transform_train,contrastive_learning=learning_type)
128 | val_dst = CIFAR100(root='~/data/cifar100/', train=False, download=True,
129 | transform=transform_test,contrastive_learning=learning_type)
130 |
131 | if learning_type=='contrastive':
132 | train_sampler = torch.utils.data.distributed.DistributedSampler(
133 | train_dst,
134 | num_replicas=args.ngpu,
135 | rank=args.local_rank,
136 | )
137 | train_loader = torch.utils.data.DataLoader(train_dst,batch_size=args.batch_size,num_workers=4,
138 | pin_memory=True,
139 | shuffle=(train_sampler is None),
140 | sampler=train_sampler,
141 | )
142 |
143 | val_loader = torch.utils.data.DataLoader(val_dst,batch_size=100,num_workers=4,
144 | pin_memory=True,
145 | )
146 | return train_loader, train_dst, val_loader, val_dst, train_sampler
147 |
148 | else:
149 | train_loader = torch.utils.data.DataLoader(train_dst,
150 | batch_size=args.batch_size,
151 | shuffle=True, num_workers=4)
152 |
153 | val_loader = torch.utils.data.DataLoader(val_dst, batch_size=100,
154 | shuffle=False, num_workers=4)
155 |
156 | return train_loader, train_dst, val_loader, val_dst
157 |
158 |
--------------------------------------------------------------------------------
/Asym-RoCL/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | BasicBlock and Bottleneck module is from the original ResNet paper:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 |
7 | PreActBlock and PreActBottleneck module is from the later paper:
8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
10 | '''
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, in_planes, planes, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_planes, planes, stride)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 |
28 | self.shortcut = nn.Sequential()
29 | if stride != 1 or in_planes != self.expansion*planes:
30 | self.shortcut = nn.Sequential(
31 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
32 | nn.BatchNorm2d(self.expansion*planes)
33 | )
34 |
35 | def forward(self, x):
36 | out = F.relu(self.bn1(self.conv1(x)))
37 | out = self.bn2(self.conv2(out))
38 | out += self.shortcut(x)
39 | out = F.relu(out)
40 | return out
41 |
42 |
43 | class PreActBlock(nn.Module):
44 | '''Pre-activation version of the BasicBlock.'''
45 | expansion = 1
46 |
47 | def __init__(self, in_planes, planes, stride=1):
48 | super(PreActBlock, self).__init__()
49 | self.bn1 = nn.BatchNorm2d(in_planes)
50 | self.conv1 = conv3x3(in_planes, planes, stride)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv2 = conv3x3(planes, planes)
53 |
54 | self.shortcut = nn.Sequential()
55 | if stride != 1 or in_planes != self.expansion*planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
58 | )
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(x))
62 | shortcut = self.shortcut(out)
63 | out = self.conv1(out)
64 | out = self.conv2(F.relu(self.bn2(out)))
65 | out += shortcut
66 | return out
67 |
68 |
69 | class Bottleneck(nn.Module):
70 | expansion = 4
71 |
72 | def __init__(self, in_planes, planes, stride=1):
73 | super(Bottleneck, self).__init__()
74 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(planes)
76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
77 | self.bn2 = nn.BatchNorm2d(planes)
78 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
79 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
80 |
81 | self.shortcut = nn.Sequential()
82 | if stride != 1 or in_planes != self.expansion*planes:
83 | self.shortcut = nn.Sequential(
84 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
85 | nn.BatchNorm2d(self.expansion*planes)
86 | )
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = F.relu(self.bn2(self.conv2(out)))
91 | out = self.bn3(self.conv3(out))
92 | out += self.shortcut(x)
93 | out = F.relu(out)
94 | return out
95 |
96 |
97 | class PreActBottleneck(nn.Module):
98 | '''Pre-activation version of the original Bottleneck module.'''
99 | expansion = 4
100 |
101 | def __init__(self, in_planes, planes, stride=1):
102 | super(PreActBottleneck, self).__init__()
103 | self.bn1 = nn.BatchNorm2d(in_planes)
104 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
105 | self.bn2 = nn.BatchNorm2d(planes)
106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
107 | self.bn3 = nn.BatchNorm2d(planes)
108 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
109 |
110 | self.shortcut = nn.Sequential()
111 | if stride != 1 or in_planes != self.expansion*planes:
112 | self.shortcut = nn.Sequential(
113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
114 | )
115 |
116 | def forward(self, x):
117 | out = F.relu(self.bn1(x))
118 | shortcut = self.shortcut(out)
119 | out = self.conv1(out)
120 | out = self.conv2(F.relu(self.bn2(out)))
121 | out = self.conv3(F.relu(self.bn3(out)))
122 | out += shortcut
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 | def __init__(self, block, num_blocks, num_classes=10, contrastive_learning=True):
128 | super(ResNet, self).__init__()
129 | self.in_planes = 64
130 |
131 | self.conv1 = conv3x3(3,64)
132 | self.bn1 = nn.BatchNorm2d(64)
133 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
134 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
135 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
136 |
137 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
138 |
139 | self.contranstive_learning = contrastive_learning
140 |
141 | if not contrastive_learning:
142 | self.linear = nn.Linear(512*block.expansion, num_classes)
143 |
144 | def _make_layer(self, block, planes, num_blocks, stride):
145 | strides = [stride] + [1]*(num_blocks-1)
146 | layers = []
147 | for stride in strides:
148 | layers.append(block(self.in_planes, planes, stride))
149 | self.in_planes = planes * block.expansion
150 | return nn.Sequential(*layers)
151 |
152 | def forward(self, x, internal_outputs=False):
153 | out = x
154 | out_list = []
155 |
156 | out = self.conv1(out)
157 | out = self.bn1(out)
158 | out = F.relu(out)
159 | out_list.append(out)
160 |
161 | out = self.layer1(out)
162 | out_list.append(out)
163 |
164 | out = self.layer2(out)
165 | out_list.append(out)
166 |
167 | out = self.layer3(out)
168 | out_list.append(out)
169 |
170 | out = self.layer4(out)
171 | out_list.append(out)
172 |
173 | out = F.avg_pool2d(out, 4)
174 | out = out.view(out.size(0), -1)
175 |
176 | if not self.contranstive_learning:
177 | out = self.linear(out)
178 |
179 | if internal_outputs:
180 | return out, out_list
181 |
182 | return out
183 |
184 | def ResNet18(num_classes, contrastive_learning):
185 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, contrastive_learning=contrastive_learning)
186 |
187 | def ResNet34(num_classes,contrastive_learning):
188 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, contrastive_learning=contrastive_learning)
189 |
190 | def ResNet50(num_classes,contrastive_learning):
191 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, contrastive_learning=contrastive_learning)
192 |
--------------------------------------------------------------------------------
/Asym-RoCL/attack_lib.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is modified from
3 |
4 | https://github.com/utkuozbulak/pytorch-cnn-adversarial-attacks
5 | https://github.com/louis2889184/pytorch-adversarial-training
6 | https://github.com/MadryLab/robustness
7 | https://github.com/yaodongyu/TRADES
8 |
9 | """
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from loss import pairwise_similarity, NT_xent
14 |
15 | def project(x, original_x, epsilon, _type='linf'):
16 |
17 | if _type == 'linf':
18 | max_x = original_x + epsilon
19 | min_x = original_x - epsilon
20 |
21 | x = torch.max(torch.min(x, max_x), min_x)
22 | else:
23 | raise NotImplementedError
24 |
25 | return x
26 |
27 | class FastGradientSignUntargeted():
28 | """
29 | Fast gradient sign untargeted adversarial attack, minimizes the initial class activation
30 | with iterative grad sign updates
31 | """
32 | def __init__(self, model, linear, epsilon, alpha, min_val, max_val, max_iters, _type='linf'):
33 |
34 | # Model
35 | self.model = model
36 | self.linear = linear
37 | # Maximum perturbation
38 | self.epsilon = epsilon
39 | # Movement multiplier per iteration
40 | self.alpha = alpha
41 | # Minimum value of the pixels
42 | self.min_val = min_val
43 | # Maximum value of the pixels
44 | self.max_val = max_val
45 | # Maximum numbers of iteration to generated adversaries
46 | self.max_iters = max_iters
47 | # The perturbation of epsilon
48 | self._type = _type
49 |
50 | def perturb(self, original_images, labels, reduction4loss='mean', random_start=True):
51 | # original_images: values are within self.min_val and self.max_val
52 | # The adversaries created from random close points to the original data
53 | if random_start:
54 | rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
55 | -self.epsilon, self.epsilon)
56 | rand_perturb = rand_perturb.cuda()
57 | x = original_images.clone() + rand_perturb
58 | x = torch.clamp(x,self.min_val, self.max_val)
59 | else:
60 | x = original_images.clone()
61 |
62 | x.requires_grad = True
63 |
64 | self.model.eval()
65 | if not self.linear=='None':
66 | self.linear.eval()
67 |
68 | with torch.enable_grad():
69 | for _iter in range(self.max_iters):
70 |
71 | self.model.zero_grad()
72 | if not self.linear=='None':
73 | self.linear.zero_grad()
74 |
75 | if self.linear=='None':
76 | outputs = self.model(x)
77 | else:
78 | outputs = self.linear(self.model(x))
79 |
80 | loss = F.cross_entropy(outputs, labels, reduction=reduction4loss)
81 |
82 | grad_outputs = None
83 | grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, only_inputs=True, retain_graph=False)[0]
84 |
85 | if self._type == 'linf':
86 | scaled_g = torch.sign(grads.data)
87 |
88 | x.data += self.alpha * scaled_g
89 |
90 | x = torch.clamp(x, self.min_val, self.max_val)
91 | x = project(x, original_images, self.epsilon, self._type)
92 |
93 |
94 | return x.detach()
95 |
96 | class RepresentationAdv():
97 |
98 | def __init__(self, model, projector, epsilon, alpha, min_val, max_val, max_iters, args, _type='linf', loss_type='sim', regularize='original'):
99 |
100 | # Model
101 | self.model = model
102 | self.projector = projector
103 | self.regularize = regularize
104 | # Maximum perturbation
105 | self.epsilon = epsilon
106 | # Movement multiplier per iteration
107 | self.alpha = alpha
108 | # Minimum value of the pixels
109 | self.min_val = min_val
110 | # Maximum value of the pixels
111 | self.max_val = max_val
112 | # Maximum numbers of iteration to generated adversaries
113 | self.max_iters = max_iters
114 | # The perturbation of epsilon
115 | self._type = _type
116 | # loss type
117 | self.loss_type = loss_type
118 |
119 | self.args = args
120 |
121 |
122 | def get_loss(self, original_images, target, optimizer, weight, random_start=True):
123 | if random_start:
124 | rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
125 | -self.epsilon, self.epsilon)
126 | rand_perturb = rand_perturb.float().cuda()
127 | x = original_images.float().clone() + rand_perturb
128 | x = torch.clamp(x,self.min_val, self.max_val)
129 | else:
130 | x = original_images.clone()
131 |
132 | x.requires_grad = True
133 |
134 | self.model.eval()
135 | self.projector.eval()
136 | batch_size = len(x)
137 |
138 | with torch.enable_grad():
139 | for _iter in range(self.max_iters):
140 |
141 | self.model.zero_grad()
142 | self.projector.zero_grad()
143 |
144 | if self.loss_type == 'mse':
145 | loss = F.mse_loss(self.projector(self.model(x)),self.projector(self.model(target)))
146 | elif self.loss_type == 'sim':
147 | inputs = torch.cat((x, target))
148 | output = self.projector(self.model(inputs))
149 | similarity,_ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type = 'None')
150 | loss = NT_xent(similarity, 'None', self.args)
151 | elif self.loss_type == 'l1':
152 | loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target)))
153 | elif self.loss_type =='cos':
154 | loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).mean()
155 |
156 | grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0]
157 |
158 | if self._type == 'linf':
159 | scaled_g = torch.sign(grads.data)
160 |
161 | x.data += self.alpha * scaled_g
162 |
163 | x = torch.clamp(x,self.min_val,self.max_val)
164 | x = project(x, original_images, self.epsilon, self._type)
165 |
166 | self.model.train()
167 | self.projector.train()
168 | optimizer.zero_grad()
169 |
170 | if self.loss_type == 'mse':
171 | loss = F.mse_loss(self.projector(self.model(x)),self.projector(self.model(target))) * (1.0/batch_size)
172 | elif self.loss_type == 'sim':
173 | if self.regularize== 'original':
174 | inputs = torch.cat((x, original_images))
175 | else:
176 | inputs = torch.cat((x, target))
177 | output = self.projector(self.model(inputs))
178 | similarity, _ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type = 'None')
179 | loss = (1.0/weight) * NT_xent(similarity, 'None', self.args)
180 | elif self.loss_type == 'l1':
181 | loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target))) * (1.0/batch_size)
182 | elif self.loss_type == 'cos':
183 | loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).sum() * (1.0/batch_size)
184 |
185 | return x.detach(), loss
186 |
--------------------------------------------------------------------------------
/Asym-RoCL/data/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | import sys
7 | from torchvision import transforms
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 |
13 | from .vision import VisionDataset
14 | from .utils import check_integrity, download_and_extract_archive
15 |
16 |
17 | class CIFAR10(VisionDataset):
18 | """`CIFAR10 `_ Dataset.
19 | Args:
20 | root (string): Root directory of dataset where directory
21 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
22 | train (bool, optional): If True, creates dataset from training set, otherwise
23 | creates from test set.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If true, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again.
31 | """
32 | base_folder = 'cifar-10-batches-py'
33 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
34 | filename = "cifar-10-python.tar.gz"
35 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
36 | train_list = [
37 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
38 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
39 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
40 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
41 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
42 | ]
43 |
44 | test_list = [
45 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
46 | ]
47 | meta = {
48 | 'filename': 'batches.meta',
49 | 'key': 'label_names',
50 | 'md5': '5ff9c542aee3614f3951f8cda6e48888',
51 | }
52 |
53 | def __init__(self, root, train=True, transform=None, target_transform=None,
54 | download=False, contrastive_learning=False):
55 |
56 | super(CIFAR10, self).__init__(root, transform=transform,
57 | target_transform=target_transform)
58 |
59 | self.train = train # training set or test set
60 | self.learning_type = contrastive_learning
61 |
62 | if download:
63 | self.download()
64 |
65 | if not self._check_integrity():
66 | raise RuntimeError('Dataset not found or corrupted.' +
67 | ' You can use download=True to download it')
68 |
69 | if self.train:
70 | downloaded_list = self.train_list
71 | else:
72 | downloaded_list = self.test_list
73 |
74 | self.data = []
75 | self.targets = []
76 |
77 | # now load the picked numpy arrays
78 | for file_name, checksum in downloaded_list:
79 | file_path = os.path.join(self.root, self.base_folder, file_name)
80 | with open(file_path, 'rb') as f:
81 | if sys.version_info[0] == 2:
82 | entry = pickle.load(f)
83 | else:
84 | entry = pickle.load(f, encoding='latin1')
85 | self.data.append(entry['data'])
86 | if 'labels' in entry:
87 | self.targets.extend(entry['labels'])
88 | else:
89 | self.targets.extend(entry['fine_labels'])
90 |
91 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
92 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
93 |
94 | self._load_meta()
95 |
96 | def _load_meta(self):
97 | path = os.path.join(self.root, self.base_folder, self.meta['filename'])
98 | if not check_integrity(path, self.meta['md5']):
99 | raise RuntimeError('Dataset metadata file not found or corrupted.' +
100 | ' You can use download=True to download it')
101 | with open(path, 'rb') as infile:
102 | if sys.version_info[0] == 2:
103 | data = pickle.load(infile)
104 | else:
105 | data = pickle.load(infile, encoding='latin1')
106 | self.classes = data[self.meta['key']]
107 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
108 |
109 | def __getitem__(self, index):
110 | """
111 | Args:
112 | index (int): Index
113 | Returns:
114 | tuple: (image, target) where target is index of the target class.
115 | """
116 | img, target = self.data[index], self.targets[index]
117 | ori_img = img
118 | toTensor = transforms.ToTensor()
119 | ori_img = toTensor(ori_img)
120 |
121 | if self.learning_type=='contrastive':
122 | img_2 = img.copy()
123 |
124 | elif self.learning_type=='linear_eval':
125 | if self.train:
126 | img_2 = img.copy()
127 |
128 | # doing this so that it is consistent with all other datasets
129 | # to return a PIL Image
130 | img = Image.fromarray(img)
131 | if self.learning_type=='contrastive':
132 | img_2 = Image.fromarray(img_2)
133 | elif self.learning_type=='linear_eval':
134 | if self.train:
135 | img_2 = img.copy()
136 |
137 | if self.transform is not None:
138 | img = self.transform(img)
139 | if self.learning_type=='contrastive':
140 | img_2 = self.transform(img_2)
141 | elif self.learning_type=='linear_eval':
142 | if self.train:
143 | img_2 = self.transform(img_2)
144 |
145 | if self.target_transform is not None:
146 | target = self.target_transform(target)
147 |
148 | if self.learning_type=='contrastive':
149 | return ori_img, img, img_2, target
150 | elif self.learning_type=='linear_eval':
151 | if self.train:
152 | return ori_img, img, img_2, target
153 | else:
154 | return img, target
155 | else:
156 | return img, target
157 |
158 | def __len__(self):
159 | return len(self.data)
160 |
161 | def _check_integrity(self):
162 | root = self.root
163 | for fentry in (self.train_list + self.test_list):
164 | filename, md5 = fentry[0], fentry[1]
165 | fpath = os.path.join(root, self.base_folder, filename)
166 | if not check_integrity(fpath, md5):
167 | return False
168 | return True
169 |
170 | def download(self):
171 | if self._check_integrity():
172 | print('Files already downloaded and verified')
173 | return
174 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
175 |
176 | def extra_repr(self):
177 | return "Split: {}".format("Train" if self.train is True else "Test")
178 |
179 |
180 | class CIFAR100(CIFAR10):
181 | """`CIFAR100 `_ Dataset.
182 | This is a subclass of the `CIFAR10` Dataset.
183 | """
184 | base_folder = 'cifar-100-python'
185 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
186 | filename = "cifar-100-python.tar.gz"
187 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
188 | train_list = [
189 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
190 | ]
191 |
192 | test_list = [
193 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
194 | ]
195 | meta = {
196 | 'filename': 'meta',
197 | 'key': 'fine_label_names',
198 | 'md5': '7973b15100ade9c7d40fb424638fde48',
199 | }
200 |
201 |
--------------------------------------------------------------------------------
/Asym-RoCL/utils.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | import torch.nn.functional as F
15 |
16 | import numpy as np
17 | # import cv2
18 | import scipy.misc
19 | from itertools import chain
20 |
21 | class AverageMeter(object):
22 | """Computes and stores the average and current value"""
23 | def __init__(self):
24 | self.reset()
25 |
26 | def reset(self):
27 | self.val = 0
28 | self.avg = 0
29 | self.sum = 0
30 | self.count = 0
31 |
32 | def update(self, val, n=1):
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
38 |
39 | def accuracy(output, target, topk=(1,)):
40 | """Computes the accuracy over the k top predictions for the specified values of k"""
41 | with torch.no_grad():
42 | maxk = max(topk)
43 | batch_size = target.size(0)
44 |
45 | _, pred = output.topk(maxk, 1, True, True)
46 | pred = pred.t()
47 | correct = pred.eq(target.view(1, -1).expand_as(pred))
48 |
49 | res = []
50 | for k in topk:
51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
52 | res.append(correct_k.mul_(100.0 / batch_size))
53 | return res
54 |
55 | def get_mean_and_std(dataset):
56 | '''Compute the mean and std value of dataset.'''
57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
58 | mean = torch.zeros(3)
59 | std = torch.zeros(3)
60 | print('==> Computing mean and std..')
61 | for inputs, targets in dataloader:
62 | for i in range(3):
63 | mean[i] += inputs[:,i,:,:].mean()
64 | std[i] += inputs[:,i,:,:].std()
65 | mean.div_(len(dataset))
66 | std.div_(len(dataset))
67 | return mean, std
68 |
69 | def init_params(net):
70 | '''Init layer parameters.'''
71 | for m in net.modules():
72 | if isinstance(m, nn.Conv2d):
73 | init.kaiming_normal(m.weight, mode='fan_out')
74 | if m.bias:
75 | init.constant(m.bias, 0)
76 | elif isinstance(m, nn.BatchNorm2d):
77 | init.constant(m.weight, 1)
78 | init.constant(m.bias, 0)
79 | elif isinstance(m, nn.Linear):
80 | init.normal(m.weight, std=1e-3)
81 | if m.bias:
82 | init.constant(m.bias, 0)
83 |
84 |
85 | # _, term_width = os.popen('stty size', 'r').read().split()
86 | term_width = int(80)
87 |
88 | TOTAL_BAR_LENGTH = 86.
89 | last_time = time.time()
90 | begin_time = last_time
91 | def progress_bar(current, total, msg=None):
92 | global last_time, begin_time
93 | if current == 0:
94 | begin_time = time.time() # Reset for new bar.
95 |
96 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
97 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
98 |
99 | sys.stdout.write(' [')
100 | for i in range(cur_len):
101 | sys.stdout.write('=')
102 | sys.stdout.write('>')
103 | for i in range(rest_len):
104 | sys.stdout.write('.')
105 | sys.stdout.write(']')
106 |
107 | cur_time = time.time()
108 | step_time = cur_time - last_time
109 | last_time = cur_time
110 | tot_time = cur_time - begin_time
111 |
112 | L = []
113 | L.append(' Step: %s' % format_time(step_time))
114 | L.append(' | Tot: %s' % format_time(tot_time))
115 | if msg:
116 | L.append(' | ' + msg)
117 |
118 | msg = ''.join(L)
119 | sys.stdout.write(msg)
120 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
121 | sys.stdout.write(' ')
122 |
123 | # Go back to the center of the bar.
124 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
125 | sys.stdout.write('\b')
126 | sys.stdout.write(' %d/%d ' % (current+1, total))
127 |
128 | if current < total-1:
129 | sys.stdout.write('\r')
130 | else:
131 | sys.stdout.write('\n')
132 | sys.stdout.flush()
133 |
134 | def format_time(seconds):
135 | days = int(seconds / 3600/24)
136 | seconds = seconds - days*3600*24
137 | hours = int(seconds / 3600)
138 | seconds = seconds - hours*3600
139 | minutes = int(seconds / 60)
140 | seconds = seconds - minutes*60
141 | secondsf = int(seconds)
142 | seconds = seconds - secondsf
143 | millis = int(seconds*1000)
144 |
145 | f = ''
146 | i = 1
147 | if days > 0:
148 | f += str(days) + 'D'
149 | i += 1
150 | if hours > 0 and i <= 2:
151 | f += str(hours) + 'h'
152 | i += 1
153 | if minutes > 0 and i <= 2:
154 | f += str(minutes) + 'm'
155 | i += 1
156 | if secondsf > 0 and i <= 2:
157 | f += str(secondsf) + 's'
158 | i += 1
159 | if millis > 0 and i <= 2:
160 | f += str(millis) + 'ms'
161 | i += 1
162 | if f == '':
163 | f = '0ms'
164 | return f
165 |
166 | def checkpoint_train(model, acc, epoch, args, optimizer, save_name_add=''):
167 | # Save checkpoint.
168 | print('Saving..')
169 | state = {
170 | 'epoch': epoch,
171 | 'acc': acc,
172 | 'model': model.state_dict(),
173 | 'optimizer_state' : optimizer.state_dict(),
174 | 'rng_state': torch.get_rng_state()
175 | }
176 |
177 | save_name = './checkpoint/ckpt.t7' + args.name
178 | save_name += save_name_add
179 |
180 | if not os.path.isdir('./checkpoint'):
181 | os.mkdir('./checkpoint')
182 | torch.save(state, save_name)
183 |
184 | def checkpoint(model, finetune_type, acc, acc_adv, epoch, args, optimizer, save_name_add=''):
185 | # Save checkpoint.
186 | print('Saving..')
187 | state = {
188 | 'epoch': epoch,
189 | 'acc': acc,
190 | 'acc_adv': acc_adv,
191 | 'finetune_type': finetune_type,
192 | 'model': model.state_dict(),
193 | 'optimizer_state' : optimizer.state_dict(),
194 | 'rng_state': torch.get_rng_state()
195 | }
196 |
197 | save_name = './checkpoint/ckpt.t7' + args.name + finetune_type
198 | save_name += save_name_add
199 |
200 | if not os.path.isdir('./checkpoint'):
201 | os.mkdir('./checkpoint')
202 | torch.save(state, save_name)
203 |
204 | def learning_rate_warmup(optimizer, epoch, args):
205 | """Learning rate warmup for first 10 epoch"""
206 |
207 | lr = args.lr
208 | lr /= 10
209 | lr *= (epoch+1)
210 |
211 | for param_group in optimizer.param_groups:
212 | param_group['lr'] = lr
213 |
214 | class LabelDict():
215 | def __init__(self, dataset='cifar-10'):
216 | self.dataset = dataset
217 | if dataset == 'cifar-10':
218 | self.label_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat',
219 | 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse',
220 | 8: 'ship', 9: 'truck'}
221 |
222 | self.class_dict = {v: k for k, v in self.label_dict.items()}
223 |
224 | def label2class(self, label):
225 | assert label in self.label_dict, 'the label %d is not in %s' % (label, self.dataset)
226 | return self.label_dict[label]
227 |
228 | def class2label(self, _class):
229 | assert isinstance(_class, str)
230 | assert _class in self.class_dict, 'the class %s is not in %s' % (_class, self.dataset)
231 | return self.class_dict[_class]
232 |
233 | def get_highest_incorrect_predict(outputs,targets):
234 | _, sorted_prediction = torch.topk(outputs.data,k=2,dim=1)
235 |
236 | ### correct then second predict, incorrect then highest predict ###
237 |
238 | highest_incorrect_predict = ((sorted_prediction[:,0] == targets).type(torch.cuda.LongTensor) * sorted_prediction[:,1] + (sorted_prediction[:,0] != targets).type(torch.cuda.LongTensor) * sorted_prediction[:,0]).detach()
239 |
240 | return highest_incorrect_predict
241 |
242 |
--------------------------------------------------------------------------------
/Asym-AdvCL/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from torch.autograd import Variable
14 | import numpy as np
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, in_planes, planes, stride=1):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 |
27 | self.shortcut = nn.Sequential()
28 | if stride != 1 or in_planes != self.expansion*planes:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(self.expansion*planes)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = self.bn2(self.conv2(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1):
46 | super(Bottleneck, self).__init__()
47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(planes)
49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
53 |
54 | self.shortcut = nn.Sequential()
55 | if stride != 1 or in_planes != self.expansion*planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(self.expansion*planes)
59 | )
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = F.relu(self.bn2(self.conv2(out)))
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = F.relu(out)
67 | return out
68 |
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_blocks, num_classes=10, feat_dim=128, low_freq=False, high_freq=False, radius=0):
72 | super(ResNet, self).__init__()
73 | self.in_planes = 64
74 |
75 | # self.normalize = NormalizeByChannelMeanStd(
76 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77 |
78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
79 | self.bn1 = nn.BatchNorm2d(64)
80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
84 | self.linear = nn.Linear(512*block.expansion, num_classes)
85 |
86 |
87 | # self.linear_contrast = nn.Linear(512*block.expansion, 128)
88 | dim_in = 512*block.expansion
89 | self.head_proj = nn.Sequential(
90 | nn.Linear(dim_in, dim_in),
91 | # nn.BatchNorm1d(dim_in),
92 | nn.ReLU(inplace=True),
93 | nn.Linear(dim_in, 128)
94 | )
95 |
96 | self.head_pred = nn.Sequential(
97 | nn.Linear(128, 128),
98 | # nn.BatchNorm1d(128),
99 | nn.ReLU(inplace=True),
100 | nn.Linear(128, 128)
101 | )
102 |
103 | self.low_freq = low_freq
104 | self.high_freq = high_freq
105 | self.radius = radius
106 |
107 | def _make_layer(self, block, planes, num_blocks, stride):
108 | strides = [stride] + [1]*(num_blocks-1)
109 | layers = []
110 | for stride in strides:
111 | layers.append(block(self.in_planes, planes, stride))
112 | self.in_planes = planes * block.expansion
113 | return nn.Sequential(*layers)
114 |
115 | def distance(self, i, j, imageSize, r):
116 | dis = np.sqrt((i - imageSize / 2) ** 2 + (j - imageSize / 2) ** 2)
117 | if dis < r:
118 | return 1.0
119 | else:
120 | return 0
121 | def mask_radial(self, img, r):
122 | rows, cols = img.shape
123 | mask = torch.zeros((rows, cols))
124 | for i in range(rows):
125 | for j in range(cols):
126 | mask[i, j] = self.distance(i, j, imageSize=rows, r=r)
127 | return mask.cuda()
128 | def filter_low(self, Images, r):
129 | mask = self.mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r)
130 | bs, c, h, w = Images.shape
131 | x = Images.reshape([bs*c, h, w])
132 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1)))
133 | mask = mask.unsqueeze(0).repeat([bs*c, 1, 1])
134 | fd = fd * mask
135 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1))
136 | fd = torch.real(fd)
137 | fd = fd.reshape([bs, c, h, w])
138 | return fd
139 |
140 | def filter_high(self, Images, r):
141 | mask = self.mask_radial(torch.zeros([Images.shape[2], Images.shape[3]]), r)
142 | bs, c, h, w = Images.shape
143 | x = Images.reshape([bs * c, h, w])
144 | fd = torch.fft.fftshift(torch.fft.fftn(x, dim=(-2, -1)))
145 | mask = mask.unsqueeze(0).repeat([bs * c, 1, 1])
146 | fd = fd * (1. - mask)
147 | fd = torch.fft.ifftn(torch.fft.ifftshift(fd), dim=(-2, -1))
148 | fd = torch.real(fd)
149 | fd = fd.reshape([bs, c, h, w])
150 | return fd
151 |
152 | # return np.array(Images_freq_low), np.array(Images_freq_high)
153 |
154 | def forward(self, x, contrast=False, return_feat=False, CF=False):
155 | # img_org = x[0]
156 | # x = self.normalize(x)
157 |
158 |
159 | if self.low_freq:
160 | x = self.filter_low(x, self.radius)
161 | x = torch.clamp(x, 0, 1)
162 |
163 | if self.high_freq:
164 | x = self.filter_high(x, self.radius)
165 | x = torch.clamp(x, 0, 1)
166 |
167 | # img_filter = x[0]
168 | # import cv2
169 | # img_org = img_org.detach().cpu().numpy()*255.
170 | # img_filter = img_filter.detach().cpu().numpy()*255.
171 | # cv2.imwrite('org.jpg', img_org.transpose([1,2,0]))
172 | # cv2.imwrite('filter.jpg', img_filter.transpose([1,2,0]))
173 | # exit(0)
174 |
175 | out = F.relu(self.bn1(self.conv1(x)))
176 | out = self.layer1(out)
177 | out = self.layer2(out)
178 | out = self.layer3(out)
179 | out = self.layer4(out)
180 | out = F.avg_pool2d(out, 4)
181 | out = out.view(out.size(0), -1)
182 | feat = out
183 | if return_feat:
184 | return out
185 | if contrast:
186 | # out = self.linear_contrast(out)
187 | proj = self.head_proj(out)
188 | pred = self.head_pred(proj)
189 | proj = F.normalize(proj, dim=1)
190 | pred = F.normalize(pred, dim=1)
191 | if CF:
192 | return proj, pred, feat
193 | else:
194 | return proj, pred
195 | else:
196 | out = self.linear(out)
197 | return out
198 |
199 |
200 | def ResNet18(num_class=10, radius=8, low_freq=False, high_freq=False):
201 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_class, radius=radius, low_freq=low_freq, high_freq=high_freq)
202 |
203 | def ResNet34(num_class=10):
204 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_class)
205 |
206 | def ResNet50(num_class=10):
207 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_class)
208 |
209 | def ResNet101():
210 | return ResNet(Bottleneck, [3,4,23,3])
211 |
212 | def ResNet152():
213 | return ResNet(Bottleneck, [3,8,36,3])
214 |
215 |
216 | def test():
217 | net = ResNet18()
218 | y = net(Variable(torch.randn(1,3,32,32)))
219 | print(y.size())
220 |
221 | # test()
--------------------------------------------------------------------------------
/Asym-RoCL/rocl_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 |
3 | from __future__ import print_function
4 |
5 | import csv
6 | import os
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 |
11 | import torch.optim as optim
12 | import data_loader
13 | import model_loader
14 |
15 | from attack_lib import FastGradientSignUntargeted,RepresentationAdv
16 |
17 | from models.projector import Projector
18 | from argument import parser, print_args
19 | from utils import progress_bar, checkpoint, checkpoint_train, AverageMeter, accuracy
20 |
21 | from loss import pairwise_similarity_train, NT_xent, NT_xent_HN
22 | from torchlars import LARS
23 | from warmup_scheduler import GradualWarmupScheduler
24 |
25 | args = parser()
26 |
27 | def print_status(string):
28 | if args.local_rank % ngpus_per_node == 0:
29 | print(string)
30 |
31 | ngpus_per_node = torch.cuda.device_count()
32 | if args.ngpu>1:
33 | multi_gpu=True
34 | elif args.ngpu==1:
35 | multi_gpu=False
36 | else:
37 | assert("Need GPU....")
38 | if args.local_rank % ngpus_per_node == 0:
39 | print_args(args)
40 |
41 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
42 |
43 | if args.seed != 0:
44 | torch.manual_seed(args.seed)
45 |
46 | world_size = args.ngpu
47 | torch.distributed.init_process_group(
48 | 'nccl',
49 | init_method='env://',
50 | world_size=world_size,
51 | rank=args.local_rank,
52 | )
53 |
54 | # Data
55 | print_status('==> Preparing data..')
56 | if not (args.train_type=='contrastive'):
57 | assert('wrong train phase...')
58 | else:
59 | trainloader, traindst, testloader, testdst ,train_sampler = data_loader.get_dataset(args)
60 |
61 | # Model
62 | print_status('==> Building model..')
63 | torch.cuda.set_device(args.local_rank)
64 | model = model_loader.get_model(args)
65 |
66 | if args.model=='ResNet18':
67 | expansion=1
68 | elif args.model=='ResNet50':
69 | expansion=4
70 | else:
71 | assert('wrong model type')
72 | projector = Projector(expansion=expansion)
73 |
74 | if 'Rep' in args.advtrain_type:
75 | Rep_info = 'Rep_attack_ep_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(args.min) + '_max_val_' + str(args.max) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start)
76 | args.name += Rep_info
77 |
78 | print_status("Representation attack info ...")
79 | print_status(Rep_info)
80 | Rep = RepresentationAdv(model, projector, epsilon=args.epsilon, alpha=args.alpha, min_val=args.min, max_val=args.max, max_iters=args.k, args=args, _type=args.attack_type, loss_type=args.loss_type, regularize = args.regularize_to)
81 | else:
82 | assert('wrong adversarial train type')
83 |
84 | # Model upload to GPU #
85 | model.cuda()
86 | projector.cuda()
87 |
88 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
89 | model = torch.nn.parallel.DistributedDataParallel(
90 | model,
91 | device_ids=[args.local_rank],
92 | output_device=args.local_rank,
93 | find_unused_parameters=True,
94 | )
95 | projector = torch.nn.parallel.DistributedDataParallel(
96 | projector,
97 | device_ids=[args.local_rank],
98 | output_device=args.local_rank,
99 | find_unused_parameters=True,
100 | )
101 |
102 | cudnn.benchmark = True
103 | print_status('Using CUDA..')
104 |
105 | # Aggregating model parameter & projection parameter #
106 | model_params = []
107 | model_params += model.parameters()
108 | model_params += projector.parameters()
109 |
110 | # LARS optimizer from KAKAO-BRAIN github "pip install torchlars" or git from https://github.com/kakaobrain/torchlars
111 | base_optimizer = optim.SGD(model_params, lr=args.lr, momentum=0.9, weight_decay=args.decay)
112 | optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
113 |
114 | # Cosine learning rate annealing (SGDR) & Learning rate warmup git from https://github.com/ildoonet/pytorch-gradual-warmup-lr #
115 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch)
116 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=args.lr_multiplier, total_epoch=10, after_scheduler=scheduler_cosine)
117 |
118 | def train(epoch):
119 | if args.local_rank % ngpus_per_node == 0:
120 | print('\nEpoch: %d' % epoch)
121 |
122 | model.train()
123 | projector.train()
124 |
125 | train_sampler.set_epoch(epoch)
126 | scheduler_warmup.step()
127 |
128 | total_loss = 0
129 | reg_simloss = 0
130 | reg_loss = 0
131 |
132 | for batch_idx, (ori, inputs_1, inputs_2, label) in enumerate(trainloader):
133 | ori, inputs_1, inputs_2 = ori.cuda(), inputs_1.cuda() ,inputs_2.cuda()
134 |
135 | if args.attack_to=='original':
136 | attack_target = inputs_1
137 | else:
138 | attack_target = inputs_2
139 |
140 | if 'Rep' in args.advtrain_type :
141 | advinputs, adv_loss = Rep.get_loss(original_images=inputs_1, target = attack_target, optimizer=optimizer, weight= args.lamda, random_start=args.random_start)
142 | reg_loss += adv_loss.data
143 |
144 | if not (args.advtrain_type == 'None'):
145 | inputs = torch.cat((inputs_1, inputs_2, advinputs))
146 | else:
147 | inputs = torch.cat((inputs_1, inputs_2))
148 |
149 | outputs = projector(model(inputs))
150 |
151 | similarity, similarity_grad, gathered_outputs = pairwise_similarity_train(outputs, args, temperature=args.temperature, multi_gpu=multi_gpu, adv_type = args.advtrain_type)
152 |
153 | if args.HN:
154 | simloss = NT_xent_HN(similarity, args.advtrain_type, args)
155 | else:
156 | simloss = NT_xent(similarity, args.advtrain_type, args)
157 |
158 | if not (args.advtrain_type=='None'):
159 | loss = simloss + adv_loss
160 | else:
161 | loss = simloss
162 |
163 | optimizer.zero_grad()
164 | loss.backward()
165 | total_loss += loss.data
166 | reg_simloss += simloss.data
167 |
168 | optimizer.step()
169 |
170 | if (args.local_rank % ngpus_per_node == 0):
171 | if 'Rep' in args.advtrain_type:
172 | progress_bar(batch_idx, len(trainloader),
173 | 'Loss: %.3f | SimLoss: %.3f | Adv: %.2f'
174 | % (total_loss / (batch_idx + 1), reg_simloss / (batch_idx + 1), reg_loss / (batch_idx + 1)))
175 | else:
176 | progress_bar(batch_idx, len(trainloader),
177 | 'Loss: %.3f | Adv: %.3f'
178 | % (total_loss/(batch_idx+1), reg_simloss/(batch_idx+1)))
179 |
180 | return (total_loss/batch_idx, reg_simloss/batch_idx)
181 |
182 |
183 | def test(epoch, train_loss):
184 | model.eval()
185 | projector.eval()
186 |
187 | # Save at the last epoch #
188 | if epoch == args.epoch - 1 and args.local_rank % ngpus_per_node == 0:
189 | checkpoint_train(model, train_loss, epoch, args, optimizer)
190 | checkpoint_train(projector, train_loss, epoch, args, optimizer, save_name_add='_projector')
191 |
192 | # Save at every 100 epoch #
193 | elif epoch % 100 == 0 and args.local_rank % ngpus_per_node == 0:
194 | checkpoint_train(model, train_loss, epoch, args, optimizer, save_name_add='_epoch_'+str(epoch))
195 | checkpoint_train(projector, train_loss, epoch, args, optimizer, save_name_add=('_projector_epoch_' + str(epoch)))
196 |
197 |
198 | # Log and saving checkpoint information #
199 | if not os.path.isdir('results') and args.local_rank % ngpus_per_node == 0:
200 | os.mkdir('results')
201 |
202 | args.name += (args.train_type + '_' +args.model + '_' + args.dataset + '_b' + str(args.batch_size)+'_nGPU'+str(args.ngpu)+'_l'+str(args.lamda))
203 | if args.HN:
204 | args.name += '_HN'
205 | loginfo = 'results/log_' + args.name + '_' + str(args.seed)
206 | logname = (loginfo+ '.csv')
207 | print_status('Training info...')
208 | print_status(loginfo)
209 |
210 | ##### Log file #####
211 | if args.local_rank % ngpus_per_node == 0:
212 | with open(logname, 'w') as logfile:
213 | logwriter = csv.writer(logfile, delimiter=',')
214 | logwriter.writerow(['epoch', 'train loss', 'reg loss'])
215 |
216 | print(args.name)
217 |
218 | ##### Training #####
219 | for epoch in range(start_epoch, args.epoch):
220 | train_loss, reg_loss = train(epoch)
221 | test(epoch, train_loss)
222 |
223 | if args.local_rank % ngpus_per_node == 0:
224 | with open(logname, 'a') as logfile:
225 | logwriter = csv.writer(logfile, delimiter=',')
226 | logwriter.writerow([epoch, train_loss.item(), reg_loss.item()])
227 |
228 |
229 |
--------------------------------------------------------------------------------
/Asym-RoCL/loss.py:
--------------------------------------------------------------------------------
1 | import nturl2path
2 | import diffdist.functional as distops
3 | import torch
4 | import torch.distributed as dist
5 | import numpy as np
6 |
7 | def pairwise_similarity(outputs,temperature=0.5,multi_gpu=False, adv_type='None'):
8 | '''
9 | Compute pairwise similarity and return the matrix
10 | input: aggregated outputs & temperature for scaling
11 | return: pairwise cosine similarity
12 | '''
13 | if multi_gpu and adv_type=='None':
14 |
15 | B = int(outputs.shape[0]/2)
16 |
17 | outputs_1 = outputs[0:B]
18 | outputs_2 = outputs[B:]
19 |
20 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())]
21 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1)
22 |
23 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())]
24 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2)
25 |
26 | outputs_1 = torch.cat(gather_t_1)
27 | outputs_2 = torch.cat(gather_t_2)
28 | outputs = torch.cat((outputs_1,outputs_2))
29 | elif multi_gpu and 'Rep' in adv_type:
30 | if adv_type == 'Rep':
31 | N=3
32 | B = int(outputs.shape[0]/N)
33 |
34 | outputs_1 = outputs[0:B]
35 | outputs_2 = outputs[B:2*B]
36 | outputs_3 = outputs[2*B:3*B]
37 |
38 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())]
39 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1)
40 |
41 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())]
42 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2)
43 |
44 | gather_t_3 = [torch.empty_like(outputs_3) for i in range(dist.get_world_size())]
45 | gather_t_3 = distops.all_gather(gather_t_3, outputs_3)
46 |
47 | outputs_1 = torch.cat(gather_t_1)
48 | outputs_2 = torch.cat(gather_t_2)
49 | outputs_3 = torch.cat(gather_t_3)
50 |
51 | if N==3:
52 | outputs = torch.cat((outputs_1,outputs_2,outputs_3))
53 |
54 | B = outputs.shape[0]
55 | outputs_norm = outputs/(outputs.norm(dim=1).view(B,1) + 1e-8)
56 | similarity_matrix = (1./temperature) * torch.mm(outputs_norm,outputs_norm.transpose(0,1).detach())
57 |
58 | return similarity_matrix, outputs
59 |
60 | def pairwise_similarity_train(outputs, args, temperature=0.5,multi_gpu=False, adv_type='None'):
61 | '''
62 | Compute pairwise similarity and return the matrix
63 | input: aggregated outputs & temperature for scaling
64 | return: pairwise cosine similarity
65 |
66 | '''
67 | if multi_gpu and adv_type=='None':
68 |
69 | B = int(outputs.shape[0]/2)
70 |
71 | outputs_1 = outputs[0:B]
72 | outputs_2 = outputs[B:]
73 |
74 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())]
75 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1)
76 |
77 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())]
78 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2)
79 |
80 | outputs_1 = torch.cat(gather_t_1)
81 | outputs_2 = torch.cat(gather_t_2)
82 | outputs = torch.cat((outputs_1,outputs_2))
83 | elif multi_gpu and 'Rep' in adv_type:
84 | if adv_type == 'Rep':
85 | N=3
86 | B = int(outputs.shape[0]/N)
87 |
88 | outputs_1 = outputs[0:B]
89 | outputs_2 = outputs[B:2*B]
90 | outputs_3 = outputs[2*B:3*B]
91 |
92 | gather_t_1 = [torch.empty_like(outputs_1) for i in range(dist.get_world_size())]
93 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1)
94 |
95 | gather_t_2 = [torch.empty_like(outputs_2) for i in range(dist.get_world_size())]
96 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2)
97 |
98 | gather_t_3 = [torch.empty_like(outputs_3) for i in range(dist.get_world_size())]
99 | gather_t_3 = distops.all_gather(gather_t_3, outputs_3)
100 |
101 | outputs_1 = torch.cat(gather_t_1)
102 | outputs_2 = torch.cat(gather_t_2)
103 | outputs_3 = torch.cat(gather_t_3)
104 |
105 | if N==3:
106 | outputs = torch.cat((outputs_1,outputs_2,outputs_3))
107 |
108 | B = outputs.shape[0]
109 | outputs_norm = outputs/(outputs.norm(dim=1).view(B,1) + 1e-8)
110 | similarity_matrix = (1./temperature) * torch.mm(outputs_norm,outputs_norm.transpose(0,1).detach())
111 |
112 | return similarity_matrix, None, outputs
113 |
114 | def NT_xent(similarity_matrix, adv_type, args):
115 | """
116 | Compute NT_xent loss
117 | input: pairwise-similarity matrix
118 | return: NT xent loss
119 | """
120 |
121 | N2 = len(similarity_matrix)
122 | if adv_type=='None':
123 | N = int(len(similarity_matrix) / 2)
124 | contrast_num = 2
125 | elif adv_type=='Rep': # [inp1, inp2, adv1]
126 | N = int(len(similarity_matrix) / 3)
127 | contrast_num = 3
128 |
129 | # Removing diagonal #
130 | similarity_matrix_exp = torch.exp(similarity_matrix)
131 | similarity_matrix_exp = similarity_matrix_exp * (1 - torch.eye(N2,N2)).cuda()
132 | NT_xent_loss = - torch.log(similarity_matrix_exp/(torch.sum(similarity_matrix_exp,dim=1).view(N2,1) + 1e-8) + 1e-8)
133 |
134 | if adv_type =='None':
135 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:]) + torch.diag(NT_xent_loss[N:,0:N]))
136 | elif adv_type =='Rep':
137 | if args.stop_grad:
138 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N])
139 | + args.adv_weight *
140 | ((torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[N:2*N,2*N:])) * args.stpg_degree
141 | + (torch.diag(NT_xent_loss[2*N:,0:N]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) * (1-args.stpg_degree))
142 | )
143 | else:
144 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N])
145 | + torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[2*N:,0:N])
146 | + torch.diag(NT_xent_loss[N:2*N,2*N:]) + torch.diag(NT_xent_loss[2*N:,N:2*N]))
147 | return NT_xent_loss_total
148 |
149 | def NT_xent_HN(similarity_matrix, adv_type, args):
150 | """
151 | Compute NT_xent loss
152 | input: pairwise-similarity matrix
153 | return: NT xent loss
154 | """
155 |
156 | N2 = len(similarity_matrix)
157 | if adv_type=='None':
158 | N = int(len(similarity_matrix) / 2)
159 | contrast_num = 2
160 | elif adv_type=='Rep': # [inp1, inp2, adv1]
161 | N = int(len(similarity_matrix) / 3)
162 | contrast_num = 3
163 |
164 | # Removing diagonal #
165 | similarity_matrix_exp = torch.exp(similarity_matrix)
166 | # similarity_matrix_exp = similarity_matrix_exp * (1 - torch.eye(N2,N2)).cuda()
167 | # NT_xent_loss = - torch.log(similarity_matrix_exp/(torch.sum(similarity_matrix_exp,dim=1).view(N2,1) + 1e-8) + 1e-8)
168 | # tau_plus = 0.1
169 | # beta = 1.0
170 | tau_plus = args.tau
171 | beta = args.beta
172 | temperature = 0.5
173 | N_neg = (N - 1) * contrast_num
174 | mask = torch.eye(N, dtype=torch.float32).cuda()
175 | mask = mask.repeat(contrast_num, contrast_num)
176 | logits_mask = torch.scatter(
177 | torch.ones_like(mask),
178 | 1,
179 | torch.arange(N * contrast_num).view(-1, 1).cuda(),
180 | 0
181 | )
182 | mask = mask * logits_mask
183 | # =============== reweight neg =================
184 | # for numerical stability
185 | exp_logits_neg = similarity_matrix_exp * (1 - mask) * logits_mask
186 | exp_logits_pos = similarity_matrix_exp * mask
187 | pos = exp_logits_pos.sum(dim=1) / mask.sum(1)
188 |
189 | imp = (beta * (exp_logits_neg + 1e-9).log()).exp()
190 | # imp = exp_logits_neg ** beta
191 | # imp = exp_logits_neg
192 | reweight_logits_neg = (imp * exp_logits_neg) / imp.mean(dim=-1)
193 | Ng = (-tau_plus * N_neg * pos + reweight_logits_neg.sum(dim=-1)) / (1 - tau_plus) # [4 bsz, 1]
194 | # constrain (optional)
195 | Ng = torch.clamp(Ng, min=N_neg * np.e**(-1 / temperature))
196 | NT_xent_loss = - torch.log(similarity_matrix_exp / ((pos + Ng).view(N2,1)))
197 | # ===============================================
198 |
199 | if adv_type =='None':
200 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:]) + torch.diag(NT_xent_loss[N:,0:N]))
201 | elif adv_type =='Rep':
202 | if args.stop_grad:
203 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N])
204 | + args.adv_weight *
205 | ((torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[N:2*N,2*N:])) * args.stpg_degree
206 | + (torch.diag(NT_xent_loss[2*N:,0:N]) + torch.diag(NT_xent_loss[2*N:,N:2*N])) * (1-args.stpg_degree))
207 | )
208 | else:
209 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:2*N]) + torch.diag(NT_xent_loss[N:2*N,0:N])
210 | + torch.diag(NT_xent_loss[0:N,2*N:]) + torch.diag(NT_xent_loss[2*N:,0:N])
211 | + torch.diag(NT_xent_loss[N:2*N,2*N:]) + torch.diag(NT_xent_loss[2*N:,N:2*N]))
212 | return NT_xent_loss_total
--------------------------------------------------------------------------------
/Asym-RoCL/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import gzip
5 | import errno
6 | import tarfile
7 | import zipfile
8 |
9 | import torch
10 | from torch.utils.model_zoo import tqdm
11 | from torch._six import PY37
12 |
13 |
14 | def gen_bar_updater():
15 | pbar = tqdm(total=None)
16 |
17 | def bar_update(count, block_size, total_size):
18 | if pbar.total is None and total_size:
19 | pbar.total = total_size
20 | progress_bytes = count * block_size
21 | pbar.update(progress_bytes - pbar.n)
22 |
23 | return bar_update
24 |
25 |
26 | def calculate_md5(fpath, chunk_size=1024 * 1024):
27 | md5 = hashlib.md5()
28 | with open(fpath, 'rb') as f:
29 | for chunk in iter(lambda: f.read(chunk_size), b''):
30 | md5.update(chunk)
31 | return md5.hexdigest()
32 |
33 |
34 | def check_md5(fpath, md5, **kwargs):
35 | return md5 == calculate_md5(fpath, **kwargs)
36 |
37 |
38 | def check_integrity(fpath, md5=None):
39 | if not os.path.isfile(fpath):
40 | return False
41 | if md5 is None:
42 | return True
43 | return check_md5(fpath, md5)
44 |
45 |
46 | def makedir_exist_ok(dirpath):
47 | """
48 | Python2 support for os.makedirs(.., exist_ok=True)
49 | """
50 | try:
51 | os.makedirs(dirpath)
52 | except OSError as e:
53 | if e.errno == errno.EEXIST:
54 | pass
55 | else:
56 | raise
57 |
58 |
59 | def download_url(url, root, filename=None, md5=None):
60 | """Download a file from a url and place it in root.
61 | Args:
62 | url (str): URL to download file from
63 | root (str): Directory to place downloaded file in
64 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
65 | md5 (str, optional): MD5 checksum of the download. If None, do not check
66 | """
67 | from six.moves import urllib
68 |
69 | root = os.path.expanduser(root)
70 | if not filename:
71 | filename = os.path.basename(url)
72 | fpath = os.path.join(root, filename)
73 |
74 | makedir_exist_ok(root)
75 |
76 | # check if file is already present locally
77 | if check_integrity(fpath, md5):
78 | print('Using downloaded and verified file: ' + fpath)
79 | else: # download the file
80 | try:
81 | print('Downloading ' + url + ' to ' + fpath)
82 | urllib.request.urlretrieve(
83 | url, fpath,
84 | reporthook=gen_bar_updater()
85 | )
86 | except (urllib.error.URLError, IOError) as e:
87 | if url[:5] == 'https':
88 | url = url.replace('https:', 'http:')
89 | print('Failed download. Trying https -> http instead.'
90 | ' Downloading ' + url + ' to ' + fpath)
91 | urllib.request.urlretrieve(
92 | url, fpath,
93 | reporthook=gen_bar_updater()
94 | )
95 | else:
96 | raise e
97 | # check integrity of downloaded file
98 | if not check_integrity(fpath, md5):
99 | raise RuntimeError("File not found or corrupted.")
100 |
101 |
102 | def list_dir(root, prefix=False):
103 | """List all directories at a given root
104 | Args:
105 | root (str): Path to directory whose folders need to be listed
106 | prefix (bool, optional): If true, prepends the path to each result, otherwise
107 | only returns the name of the directories found
108 | """
109 | root = os.path.expanduser(root)
110 | directories = list(
111 | filter(
112 | lambda p: os.path.isdir(os.path.join(root, p)),
113 | os.listdir(root)
114 | )
115 | )
116 |
117 | if prefix is True:
118 | directories = [os.path.join(root, d) for d in directories]
119 |
120 | return directories
121 |
122 |
123 | def list_files(root, suffix, prefix=False):
124 | """List all files ending with a suffix at a given root
125 | Args:
126 | root (str): Path to directory whose folders need to be listed
127 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
128 | It uses the Python "str.endswith" method and is passed directly
129 | prefix (bool, optional): If true, prepends the path to each result, otherwise
130 | only returns the name of the files found
131 | """
132 | root = os.path.expanduser(root)
133 | files = list(
134 | filter(
135 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
136 | os.listdir(root)
137 | )
138 | )
139 |
140 | if prefix is True:
141 | files = [os.path.join(root, d) for d in files]
142 |
143 | return files
144 |
145 |
146 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
147 | """Download a Google Drive file from and place it in root.
148 | Args:
149 | file_id (str): id of file to be downloaded
150 | root (str): Directory to place downloaded file in
151 | filename (str, optional): Name to save the file under. If None, use the id of the file.
152 | md5 (str, optional): MD5 checksum of the download. If None, do not check
153 | """
154 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
155 | import requests
156 | url = "https://docs.google.com/uc?export=download"
157 |
158 | root = os.path.expanduser(root)
159 | if not filename:
160 | filename = file_id
161 | fpath = os.path.join(root, filename)
162 |
163 | makedir_exist_ok(root)
164 |
165 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
166 | print('Using downloaded and verified file: ' + fpath)
167 | else:
168 | session = requests.Session()
169 |
170 | response = session.get(url, params={'id': file_id}, stream=True)
171 | token = _get_confirm_token(response)
172 |
173 | if token:
174 | params = {'id': file_id, 'confirm': token}
175 | response = session.get(url, params=params, stream=True)
176 |
177 | _save_response_content(response, fpath)
178 |
179 |
180 | def _get_confirm_token(response):
181 | for key, value in response.cookies.items():
182 | if key.startswith('download_warning'):
183 | return value
184 |
185 | return None
186 |
187 |
188 | def _save_response_content(response, destination, chunk_size=32768):
189 | with open(destination, "wb") as f:
190 | pbar = tqdm(total=None)
191 | progress = 0
192 | for chunk in response.iter_content(chunk_size):
193 | if chunk: # filter out keep-alive new chunks
194 | f.write(chunk)
195 | progress += len(chunk)
196 | pbar.update(progress - pbar.n)
197 | pbar.close()
198 |
199 |
200 | def _is_tarxz(filename):
201 | return filename.endswith(".tar.xz")
202 |
203 |
204 | def _is_tar(filename):
205 | return filename.endswith(".tar")
206 |
207 |
208 | def _is_targz(filename):
209 | return filename.endswith(".tar.gz")
210 |
211 |
212 | def _is_tgz(filename):
213 | return filename.endswith(".tgz")
214 |
215 |
216 | def _is_gzip(filename):
217 | return filename.endswith(".gz") and not filename.endswith(".tar.gz")
218 |
219 |
220 | def _is_zip(filename):
221 | return filename.endswith(".zip")
222 |
223 |
224 | def extract_archive(from_path, to_path=None, remove_finished=False):
225 | if to_path is None:
226 | to_path = os.path.dirname(from_path)
227 |
228 | if _is_tar(from_path):
229 | with tarfile.open(from_path, 'r') as tar:
230 | tar.extractall(path=to_path)
231 | elif _is_targz(from_path) or _is_tgz(from_path):
232 | with tarfile.open(from_path, 'r:gz') as tar:
233 | tar.extractall(path=to_path)
234 | elif _is_tarxz(from_path) and PY3:
235 | # .tar.xz archive only supported in Python 3.x
236 | with tarfile.open(from_path, 'r:xz') as tar:
237 | tar.extractall(path=to_path)
238 | elif _is_gzip(from_path):
239 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
240 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
241 | out_f.write(zip_f.read())
242 | elif _is_zip(from_path):
243 | with zipfile.ZipFile(from_path, 'r') as z:
244 | z.extractall(to_path)
245 | else:
246 | raise ValueError("Extraction of {} not supported".format(from_path))
247 |
248 | if remove_finished:
249 | os.remove(from_path)
250 |
251 |
252 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
253 | md5=None, remove_finished=False):
254 | download_root = os.path.expanduser(download_root)
255 | if extract_root is None:
256 | extract_root = download_root
257 | if not filename:
258 | filename = os.path.basename(url)
259 |
260 | download_url(url, download_root, filename, md5)
261 |
262 | archive = os.path.join(download_root, filename)
263 | print("Extracting {} to {}".format(archive, extract_root))
264 | extract_archive(archive, extract_root, remove_finished)
265 |
266 |
267 | def iterable_to_str(iterable):
268 | return "'" + "', '".join([str(item) for item in iterable]) + "'"
269 |
270 |
271 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
272 | if not isinstance(value, torch._six.string_classes):
273 | if arg is None:
274 | msg = "Expected type str, but got type {type}."
275 | else:
276 | msg = "Expected type str for argument {arg}, but got type {type}."
277 | msg = msg.format(type=type(value), arg=arg)
278 | raise ValueError(msg)
279 |
280 | if valid_values is None:
281 | return value
282 |
283 | if value not in valid_values:
284 | if custom_msg is not None:
285 | msg = custom_msg
286 | else:
287 | msg = ("Unknown value '{value}' for argument {arg}. "
288 | "Valid values are {{{valid_values}}}.")
289 | msg = msg.format(value=value, arg=arg,
290 | valid_values=iterable_to_str(valid_values))
291 | raise ValueError(msg)
292 |
293 | return value
294 |
--------------------------------------------------------------------------------
/Asym-RoCL/linear_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 |
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import csv
7 | import os
8 | import json
9 | import copy
10 |
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.optim as optim
18 | import torchvision.transforms as transforms
19 | import torchvision.datasets as datasets
20 |
21 | import data_loader
22 | import model_loader
23 | import models
24 | from models.projector import Projector
25 |
26 | from argument import linear_parser, print_args
27 | from utils import progress_bar, checkpoint_train
28 | from collections import OrderedDict
29 | from attack_lib import FastGradientSignUntargeted
30 | from loss import pairwise_similarity, NT_xent
31 |
32 | args = linear_parser()
33 | use_cuda = torch.cuda.is_available()
34 | if use_cuda:
35 | ngpus_per_node = torch.cuda.device_count()
36 |
37 | if args.local_rank % ngpus_per_node==0:
38 | print_args(args)
39 |
40 | def print_status(string):
41 | if args.local_rank % ngpus_per_node == 0:
42 | print(string)
43 |
44 | print_status(torch.cuda.device_count())
45 | print_status('Using CUDA..')
46 |
47 | best_acc = 0 # best test accuracy
48 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
49 |
50 | if args.seed != 0:
51 | torch.manual_seed(args.seed)
52 |
53 | # Data
54 | print_status('==> Preparing data..')
55 | if not (args.train_type=='linear_eval'):
56 | assert('wrong train phase...')
57 | else:
58 | trainloader, traindst, testloader, testdst = data_loader.get_dataset(args)
59 |
60 | if args.dataset == 'cifar-10' or args.dataset=='mnist':
61 | num_outputs = 10
62 | elif args.dataset == 'cifar-100':
63 | num_outputs = 100
64 |
65 | if args.model == 'ResNet50':
66 | expansion = 4
67 | else:
68 | expansion = 1
69 |
70 | # Model
71 | print_status('==> Building model..')
72 | train_type = args.train_type
73 |
74 | def load(args, epoch):
75 | model = model_loader.get_model(args)
76 |
77 | if epoch == 0:
78 | add = ''
79 | else:
80 | add = '_epoch_'+str(epoch)
81 |
82 | checkpoint_ = torch.load(args.load_checkpoint+add)
83 |
84 | new_state_dict = OrderedDict()
85 | for k, v in checkpoint_['model'].items():
86 | name = k[7:]
87 | new_state_dict[name] = v
88 |
89 | model.load_state_dict(new_state_dict)
90 |
91 | if args.ss:
92 | projector = Projector(expansion=expansion)
93 | checkpoint_p = torch.load(args.load_checkpoint+'_projector'+add)
94 | new_state_dict = OrderedDict()
95 | for k, v in checkpoint_p['model'].items():
96 | name = k[7:]
97 | new_state_dict[name] = v
98 | projector.load_state_dict(new_state_dict)
99 |
100 | if args.dataset=='cifar-10':
101 | Linear = nn.Sequential(nn.Linear(512*expansion, 10))
102 | elif args.dataset=='cifar-100':
103 | Linear = nn.Sequential(nn.Linear(512*expansion, 100))
104 |
105 | model_params = []
106 | if args.finetune:
107 | model_params += model.parameters()
108 | if args.ss:
109 | model_params += projector.parameters()
110 | model_params += Linear.parameters()
111 | loptim = torch.optim.SGD(model_params, lr = args.lr, momentum=0.9, weight_decay=5e-4)
112 |
113 | use_cuda = torch.cuda.is_available()
114 | if use_cuda:
115 | ngpus_per_node = torch.cuda.device_count()
116 | model.cuda()
117 | Linear.cuda()
118 | model = nn.DataParallel(model)
119 | Linear = nn.DataParallel(Linear)
120 | if args.ss:
121 | projector.cuda()
122 | projector = nn.DataParallel(projector)
123 | else:
124 | assert("Need to use GPU...")
125 |
126 | print_status('Using CUDA..')
127 | cudnn.benchmark = True
128 |
129 | if args.adv_img:
130 | attack_info = 'Adv_train_epsilon_'+str(args.epsilon)+'_alpha_'+ str(args.alpha) + '_min_val_' + str(args.min) + '_max_val_' + str(args.max) + '_max_iters_' + str(args.k) + '_type_' + str(args.attack_type) + '_randomstart_' + str(args.random_start)
131 | print_status("Adversarial training info...")
132 | print_status(attack_info)
133 |
134 | attacker = FastGradientSignUntargeted(model, linear=Linear, epsilon=args.epsilon, alpha=args.alpha, min_val=args.min, max_val=args.max, max_iters=args.k, _type=args.attack_type)
135 |
136 | if args.adv_img:
137 | if args.ss:
138 | return model, Linear, projector, loptim, attacker
139 | return model, Linear, 'None', loptim, attacker
140 | if args.ss:
141 | return model, Linear, projector, loptim, 'None'
142 | return model, Linear, 'None', loptim, 'None'
143 |
144 | criterion = nn.CrossEntropyLoss()
145 |
146 |
147 | def linear_train(epoch, model, Linear, projector, loptim, attacker=None):
148 | Linear.train()
149 | if args.finetune:
150 | model.train()
151 | if args.ss:
152 | projector.train()
153 | else:
154 | model.eval()
155 |
156 | total_loss = 0
157 | correct = 0
158 | total = 0
159 |
160 | for batch_idx, (ori, inputs, inputs_2, target) in enumerate(trainloader):
161 | ori, inputs_1, inputs_2, target = ori.cuda(), inputs.cuda(), inputs_2.cuda(), target.cuda()
162 | input_flag = False
163 | if args.trans:
164 | inputs = inputs_1
165 | else:
166 | inputs = ori
167 |
168 | if args.adv_img:
169 | advinputs = attacker.perturb(original_images=inputs, labels=target, random_start=args.random_start)
170 |
171 | if args.clean:
172 | total_inputs = inputs
173 | total_targets = target
174 | input_flag = True
175 |
176 | if args.ss:
177 | total_inputs = torch.cat((inputs, inputs_2))
178 | total_targets = torch.cat((target, target))
179 |
180 | if args.adv_img:
181 | if input_flag:
182 | total_inputs = torch.cat((total_inputs, advinputs))
183 | total_targets = torch.cat((total_targets, target))
184 | else:
185 | total_inputs = advinputs
186 | total_targets = target
187 | input_flag = True
188 |
189 | if not input_flag:
190 | assert('choose the linear evaluation data type (clean, adv_img)')
191 |
192 | feat = model(total_inputs)
193 | if args.ss:
194 | output_p = projector(feat)
195 | B = ori.size(0)
196 |
197 | similarity, _ = pairwise_similarity(output_p[:2*B,:2*B], temperature=args.temperature, multi_gpu=False, adv_type = 'None')
198 | simloss = NT_xent(similarity, 'None')
199 |
200 | output = Linear(feat)
201 |
202 | _, predx = torch.max(output.data, 1)
203 | loss = criterion(output, total_targets)
204 |
205 | if args.ss:
206 | loss += simloss
207 |
208 | correct += predx.eq(total_targets.data).cpu().sum().item()
209 | total += total_targets.size(0)
210 | acc = 100.*correct/total
211 |
212 | total_loss += loss.data
213 |
214 | loptim.zero_grad()
215 | loss.backward()
216 | loptim.step()
217 |
218 | progress_bar(batch_idx, len(trainloader),
219 | 'Loss: {:.4f} | Acc: {:.2f}'.format(total_loss/(batch_idx+1), acc))
220 |
221 | print ("Epoch: {}, train accuracy: {}".format(epoch, acc))
222 |
223 | return acc, model, Linear, projector, loptim
224 |
225 | def test(model, Linear):
226 | global best_acc
227 |
228 | model.eval()
229 | Linear.eval()
230 |
231 | test_loss = 0
232 | correct = 0
233 | total = 0
234 |
235 | for idx, (image, label) in enumerate(testloader):
236 | img = image.cuda()
237 | y = label.cuda()
238 |
239 | out = Linear(model(img))
240 |
241 | _, predx = torch.max(out.data, 1)
242 | loss = criterion(out, y)
243 |
244 | correct += predx.eq(y.data).cpu().sum().item()
245 | total += y.size(0)
246 | acc = 100.*correct/total
247 |
248 | test_loss += loss.data
249 | if args.local_rank % ngpus_per_node == 0:
250 | progress_bar(idx, len(testloader),'Testing Loss {:.3f}, acc {:.3f}'.format(test_loss/(idx+1), acc))
251 |
252 | print ("Test accuracy: {0}".format(acc))
253 |
254 | return (acc, model, Linear)
255 |
256 | def adjust_lr(epoch, optim):
257 | lr = args.lr
258 | if args.dataset=='cifar-10' or args.dataset=='cifar-100':
259 | lr_list = [30,50,100]
260 | if epoch>=lr_list[0]:
261 | lr = lr/10
262 | if epoch>=lr_list[1]:
263 | lr = lr/10
264 | if epoch>=lr_list[2]:
265 | lr = lr/10
266 |
267 | for param_group in optim.param_groups:
268 | param_group['lr'] = lr
269 |
270 | ##### Log file for training selected tasks #####
271 | if not os.path.isdir('results'):
272 | os.mkdir('results')
273 |
274 | args.name += ('_Evaluate_'+ args.train_type + '_' +args.model + '_' + args.dataset)
275 | loginfo = 'results/log_generalization_' + args.name + '_' + str(args.seed)
276 | logname = (loginfo+ '.csv')
277 |
278 | with open(logname, 'w') as logfile:
279 | logwriter = csv.writer(logfile, delimiter=',')
280 | logwriter.writerow(['epoch', 'train acc','test acc'])
281 |
282 | if args.epochwise:
283 | for k in range(100,1000,100):
284 | model, linear, projector, loptim, attacker = load(args, k)
285 | print('loading.......epoch ', str(k))
286 | ##### Linear evaluation #####
287 | for i in range(args.epoch):
288 | print('Epoch ', i)
289 | train_acc, model, linear, projector, loptim = linear_train(i, model, linear, projector, loptim, attacker)
290 | test_acc, model, linear = test(model, linear)
291 | adjust_lr(i, loptim)
292 |
293 | checkpoint_train(model, test_acc, args.epoch, args, loptim, save_name_add='epochwise'+str(k))
294 | checkpoint_train(linear, test_acc, args.epoch, args, loptim, save_name_add='epochwise'+str(k)+'_linear')
295 | if args.local_rank % ngpus_per_node == 0:
296 | with open(logname, 'a') as logfile:
297 | logwriter = csv.writer(logfile, delimiter=',')
298 | logwriter.writerow([k, train_acc, test_acc])
299 |
300 | model, linear, projector, loptim, attacker = load(args, 0)
301 |
302 | ##### Linear evaluation #####
303 | for epoch in range(args.epoch):
304 | print('Epoch ', epoch)
305 |
306 | train_acc, model, linear, projector, loptim = linear_train(epoch, model=model, Linear=linear, projector=projector, loptim=loptim, attacker=attacker)
307 | test_acc, model, linear = test(model, linear)
308 | adjust_lr(epoch, loptim)
309 |
310 | if args.local_rank % ngpus_per_node == 0:
311 | with open(logname, 'a') as logfile:
312 | logwriter = csv.writer(logfile, delimiter=',')
313 | logwriter.writerow([epoch, train_acc, test_acc])
314 |
315 | checkpoint_train(model, test_acc, args.epoch, args, loptim)
316 | checkpoint_train(linear, test_acc, args.epoch, args, loptim, save_name_add='_linear')
317 |
318 | if args.local_rank % ngpus_per_node == 0:
319 | with open(logname, 'a') as logfile:
320 | logwriter = csv.writer(logfile, delimiter=',')
321 | logwriter.writerow([1000, train_acc, test_acc])
322 |
--------------------------------------------------------------------------------
/Asym-AdvCL/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import random
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 |
9 |
10 | class SupConLoss(nn.Module):
11 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
12 | It also supports the unsupervised contrastive loss in SimCLR"""
13 |
14 | def __init__(self, args, temperature=0.07, contrast_mode='all',
15 | base_temperature=0.07):
16 | super(SupConLoss, self).__init__()
17 | self.temperature = temperature
18 | self.contrast_mode = contrast_mode
19 | self.base_temperature = base_temperature
20 |
21 | self.args = args
22 |
23 | def forward(self, features, labels=None, mask=None, stop_grad=False, stop_grad_sd=-1.0):
24 | """Compute loss for model. If both `labels` and `mask` are None,
25 | it degenerates to SimCLR unsupervised loss:
26 | https://arxiv.org/pdf/2002.05709.pdf
27 | Args:
28 | features: hidden vector of shape [bsz, n_views, ...].
29 | labels: ground truth of shape [bsz].
30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31 | has the same class as sample i. Can be asymmetric.
32 | Returns:
33 | A loss scalar.
34 | """
35 |
36 | device = (torch.device('cuda')
37 | if features.is_cuda
38 | else torch.device('cpu'))
39 |
40 | if len(features.shape) < 3:
41 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
42 | 'at least 3 dimensions are required')
43 | if len(features.shape) > 3:
44 | features = features.view(features.shape[0], features.shape[1], -1)
45 |
46 | batch_size = features.shape[0]
47 | if labels is not None and mask is not None:
48 | raise ValueError('Cannot define both `labels` and `mask`')
49 | elif labels is None and mask is None:
50 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
51 | elif labels is not None:
52 | labels = labels.contiguous().view(-1, 1)
53 | if labels.shape[0] != batch_size:
54 | raise ValueError('Num of labels does not match num of features')
55 | mask = torch.eq(labels, labels.T).float().to(device)
56 | else:
57 | mask = mask.float().to(device)
58 |
59 | contrast_count = features.shape[1]
60 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
61 | if self.contrast_mode == 'one':
62 | anchor_feature = features[:, 0]
63 | anchor_count = 1
64 | elif self.contrast_mode == 'all':
65 | anchor_feature = contrast_feature
66 | anchor_count = contrast_count
67 | else:
68 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
69 |
70 | # compute logits
71 | anchor_dot_contrast = torch.div(
72 | torch.matmul(anchor_feature, contrast_feature.T),
73 | self.temperature)
74 | if stop_grad:
75 | anchor_dot_contrast_stpg = torch.div(
76 | torch.matmul(anchor_feature, contrast_feature.T.detach()),
77 | self.temperature)
78 |
79 | # tile mask
80 | mask = mask.repeat(anchor_count, contrast_count)
81 | # mask-out self-contrast cases
82 | logits_mask = torch.scatter(
83 | torch.ones_like(mask),
84 | 1,
85 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
86 | 0
87 | )
88 | mask = mask * logits_mask
89 |
90 | # For hard negatives, code adapted from HCL (https://github.com/joshr17/HCL)
91 | # =============== hard neg params =================
92 | tau_plus = self.args.tau_plus
93 | beta = self.args.beta
94 | temperature = 0.5
95 | N = (batch_size - 1) * contrast_count
96 | # =============== reweight neg =================
97 | # for numerical stability
98 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
99 | logits = anchor_dot_contrast - logits_max.detach()
100 | exp_logits = torch.exp(logits)
101 | exp_logits_neg = exp_logits * (1 - mask) * logits_mask
102 | exp_logits_pos = exp_logits * mask
103 | pos = exp_logits_pos.sum(dim=1) / mask.sum(1)
104 |
105 | imp = (beta * (exp_logits_neg + 1e-9).log()).exp()
106 | reweight_logits_neg = (imp * exp_logits_neg) / imp.mean(dim=-1)
107 | Ng = (-tau_plus * N * pos + reweight_logits_neg.sum(dim=-1)) / (1 - tau_plus)
108 | # constrain (optional)
109 | Ng = torch.clamp(Ng, min=N * np.e ** (-1 / temperature))
110 | log_prob = -torch.log(exp_logits / (pos + Ng))
111 | # ===============================================
112 |
113 | loss_square = mask * log_prob # only positive positions have elements
114 |
115 | # mix_square = exp_logits
116 | mix_square = loss_square
117 |
118 | if stop_grad:
119 | logits_max_stpg, _ = torch.max(anchor_dot_contrast_stpg, dim=1, keepdim=True)
120 | logits_stpg = anchor_dot_contrast_stpg - logits_max_stpg.detach()
121 | # =============== reweight neg =================
122 | exp_logits_stpg = torch.exp(logits_stpg)
123 | exp_logits_neg_stpg = exp_logits_stpg * (1 - mask) * logits_mask
124 | exp_logits_pos_stpg = exp_logits_stpg * mask
125 | pos_stpg = exp_logits_pos_stpg.sum(dim=1) / mask.sum(1)
126 |
127 | imp_stpg = (beta * (exp_logits_neg_stpg + 1e-9).log()).exp()
128 | reweight_logits_neg_stpg = (imp_stpg * exp_logits_neg_stpg) / imp_stpg.mean(dim=-1)
129 | Ng_stpg = ((-tau_plus * N * pos_stpg + reweight_logits_neg_stpg.sum(dim=-1)) / (1 - tau_plus))
130 |
131 | # constrain (optional)
132 | Ng_stpg = torch.clamp(Ng_stpg, min=N * np.e ** (-1 / temperature))
133 | log_prob_stpg = -torch.log(exp_logits_stpg / (pos_stpg + Ng_stpg))
134 | # ===============================================
135 | tmp_square = mask * log_prob_stpg
136 | else:
137 | # tmp_square = exp_logits
138 | tmp_square = loss_square
139 | if stop_grad:
140 | ac_square = stop_grad_sd * tmp_square[batch_size:, 0:batch_size].T + (1 - stop_grad_sd) * tmp_square[
141 | 0:batch_size,
142 | batch_size:]
143 | else:
144 | ac_square = tmp_square[0:batch_size, batch_size:]
145 |
146 | mix_square[0:batch_size, batch_size:] = ac_square * self.args.adv_weight
147 | mix_square[batch_size:, 0:batch_size] = ac_square.T * self.args.adv_weight
148 |
149 | # compute mean of log-likelihood over positive
150 | mean_log_prob_pos = mix_square.sum(1) / mask.sum(1)
151 |
152 | # loss
153 | loss = (self.temperature / self.base_temperature) * mean_log_prob_pos
154 | loss = loss.view(anchor_count, batch_size).mean()
155 |
156 | return loss
157 |
158 |
159 | class ori_SupConLoss(nn.Module):
160 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
161 | It also supports the unsupervised contrastive loss in SimCLR"""
162 |
163 | def __init__(self, args, temperature=0.07, contrast_mode='all',
164 | base_temperature=0.07):
165 | super(ori_SupConLoss, self).__init__()
166 | self.temperature = temperature
167 | self.contrast_mode = contrast_mode
168 | self.base_temperature = base_temperature
169 |
170 | self.args = args
171 |
172 | def forward(self, features, labels=None, mask=None, stop_grad=False, stop_grad_sd=-1.0):
173 | """Compute loss for model. If both `labels` and `mask` are None,
174 | it degenerates to SimCLR unsupervised loss:
175 | https://arxiv.org/pdf/2002.05709.pdf
176 | Args:
177 | features: hidden vector of shape [bsz, n_views, ...].
178 | labels: ground truth of shape [bsz].
179 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
180 | has the same class as sample i. Can be asymmetric.
181 | Returns:
182 | A loss scalar.
183 | """
184 |
185 | device = (torch.device('cuda')
186 | if features.is_cuda
187 | else torch.device('cpu'))
188 |
189 | if len(features.shape) < 3:
190 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
191 | 'at least 3 dimensions are required')
192 | if len(features.shape) > 3:
193 | features = features.view(features.shape[0], features.shape[1], -1)
194 |
195 | batch_size = features.shape[0]
196 | if labels is not None and mask is not None:
197 | raise ValueError('Cannot define both `labels` and `mask`')
198 | elif labels is None and mask is None:
199 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
200 | elif labels is not None:
201 | labels = labels.contiguous().view(-1, 1)
202 | if labels.shape[0] != batch_size:
203 | raise ValueError('Num of labels does not match num of features')
204 | mask = torch.eq(labels, labels.T).float().to(device)
205 | else:
206 | mask = mask.float().to(device)
207 |
208 | contrast_count = features.shape[1]
209 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
210 | if self.contrast_mode == 'one':
211 | anchor_feature = features[:, 0]
212 | anchor_count = 1
213 | elif self.contrast_mode == 'all':
214 | anchor_feature = contrast_feature
215 | anchor_count = contrast_count
216 | else:
217 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
218 |
219 | # compute logits
220 | anchor_dot_contrast = torch.div(
221 | torch.matmul(anchor_feature, contrast_feature.T),
222 | self.temperature)
223 | if stop_grad:
224 | anchor_dot_contrast_stpg = torch.div(
225 | torch.matmul(anchor_feature, contrast_feature.T.detach()),
226 | self.temperature)
227 | # for numerical stability
228 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
229 | logits = anchor_dot_contrast - logits_max.detach()
230 |
231 | # tile mask
232 | mask = mask.repeat(anchor_count, contrast_count)
233 | # mask-out self-contrast cases
234 | logits_mask = torch.scatter(
235 | torch.ones_like(mask),
236 | 1,
237 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
238 | 0
239 | )
240 | mask = mask * logits_mask
241 |
242 | # compute log_prob
243 | exp_logits = torch.exp(logits) * logits_mask
244 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
245 |
246 | loss_square = mask * log_prob # only positive position have elements
247 | mix_square = loss_square
248 | if stop_grad:
249 | logits_max_stpg, _ = torch.max(anchor_dot_contrast_stpg, dim=1, keepdim=True)
250 | logits_stpg = anchor_dot_contrast_stpg - logits_max_stpg.detach()
251 | # compute log_prob
252 | exp_logits_stpg = torch.exp(logits_stpg) * logits_mask
253 | log_prob_stpg = logits_stpg - torch.log(exp_logits_stpg.sum(1, keepdim=True))
254 | loss_square_stpg = mask * log_prob_stpg
255 | tmp_square = loss_square_stpg
256 | else:
257 | tmp_square = loss_square
258 | if stop_grad:
259 | ac_square = stop_grad_sd * tmp_square[batch_size:, 0:batch_size].T + (1 - stop_grad_sd) * tmp_square[
260 | 0:batch_size,
261 | batch_size:]
262 | else:
263 | ac_square = tmp_square[0:batch_size, batch_size:]
264 |
265 | mix_square[0:batch_size, batch_size:] = ac_square * self.args.adv_weight
266 | mix_square[batch_size:, 0:batch_size] = ac_square.T * self.args.adv_weight
267 |
268 | # compute mean of log-likelihood over positive
269 | mean_log_prob_pos = mix_square.sum(1) / mask.sum(1)
270 |
271 | # loss
272 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
273 | loss = loss.view(anchor_count, batch_size).mean()
274 |
275 | return loss
276 |
--------------------------------------------------------------------------------
/Asym-RoCL/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def test_parser():
4 | parser = argparse.ArgumentParser(description='linear eval test')
5 | parser.add_argument('--train_type', default='linear_eval', type=str, help='standard')
6 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100')
7 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT')
8 | parser.add_argument('--model', default='ResNet18', type=str,
9 | help='model type ResNet18/ResNet50')
10 | parser.add_argument('--name', default='', type=str, help='name of run')
11 | parser.add_argument('--seed', default=2342, type=int, help='random seed')
12 | parser.add_argument('--batch-size', default=128, type=int, help='batch size / multi-gpu setting: batch per gpu')
13 |
14 | ##### arguments for data augmentation #####
15 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR, 1.0 for ImageNet')
16 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity')
17 |
18 | ##### arguments for distributted parallel #####
19 | parser.add_argument('--local_rank', type=int, default=0)
20 | parser.add_argument('--ngpu', type=int, default=1)
21 |
22 | parser.add_argument('--attack_type', type=str, default='linf',
23 | help='adversarial l_p')
24 | parser.add_argument('--epsilon', type=float, default=0.0314,
25 | help='maximum perturbation of adversaries (8/255 for cifar-10)')
26 | parser.add_argument('--alpha', type=float, default=0.007,
27 | help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
28 | parser.add_argument('--k', type=int, default=10,
29 | help='maximum iteration when generating adversarial examples')
30 | parser.add_argument('--random_start', type=bool, default=True,
31 | help='True for PGD')
32 |
33 | args = parser.parse_args()
34 |
35 | return args
36 |
37 | def linear_parser():
38 | parser = argparse.ArgumentParser(description='RoCL linear training')
39 |
40 | ##### arguments for RoCL Linear eval (LE) or Robust Linear eval (r-LE)#####
41 | parser.add_argument('--train_type', default='linear_eval', type=str, help='contrastive/linear eval/test')
42 | parser.add_argument('--finetune', default=False, type=bool, help='finetune the model')
43 | parser.add_argument('--epochwise', type=bool, default=False, help='epochwise saving...')
44 | parser.add_argument('--ss', default=False, type=bool, help='using self-supervised learning loss')
45 |
46 | parser.add_argument('--trans', default=False, type=bool, help='use transformed sample')
47 | parser.add_argument('--clean', default=False, type=bool, help='use clean sample')
48 | parser.add_argument('--adv_img', default=False, type=bool, help='use adversarial sample')
49 |
50 | ##### arguments for training #####
51 | parser.add_argument('--lr', default=0.2, type=float, help='learning rate')
52 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier')
53 |
54 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100')
55 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT')
56 | parser.add_argument('--model', default="ResNet18", type=str,
57 | help='model type ResNet18/ResNet50')
58 |
59 | parser.add_argument('--name', default='', type=str, help='name of run')
60 | parser.add_argument('--seed', default=2342, type=int, help='random seed')
61 | parser.add_argument('--batch-size', default=128, type=int, help='batch size / multi-gpu setting: batch per gpu')
62 | parser.add_argument('--epoch', default=150, type=int,
63 | help='total epochs to run')
64 |
65 | ##### arguments for data augmentation #####
66 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR')
67 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity')
68 |
69 | ##### arguments for distributted parallel #####
70 | parser.add_argument('--local_rank', type=int, default=0)
71 | parser.add_argument('--ngpu', type=int, default=1)
72 |
73 | ##### arguments for PGD attack & Adversarial Training #####
74 | parser.add_argument('--min', type=float, default=0.0,
75 | help='min for cliping image')
76 | parser.add_argument('--max', type=float, default=1.0,
77 | help='max for cliping image')
78 | parser.add_argument('--attack_type', type=str, default='linf',
79 | help='adversarial l_p')
80 | parser.add_argument('--epsilon', type=float, default=0.0314,
81 | help='maximum perturbation of adversaries (8/255 for cifar-10)')
82 | parser.add_argument('--alpha', type=float, default=0.007,
83 | help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
84 | parser.add_argument('--k', type=int, default=10,
85 | help='maximum iteration when generating adversarial examples')
86 |
87 | parser.add_argument('--random_start', type=bool, default=True,
88 | help='True for PGD')
89 | args = parser.parse_args()
90 |
91 | return args
92 |
93 |
94 | def parser():
95 | parser = argparse.ArgumentParser(description='PyTorch RoCL training')
96 | parser.add_argument('--module',action='store_true')
97 |
98 | ##### arguments for RoCL #####
99 | parser.add_argument('--lamda', default=256, type=float)
100 |
101 | parser.add_argument('--regularize_to', default='other', type=str, help='original/other')
102 | parser.add_argument('--attack_to', default='other', type=str, help='original/other')
103 |
104 | parser.add_argument('--loss_type', type=str, default='sim', help='loss type for Rep')
105 | parser.add_argument('--advtrain_type', default='Rep', type=str, help='Rep/None')
106 |
107 | ##### arguments for Training Self-Sup #####
108 | parser.add_argument('--train_type', default='contrastive', type=str, help='contrastive/linear eval/test')
109 |
110 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
111 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier')
112 | parser.add_argument('--decay', default=1e-6, type=float, help='weight decay')
113 |
114 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100')
115 |
116 | parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT')
117 | parser.add_argument('--resume', '-r', action='store_true',
118 | help='resume from checkpoint')
119 | parser.add_argument('--model', default='ResNet18', type=str,
120 | help='model type ResNet18/ResNet50')
121 |
122 | parser.add_argument('--name', default='', type=str, help='name of run')
123 | parser.add_argument('--seed', default=0, type=int, help='random seed')
124 | parser.add_argument('--batch-size', default=256, type=int, help='batch size / multi-gpu setting: batch per gpu')
125 | parser.add_argument('--epoch', default=1000, type=int,
126 | help='total epochs to run')
127 |
128 | ##### arguments for data augmentation #####
129 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR')
130 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity')
131 |
132 | ##### arguments for distributted parallel #####
133 | parser.add_argument('--local_rank', type=int, default=0)
134 | parser.add_argument('--ngpu', type=int, default=2)
135 |
136 | ##### arguments for PGD attack & Adversarial Training #####
137 | parser.add_argument('--min', type=float, default=0.0, help='min for cliping image')
138 | parser.add_argument('--max', type=float, default=1.0, help='max for cliping image')
139 | parser.add_argument('--attack_type', type=str, default='linf', help='adversarial l_p')
140 | parser.add_argument('--epsilon', type=float, default=0.0314,
141 | help='maximum perturbation of adversaries (8/255 for cifar-10)')
142 | parser.add_argument('--alpha', type=float, default=0.007, help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
143 | parser.add_argument('--k', type=int, default=7, help='maximum iteration when generating adversarial examples')
144 | parser.add_argument('--random_start', type=bool, default=True,
145 | help='True for PGD')
146 |
147 | ##### new
148 | parser.add_argument('--stop_grad', action='store_true')
149 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss: stop_grad stpg-weight")
150 | # must use with --stop_grad
151 | parser.add_argument('--stpg_degree', type=float, default=-1,
152 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch")
153 |
154 | ##### HN
155 | parser.add_argument('--HN', action='store_true', default=False,
156 | help='use HN')
157 | parser.add_argument('--tau', type=float, default=0.1, help='tau in HN')
158 | parser.add_argument('--beta', type=float, default=1.0, help='beta in HN')
159 | args = parser.parse_args()
160 |
161 | return args
162 |
163 | def parser_load():
164 | parser = argparse.ArgumentParser(description='PyTorch RoCL training')
165 | parser.add_argument('--module',action='store_true')
166 |
167 | parser.add_argument('--load_checkpoint', type=str, default='./checkpoint/ckpt.t7one_task_0', help='load checkpoint')
168 | ##### arguments for RoCL #####
169 | parser.add_argument('--lamda', default=256, type=float)
170 |
171 | parser.add_argument('--regularize_to', default='other', type=str, help='original/other')
172 | parser.add_argument('--attack_to', default='other', type=str, help='original/other')
173 |
174 | parser.add_argument('--loss_type', type=str, default='sim', help='loss type for Rep')
175 | parser.add_argument('--advtrain_type', default='Rep', type=str, help='Rep/None')
176 |
177 | ##### arguments for Training Self-Sup #####
178 | parser.add_argument('--train_type', default='contrastive', type=str, help='contrastive/linear eval/test')
179 |
180 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
181 | parser.add_argument('--lr_multiplier', default=15.0, type=float, help='learning rate multiplier')
182 | parser.add_argument('--decay', default=1e-6, type=float, help='weight decay')
183 |
184 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100')
185 |
186 | # parser.add_argument('--load_checkpoint', default='./checkpoint/ckpt.t7one_task_0', type=str, help='PATH TO CHECKPOINT')
187 | parser.add_argument('--resume', '-r', action='store_true',
188 | help='resume from checkpoint')
189 | parser.add_argument('--model', default='ResNet18', type=str,
190 | help='model type ResNet18/ResNet50')
191 |
192 | parser.add_argument('--name', default='', type=str, help='name of run')
193 | parser.add_argument('--seed', default=0, type=int, help='random seed')
194 | parser.add_argument('--batch-size', default=256, type=int, help='batch size / multi-gpu setting: batch per gpu')
195 | parser.add_argument('--epoch', default=1000, type=int,
196 | help='total epochs to run')
197 |
198 | ##### arguments for data augmentation #####
199 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR')
200 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity')
201 |
202 | ##### arguments for distributted parallel #####
203 | parser.add_argument('--local_rank', type=int, default=0)
204 | parser.add_argument('--ngpu', type=int, default=2)
205 |
206 | ##### arguments for PGD attack & Adversarial Training #####
207 | parser.add_argument('--min', type=float, default=0.0, help='min for cliping image')
208 | parser.add_argument('--max', type=float, default=1.0, help='max for cliping image')
209 | parser.add_argument('--attack_type', type=str, default='linf', help='adversarial l_p')
210 | parser.add_argument('--epsilon', type=float, default=0.0314,
211 | help='maximum perturbation of adversaries (8/255 for cifar-10)')
212 | parser.add_argument('--alpha', type=float, default=0.007, help='movement multiplier per iteration when generating adversarial examples (2/255=0.00784)')
213 | parser.add_argument('--k', type=int, default=7, help='maximum iteration when generating adversarial examples')
214 | parser.add_argument('--random_start', type=bool, default=True,
215 | help='True for PGD')
216 |
217 | ##### new
218 | parser.add_argument('--stop_grad', action='store_true')
219 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss: stop_grad stpg-weight")
220 | # must use with --stop_grad
221 | parser.add_argument('--stpg_degree', type=float, default=-1,
222 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch")
223 |
224 | ##### HN
225 | parser.add_argument('--HN', action='store_true', default=False,
226 | help='use HN')
227 | parser.add_argument('--tau', type=float, default=0.1, help='tau in HN')
228 | parser.add_argument('--beta', type=float, default=1.0, help='beta in HN')
229 | args = parser.parse_args()
230 |
231 | return args
232 |
233 | def print_args(args):
234 | for k, v in vars(args).items():
235 | print('{:<16} : {}'.format(k, v))
236 |
237 |
--------------------------------------------------------------------------------
/Asym-AdvCL/finetuning_advCL_SLF.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import argparse
5 | import time
6 | import math
7 |
8 | import torch
9 | import os
10 | import torch.backends.cudnn as cudnn
11 |
12 | from utils import AverageMeter
13 | from utils import accuracy
14 | from torch import optim
15 | from torchvision import transforms, datasets
16 | import models
17 | import torch.nn.functional as F
18 | from torch import nn
19 | from models.resnet_cifar import ResNet18
20 | from models.linear import LinearClassifier
21 | import tensorboard_logger as tb_logger
22 | from utils import load_BN_checkpoint
23 | from trades import trades_loss
24 |
25 |
26 | def set_loader(opt):
27 | # construct data loader
28 | train_transform = transforms.Compose([
29 | transforms.Resize(32),
30 | transforms.RandomCrop(32, padding=4),
31 | transforms.RandomHorizontalFlip(),
32 | transforms.ToTensor(),
33 | ])
34 | val_transform = transforms.Compose([
35 | transforms.Resize(32),
36 | transforms.ToTensor(),
37 | ])
38 |
39 | if opt.dataset == "cifar10":
40 | opt.data_folder += 'cifar10/'
41 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
42 | transform=train_transform,
43 | download=True)
44 | val_dataset = datasets.CIFAR10(root=opt.data_folder,
45 | train=False,
46 | transform=val_transform)
47 | elif opt.dataset == "stl10":
48 | train_dataset = datasets.STL10(root=opt.data_folder,
49 | transform=train_transform,
50 | download=True)
51 | val_dataset = datasets.STL10(root=opt.data_folder,
52 | train=False,
53 | transform=val_transform)
54 | elif opt.dataset == "cifar100":
55 | opt.data_folder += 'cifar100/'
56 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
57 | transform=train_transform,
58 | download=True)
59 | val_dataset = datasets.CIFAR100(root=opt.data_folder,
60 | train=False,
61 | transform=val_transform)
62 | else:
63 | raise NotImplementedError("Dataset Not Supported")
64 |
65 | train_sampler = None
66 | train_loader = torch.utils.data.DataLoader(
67 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
68 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
69 | val_loader = torch.utils.data.DataLoader(
70 | val_dataset, batch_size=256, shuffle=False,
71 | num_workers=8, pin_memory=True)
72 |
73 | return train_loader, val_loader
74 |
75 |
76 | def parse_option():
77 | parser = argparse.ArgumentParser('argument for training')
78 |
79 | parser.add_argument('--print_freq', type=int, default=10,
80 | help='print frequency')
81 | parser.add_argument('--save_freq', type=int, default=50,
82 | help='save frequency')
83 | parser.add_argument('--batch_size', type=int, default=512,
84 | help='batch_size')
85 | parser.add_argument('--num_workers', type=int, default=16,
86 | help='num of workers to use')
87 | parser.add_argument('--epochs', type=int, default=25,
88 | help='number of training epochs')
89 | # optimization
90 | parser.add_argument('--learning_rate', type=float, default=0.1,
91 | help='learning rate')
92 | parser.add_argument('--weight_decay', type=float, default=2e-4,
93 | help='weight decay')
94 | parser.add_argument('--momentum', type=float, default=0.9,
95 | help='momentum')
96 |
97 | # model dataset
98 | parser.add_argument('--dataset', type=str, default='cifar10',
99 | choices=['cifar10', 'cifar100', 'stl10'], help='dataset')
100 | # other setting
101 | parser.add_argument('--ckpt', type=str, default='',
102 | help='path to pre-trained model')
103 | parser.add_argument('--name', type=str, default='advcl_slf',
104 | help='name of the exp')
105 |
106 | # control the finetuning type
107 | parser.add_argument('--finetune_type', type=str, default='SLF', choices=['SLF', 'ALF', 'AFF_trades'])
108 |
109 | opt = parser.parse_args()
110 | print(opt)
111 |
112 | # set the path according to the environment
113 | opt.data_folder = '~/data/'
114 | if opt.dataset == 'cifar10' or opt.dataset == 'stl10':
115 | opt.n_cls = 10
116 | elif opt.dataset == 'cifar100':
117 | opt.n_cls = 100
118 | else:
119 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
120 |
121 | return opt
122 |
123 |
124 | # PGD attack model
125 | class AttackPGD(nn.Module):
126 | def __init__(self, model, classifier, config):
127 | super(AttackPGD, self).__init__()
128 | self.model = model
129 | self.classifier = classifier
130 | self.rand = config['random_start']
131 | self.step_size = config['step_size']
132 | self.epsilon = config['epsilon']
133 | assert config['loss_func'] == 'xent', 'Plz use xent for loss function.'
134 |
135 | def forward(self, inputs, targets, train=True, finetune_type="AFF"):
136 | x = inputs.detach()
137 | if self.rand:
138 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
139 | if train:
140 | num_step = 10
141 | else:
142 | num_step = 20
143 | for i in range(num_step):
144 | x.requires_grad_()
145 | with torch.enable_grad():
146 | features = self.model(x, return_feat=True)
147 | logits = self.classifier(features)
148 | loss = F.cross_entropy(logits, targets, size_average=False)
149 | grad = torch.autograd.grad(loss, [x])[0]
150 | x = x.detach() + self.step_size * torch.sign(grad.detach())
151 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
152 | x = torch.clamp(x, 0, 1)
153 | if finetune_type == "SLF" or finetune_type == "ALF":
154 | with torch.no_grad():
155 | features = self.model(x, return_feat=True)
156 | else:
157 | features = self.model(x, return_feat=True)
158 | return self.classifier(features), x
159 |
160 |
161 | def set_model(opt):
162 | model = ResNet18()
163 | criterion = torch.nn.CrossEntropyLoss()
164 | classifier = LinearClassifier(name=opt.name, feat_dim=512, num_classes=opt.n_cls)
165 | if len(opt.ckpt) > 2:
166 | print('loading from {}'.format(opt.ckpt))
167 | ckpt = torch.load(opt.ckpt, map_location='cpu')
168 | if 'state_dict' in ckpt.keys():
169 | state_dict = ckpt['state_dict']
170 | else:
171 | state_dict = ckpt['model']
172 | print(torch.cuda.device_count())
173 | if torch.cuda.is_available():
174 | if torch.cuda.device_count() > 1:
175 | model = torch.nn.DataParallel(model)
176 | state_dict, _ = load_BN_checkpoint(state_dict)
177 | model = model.cuda()
178 | classifier = classifier.cuda()
179 | criterion = criterion.cuda()
180 |
181 | config = {
182 | 'epsilon': 8.0 / 255.,
183 | 'num_steps': 10,
184 | 'step_size': 2.0 / 255,
185 | 'random_start': True,
186 | 'loss_func': 'xent',
187 | }
188 | net = AttackPGD(model, classifier, config)
189 | net = net.cuda()
190 |
191 | cudnn.benchmark = True
192 | model.load_state_dict(state_dict, strict=False)
193 | else:
194 | print("only GPU version supported")
195 | raise NotImplementedError
196 | else:
197 | print("please specify pretrained model")
198 | raise NotImplementedError
199 | return model, classifier, net, criterion
200 |
201 |
202 | def train(train_loader, model, classifier, net, criterion, optimizer, epoch, opt):
203 | """one epoch training"""
204 | model.eval()
205 | classifier.train()
206 |
207 | batch_time = AverageMeter()
208 | data_time = AverageMeter()
209 | losses = AverageMeter()
210 | top1 = AverageMeter()
211 |
212 | end = time.time()
213 | for idx, (images, labels) in enumerate(train_loader):
214 | data_time.update(time.time() - end)
215 |
216 | images = images.cuda(non_blocking=True)
217 | labels = labels.cuda(non_blocking=True)
218 | bsz = labels.shape[0]
219 |
220 | if opt.finetune_type == "AFF":
221 | _, x_adv = net(images, labels, train=False, finetune_type="AFF")
222 | # calculate robust loss
223 | beta = 6.0
224 | trainmode = 'adv'
225 |
226 | # fixmode f1: fix nothing, f2: fix previous 3 stages, f3: fix all except fc')
227 | loss, output = trades_loss(model=model,
228 | classifier=classifier,
229 | x_natural=images,
230 | x_adv=x_adv,
231 | y=labels,
232 | optimizer=optimizer,
233 | beta=beta,
234 | trainmode=trainmode,
235 | fixmode='f1')
236 | elif opt.finetune_type == "AFF_trades":
237 | # calculate robust loss
238 | step_size = 2. / 255.
239 | epsilon = 8. / 255.
240 | num_steps_train = 10
241 | beta = 6.0
242 | trainmode = 'adv'
243 |
244 | # fixmode f1: fix nothing, f2: fix previous 3 stages, f3: fix all except fc')
245 | loss, output = trades_loss(model=model,
246 | classifier=classifier,
247 | x_natural=images,
248 | x_adv="",
249 | y=labels,
250 | optimizer=optimizer,
251 | step_size=step_size,
252 | epsilon=epsilon,
253 | perturb_steps=num_steps_train,
254 | beta=beta,
255 | trainmode=trainmode,
256 | fixmode='f1',
257 | trades=True)
258 | elif opt.finetune_type == "SLF":
259 | with torch.no_grad():
260 | features = model(images, return_feat=True)
261 | output = classifier(features.detach())
262 | loss = criterion(output, labels)
263 | else: # ALF, use PGD examples to train the classifier
264 | output, _ = net(images, labels, train=False, finetune_type="ALF")
265 | loss = criterion(output, labels)
266 |
267 | # update metric
268 | losses.update(loss.item(), bsz)
269 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
270 | top1.update(acc1[0], bsz)
271 |
272 | # SGD
273 | optimizer.zero_grad()
274 | loss.backward()
275 | optimizer.step()
276 |
277 | # measure elapsed time
278 | batch_time.update(time.time() - end)
279 | end = time.time()
280 |
281 | # print info
282 | if (idx + 1) % opt.print_freq == 0:
283 | print('Train: [{0}][{1}/{2}]\t'
284 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
285 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
286 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
287 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
288 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
289 | data_time=data_time, loss=losses, top1=top1))
290 | sys.stdout.flush()
291 |
292 | return losses.avg, top1.avg
293 |
294 |
295 | def validate(val_loader, model, classifier, net, criterion, opt):
296 | """validation"""
297 | model.eval()
298 | classifier.eval()
299 |
300 | batch_time = AverageMeter()
301 | losses = AverageMeter()
302 | top1 = AverageMeter()
303 |
304 | top1_clean = AverageMeter()
305 |
306 | with torch.no_grad():
307 | end = time.time()
308 | for idx, (images, labels) in enumerate(val_loader):
309 | images = images.float().cuda()
310 | labels = labels.cuda()
311 | bsz = labels.shape[0]
312 |
313 | # forward
314 | output, _ = net(images, labels, train=False)
315 | loss = criterion(output, labels)
316 |
317 | features_clean = model(images, return_feat=True)
318 | output_clean = classifier(features_clean)
319 | acc1_clean, acc5_clean = accuracy(output_clean, labels, topk=(1, 5))
320 |
321 | # update metric
322 | losses.update(loss.item(), bsz)
323 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
324 | top1.update(acc1[0], bsz)
325 | top1_clean.update(acc1_clean[0], bsz)
326 |
327 | # measure elapsed time
328 | batch_time.update(time.time() - end)
329 | end = time.time()
330 |
331 | if idx % opt.print_freq == 0:
332 | print('Test: [{0}/{1}]\t'
333 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
334 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
335 | 'Acc@1 Clean {top1_clean.val:.4f} ({top1_clean.avg:.4f})\t'
336 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
337 | idx, len(val_loader), batch_time=batch_time,
338 | loss=losses, top1=top1, top1_clean=top1_clean))
339 |
340 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
341 | print(' * Acc@1 Clean {top1_clean.avg:.3f}'.format(top1_clean=top1_clean))
342 | return losses.avg, top1.avg, top1_clean.avg
343 |
344 |
345 | def adjust_lr(lr, optimizer, epoch):
346 | if epoch >= 15:
347 | lr /= 10
348 | if epoch >= 20:
349 | lr /= 10
350 | for param_group in optimizer.param_groups:
351 | param_group['lr'] = lr
352 |
353 |
354 | def main():
355 | best_acc = 0
356 | best_acc_clean = 0
357 | opt = parse_option()
358 |
359 | log_ra = 0
360 | log_ta = 0
361 |
362 | # build data loader
363 | train_loader, val_loader = set_loader(opt)
364 |
365 | # build model and criterion
366 | model, classifier, net, criterion = set_model(opt)
367 |
368 | # build optimizer
369 | if opt.finetune_type == "SLF" or opt.finetune_type == "ALF":
370 | params = list(classifier.parameters())
371 | else:
372 | params = list(model.parameters()) + list(classifier.parameters())
373 | optimizer = optim.SGD(params,
374 | lr=opt.learning_rate,
375 | momentum=opt.momentum,
376 | weight_decay=opt.weight_decay)
377 | logname = ('logger/' + '{}_linearnormal'.format(opt.name+str(opt.learning_rate)))
378 | logger = tb_logger.Logger(logdir=logname, flush_secs=2)
379 |
380 | # training routine
381 | for epoch in range(1, opt.epochs + 1):
382 | adjust_lr(opt.learning_rate, optimizer, epoch - 1)
383 | # train for one epoch
384 | time1 = time.time()
385 | loss, acc = train(train_loader, model, classifier, net, criterion,
386 | optimizer, epoch, opt)
387 | logger.log_value('train_loss', loss, epoch)
388 | logger.log_value('train_acc', acc, epoch)
389 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
390 | time2 = time.time()
391 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
392 | epoch, time2 - time1, acc))
393 | # eval for one epoch
394 | loss, val_acc, val_acc_clean = validate(val_loader, model, classifier, net, criterion, opt)
395 | logger.log_value('val_loss', loss, epoch)
396 | logger.log_value('val_acc', val_acc, epoch)
397 | logger.log_value('val_acc_clean', val_acc_clean, epoch)
398 |
399 | logger.log_value('best_val_acc', log_ra, epoch)
400 | logger.log_value('best_val_acc_clean', log_ta, epoch)
401 |
402 | if val_acc > best_acc:
403 | best_acc = val_acc
404 | best_acc_clean = val_acc_clean
405 | log_ra = val_acc
406 | log_ta = val_acc_clean
407 | state = {
408 | 'model': model.state_dict(),
409 | 'classifier': classifier.state_dict(),
410 | 'epoch': epoch,
411 | 'rng_state': torch.get_rng_state()
412 | }
413 | if not os.path.exists('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))):
414 | os.makedirs('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate)))
415 | if opt.finetune_type == "AFF_trades":
416 | torch.save(state, './checkpoint/{}/best_aff.ckpt'.format(opt.name+str(opt.learning_rate), epoch))
417 | elif opt.finetune_type == "ALF":
418 | torch.save(state, './checkpoint/{}/best_alf.ckpt'.format(opt.name+str(opt.learning_rate), epoch))
419 | else:
420 | torch.save(state, './checkpoint/{}/best.ckpt'.format(opt.name+str(opt.learning_rate), epoch))
421 |
422 | if epoch % opt.save_freq == 0:
423 | state = {
424 | 'model': model.state_dict(),
425 | 'classifier': classifier.state_dict(),
426 | 'epoch': epoch,
427 | 'rng_state': torch.get_rng_state()
428 | }
429 | if not os.path.exists('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate))):
430 | os.makedirs('./checkpoint/{}/'.format(opt.name+str(opt.learning_rate)))
431 | torch.save(state, './checkpoint/{}/ep_{}.ckpt'.format(opt.name+str(opt.learning_rate), epoch))
432 |
433 | print('best accuracy: {:.2f}'.format(best_acc))
434 | print('best accuracy clean: {:.2f}'.format(best_acc_clean))
435 |
436 |
437 | if __name__ == '__main__':
438 | main()
439 |
--------------------------------------------------------------------------------
/Asym-AdvCL/pretraining_advCL.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import numpy as np
5 | import os, csv
6 | from dataset import CIFAR10IndexPseudoLabelEnsemble, CIFAR100IndexPseudoLabelEnsemble
7 | import pickle
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchvision.transforms as transforms
12 | import torch.utils.data as Data
13 | import torch.backends.cudnn as cudnn
14 |
15 | from utils import progress_bar, TwoCropTransformAdv
16 | from losses import SupConLoss, ori_SupConLoss
17 | import tensorboard_logger as tb_logger
18 | from models.resnet_cifar_multibn_ensembleFC import resnet18 as ResNet18
19 | import random
20 | from fr_util import generate_high
21 | from utils import adjust_learning_rate, warmup_learning_rate, AverageMeter
22 | import apex
23 |
24 | # ================================================================== #
25 | # Inputs and Pre-definition #
26 | # ================================================================== #
27 |
28 | # Arguments
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--name', type=str, default='advcl_cifar10',
31 | help='name of the run')
32 | parser.add_argument('--cname', type=str, default='imagenet_clPretrain',
33 | help='')
34 | parser.add_argument('--batch_size', type=int, default=512,
35 | help='batch size')
36 | parser.add_argument('--epoch', type=int, default=1000,
37 | help='total epochs')
38 | parser.add_argument('--save-epoch', type=int, default=100,
39 | help='save epochs')
40 | parser.add_argument('--epsilon', type=float, default=8,
41 | help='The upper bound change of L-inf norm on input pixels')
42 | parser.add_argument('--iter', type=int, default=5,
43 | help='The number of iterations for iterative attacks')
44 | parser.add_argument('--radius', type=int, default=8,
45 | help='radius of low freq images')
46 | parser.add_argument('--ce_weight', type=float, default=0.2,
47 | help='cross entp weight')
48 |
49 | # contrastive related
50 | parser.add_argument('-t', '--nce_t', default=0.5, type=float,
51 | help='temperature')
52 | parser.add_argument('--seed', default=0, type=float,
53 | help='random seed')
54 | parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
55 | parser.add_argument('--cosine', action='store_true',
56 | help='using cosine annealing')
57 | parser.add_argument('--warm', action='store_true',
58 | help='warm-up for large batch training')
59 | parser.add_argument('--learning_rate', type=float, default=0.5,
60 | help='learning rate')
61 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
62 | help='where to decay lr, can be a list')
63 | parser.add_argument('--lr_decay_rate', type=float, default=0.1,
64 | help='decay rate for learning rate')
65 | parser.add_argument('--weight_decay', type=float, default=1e-4,
66 | help='weight decay')
67 | parser.add_argument('--momentum', type=float, default=0.9,
68 | help='momentum')
69 |
70 | # ====== HN ======
71 | parser.add_argument('--attack_ori', action='store_true', default=False,
72 | help='attack use ori criterion')
73 | parser.add_argument('--HN', action='store_true', default=False,
74 | help='use HN')
75 | parser.add_argument('--tau_plus', type=float, default=0.1, help="tau-plus in HN")
76 | parser.add_argument('--beta', type=float, default=1.0, help="beta in HN")
77 |
78 | # ====== stop grad ======
79 | parser.add_argument('--stop_grad', action='store_true', default=False, help="whether to stop gradient")
80 | parser.add_argument('--adv_weight', type=float, default=1, help="weight of adv loss")
81 | parser.add_argument('--d_min', type=float, default=0.4, help="min distance in adaptive grad stopping")
82 | parser.add_argument('--d_max', type=float, default=0, help="max distance in adaptive grad stopping")
83 | # must use with --stop_grad
84 | parser.add_argument('--stpg_degree', type=float, default=-1.0,
85 | help="stop degree, range from 0 to 1, 0 is totally stop for clean branch")
86 | parser.add_argument('--stop_grad_adaptive', type=int, default=-1, help="adaptively stop grad")
87 |
88 | args = parser.parse_args()
89 | args.epochs = args.epoch
90 | args.decay = args.weight_decay
91 | args.cosine = True
92 | import math
93 |
94 | if args.batch_size > 256:
95 | args.warm = True
96 | if args.warm:
97 | args.warmup_from = 0.01
98 | args.warm_epochs = 10
99 | if args.cosine:
100 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3)
101 | args.warmup_to = eta_min + (args.learning_rate - eta_min) * (
102 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
103 | else:
104 | args.warmup_to = args.learning_rate
105 |
106 | print(args)
107 |
108 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
109 | # Device configuration
110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111 | config = {
112 | 'epsilon': args.epsilon / 255.,
113 | 'num_steps': args.iter,
114 | 'step_size': 2.0 / 255,
115 | 'random_start': True,
116 | 'loss_func': 'xent',
117 | }
118 | # ================================================================== #
119 | # Data and Pre-processing #
120 | # ================================================================== #
121 | print('=====> Preparing data...')
122 | # Multi-cuda
123 | if torch.cuda.is_available():
124 | n_gpu = torch.cuda.device_count()
125 | batch_size = args.batch_size
126 |
127 | transform_train = transforms.Compose([
128 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
129 | transforms.RandomHorizontalFlip(),
130 | transforms.RandomApply([
131 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
132 | ], p=0.8),
133 | transforms.RandomGrayscale(p=0.2),
134 | transforms.ToTensor(),
135 | ])
136 | train_transform_org = transforms.Compose([
137 | transforms.RandomCrop(32, padding=4),
138 | transforms.RandomHorizontalFlip(),
139 | transforms.ToTensor(),
140 | ])
141 | transform_train = TwoCropTransformAdv(transform_train, train_transform_org)
142 |
143 | transform_test = transforms.Compose([
144 | transforms.ToTensor(),
145 | ])
146 |
147 | label_pseudo_train_list = []
148 | num_classes_list = [2, 10, 50, 100, 500]
149 |
150 | if args.dataset == "cifar10":
151 | dict_name = 'data/{}_pseudo_labels.pkl'.format(args.cname)
152 | elif args.dataset == "cifar100":
153 | dict_name = 'data/cifar100_pseudo_labels.pkl'
154 | f = open(dict_name, 'rb')
155 | feat_label_dict = pickle.load(f) # dump data to f
156 | f.close()
157 | for i in range(5):
158 | class_num = num_classes_list[i]
159 | key_train = 'pseudo_train_{}'.format(class_num)
160 | label_pseudo_train = feat_label_dict[key_train]
161 | label_pseudo_train_list.append(label_pseudo_train)
162 |
163 | data_path = "~/data/"
164 |
165 | if args.dataset == "cifar10":
166 | train_dataset = CIFAR10IndexPseudoLabelEnsemble(root=data_path + 'cifar10/',
167 | transform=transform_train,
168 | pseudoLabel_002=label_pseudo_train_list[0],
169 | pseudoLabel_010=label_pseudo_train_list[1],
170 | pseudoLabel_050=label_pseudo_train_list[2],
171 | pseudoLabel_100=label_pseudo_train_list[3],
172 | pseudoLabel_500=label_pseudo_train_list[4],
173 | download=True)
174 | elif args.dataset == "cifar100":
175 | train_dataset = CIFAR100IndexPseudoLabelEnsemble(root=data_path + 'cifar100/',
176 | transform=transform_train,
177 | pseudoLabel_002=label_pseudo_train_list[0],
178 | pseudoLabel_010=label_pseudo_train_list[1],
179 | pseudoLabel_050=label_pseudo_train_list[2],
180 | pseudoLabel_100=label_pseudo_train_list[3],
181 | pseudoLabel_500=label_pseudo_train_list[4],
182 | download=True)
183 | # Data Loader
184 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
185 | batch_size=batch_size,
186 | shuffle=True,
187 | num_workers=n_gpu * 4)
188 |
189 |
190 | # ================================================================== #
191 | # Model, Loss and Optimizer #
192 | # ================================================================== #
193 |
194 | # PGD attack model
195 | class AttackPGD(nn.Module):
196 | def __init__(self, model, config):
197 | super(AttackPGD, self).__init__()
198 | self.model = model
199 | self.rand = config['random_start']
200 | self.step_size = config['step_size']
201 | self.epsilon = config['epsilon']
202 | self.num_steps = config['num_steps']
203 | assert config['loss_func'] == 'xent', 'Plz use xent for loss function.'
204 |
205 | def forward(self, images_t1, images_t2, images_org, targets, criterion):
206 | x1 = images_t1.clone().detach()
207 | x2 = images_t2.clone().detach()
208 | x_cl = images_org.clone().detach()
209 | x_ce = images_org.clone().detach()
210 |
211 | images_org_high = generate_high(x_cl.clone(), r=args.radius)
212 | x_HFC = images_org_high.clone().detach()
213 |
214 | if self.rand:
215 | x_cl = x_cl + torch.zeros_like(x1).uniform_(-self.epsilon, self.epsilon)
216 | x_ce = x_ce + torch.zeros_like(x1).uniform_(-self.epsilon, self.epsilon)
217 |
218 | for i in range(self.num_steps):
219 | x_cl.requires_grad_()
220 | x_ce.requires_grad_()
221 | with torch.enable_grad():
222 | f_proj, f_pred = self.model(x_cl, bn_name='pgd', contrast=True)
223 | fce_proj, fce_pred, logits_ce = self.model(x_ce, bn_name='pgd_ce', contrast=True, CF=True,
224 | return_logits=True, nonlinear=False)
225 | f1_proj, f1_pred = self.model(x1, bn_name='normal', contrast=True)
226 | f2_proj, f2_pred = self.model(x2, bn_name='normal', contrast=True)
227 | f_high_proj, f_high_pred = self.model(x_HFC, bn_name='normal', contrast=True)
228 | features = torch.cat(
229 | [f_proj.unsqueeze(1), f1_proj.unsqueeze(1), f2_proj.unsqueeze(1), f_high_proj.unsqueeze(1)], dim=1)
230 | loss_contrast = criterion(features, stop_grad=False)
231 | loss_ce = 0
232 | for label_idx in range(5):
233 | tgt = targets[label_idx].long()
234 | lgt = logits_ce[label_idx]
235 | loss_ce += F.cross_entropy(lgt, tgt, size_average=False, ignore_index=-1) / 5.
236 | loss = loss_contrast + loss_ce * args.ce_weight
237 | # torch.autograd.set_detect_anomaly(True)
238 | grad_x_cl, grad_x_ce = torch.autograd.grad(loss, [x_cl, x_ce])
239 | x_cl = x_cl.detach() + self.step_size * torch.sign(grad_x_cl.detach())
240 | x_cl = torch.min(torch.max(x_cl, images_org - self.epsilon), images_org + self.epsilon)
241 | x_cl = torch.clamp(x_cl, 0, 1)
242 | x_ce = x_ce.detach() + self.step_size * torch.sign(grad_x_ce.detach())
243 | x_ce = torch.min(torch.max(x_ce, images_org - self.epsilon), images_org + self.epsilon)
244 | x_ce = torch.clamp(x_ce, 0, 1)
245 | return x1, x2, x_cl, x_ce, x_HFC
246 |
247 |
248 | print('=====> Building model...')
249 | bn_names = ['normal', 'pgd', 'pgd_ce']
250 | model = ResNet18(bn_names=bn_names)
251 | model = model.cuda()
252 | # tb_logger
253 | if not os.path.exists('./logger'):
254 | os.makedirs('./logger')
255 | logname = ('./logger/pretrain_{}'.format(args.name))
256 | logger = tb_logger.Logger(logdir=logname, flush_secs=2)
257 | if torch.cuda.device_count() > 1:
258 | print("=====> Let's use", torch.cuda.device_count(), "GPUs!")
259 | model = apex.parallel.convert_syncbn_model(model)
260 | model = nn.DataParallel(model)
261 | model = model.cuda()
262 | cudnn.benchmark = True
263 | else:
264 | print('single gpu version is not supported, please use multiple GPUs!')
265 | raise NotImplementedError
266 | net = AttackPGD(model, config)
267 | # Loss and optimizer
268 | ce_criterion = nn.CrossEntropyLoss(ignore_index=-1)
269 | if args.HN: # loss of hard negative
270 | contrast_criterion = SupConLoss(args, temperature=args.nce_t)
271 | else: # loss of inferior positive
272 | contrast_criterion = ori_SupConLoss(args, temperature=args.nce_t)
273 | optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.decay)
274 |
275 |
276 | # ================================================================== #
277 | # Train and Test #
278 | # ================================================================== #
279 |
280 | def train(epoch):
281 | print('\nEpoch: %d' % epoch)
282 | net.train()
283 | train_loss = 0
284 | correct = 0
285 | total = 0
286 |
287 | for batch_idx, (inputs, _, targets, ind) in enumerate(train_loader):
288 | tt = []
289 | for tt_ in targets:
290 | tt.append(tt_.to(device).long())
291 | targets = tt
292 | image_t1, image_t2, image_org = inputs
293 | image_t1 = image_t1.cuda(non_blocking=True)
294 | image_t2 = image_t2.cuda(non_blocking=True)
295 | image_org = image_org.cuda(non_blocking=True)
296 | warmup_learning_rate(args, epoch + 1, batch_idx, len(train_loader), optimizer)
297 | # attack contrast
298 | optimizer.zero_grad()
299 | if args.attack_ori:
300 | attack_criterion = ori_SupConLoss(args, temperature=args.nce_t)
301 | else:
302 | attack_criterion = contrast_criterion
303 | x1, x2, x_cl, x_ce, x_HFC = net(image_t1, image_t2, image_org, targets, attack_criterion)
304 | f_proj, f_pred = model(x_cl, bn_name='pgd', contrast=True)
305 | fce_proj, fce_pred, logits_ce = model(x_ce, bn_name='pgd_ce', contrast=True, CF=True, return_logits=True,
306 | nonlinear=False)
307 | # ======== aug1&aug2&HF ========
308 | f1_proj, f1_pred = model(x1, bn_name='normal', contrast=True)
309 | f2_proj, f2_pred = model(x2, bn_name='normal', contrast=True)
310 | f_high_proj, f_high_pred = model(x_HFC, bn_name='normal', contrast=True)
311 | features = torch.cat(
312 | [f_proj.unsqueeze(1), f1_proj.unsqueeze(1), f2_proj.unsqueeze(1), f_high_proj.unsqueeze(1)], dim=1)
313 |
314 | # ======== adaptive grad stopping ========
315 | stop_grad_sd = args.stpg_degree
316 | if args.stop_grad_adaptive >= 0:
317 | with torch.no_grad():
318 | f_ori = model(image_org, bn_name='normal', return_feat=True)
319 | f_cl = model(x_cl, bn_name='normal', return_feat=True)
320 | distance = F.pairwise_distance(f_ori, f_cl)
321 | mean_distance = distance.mean()
322 | if epoch + 1 > args.stop_grad_adaptive:
323 | if mean_distance > dis_avg.avg:
324 | pass
325 | elif mean_distance < args.d_min:
326 | stop_grad_sd = 0.5
327 | else:
328 | stop_grad_sd = (dis_avg.avg - mean_distance) / (dis_avg.avg - args.d_min) * (
329 | 0.5 - args.stpg_degree) + args.stpg_degree
330 |
331 | else:
332 | if mean_distance > args.d_max:
333 | args.d_max = mean_distance
334 | dis_avg.update(mean_distance, image_org.shape[0])
335 |
336 | stop_grad_sd_print = stop_grad_sd if isinstance(stop_grad_sd, float) else stop_grad_sd.mean()
337 |
338 | logger.log_value('stop-degree', stop_grad_sd_print, epoch * len(train_loader) + batch_idx)
339 | if batch_idx % 40 == 0:
340 | print(
341 | "dis_avg: {:.3f}, d_max: {:.3f}, distance: {:.3f}, stop_grad_sd: {}".format(dis_avg.avg, args.d_max,
342 | mean_distance,
343 | stop_grad_sd_print))
344 |
345 | contrast_loss = contrast_criterion(features, stop_grad=args.stop_grad, stop_grad_sd=stop_grad_sd)
346 | ce_loss = 0
347 | for label_idx in range(5):
348 | tgt = targets[label_idx].long()
349 | lgt = logits_ce[label_idx]
350 | ce_loss += ce_criterion(lgt, tgt) / 5.
351 |
352 | loss = contrast_loss + ce_loss * args.ce_weight
353 | loss.backward()
354 | optimizer.step()
355 |
356 | train_loss += loss.item()
357 | total += targets[0].size(0)
358 |
359 | progress_bar(batch_idx, len(train_loader),
360 | 'Loss: %.3f (%d/%d)'
361 | % (train_loss / (batch_idx + 1), correct, total))
362 |
363 | return train_loss / batch_idx, 0.
364 |
365 |
366 | # ================================================================== #
367 | # Checkpoint #
368 | # ================================================================== #
369 |
370 | # Save checkpoint
371 | def checkpoint(epoch):
372 | state = {
373 | 'model': model.state_dict(),
374 | 'epoch': epoch,
375 | 'rng_state': torch.get_rng_state()
376 | }
377 | save_dir = './checkpoint/{}'.format(args.name)
378 | if not os.path.isdir(save_dir):
379 | os.makedirs(save_dir)
380 | torch.save(state, '{}/epoch_{}.ckpt'.format(save_dir, epoch))
381 | print('=====> Saving checkpoint to {}/epoch_{}.ckpt'.format(save_dir, epoch))
382 |
383 |
384 | # ================================================================== #
385 | # Run the model #
386 | # ================================================================== #
387 |
388 |
389 | np.random.seed(args.seed)
390 | random.seed(args.seed)
391 | torch.manual_seed(args.seed)
392 | torch.cuda.manual_seed_all(args.seed)
393 |
394 | dis_avg = AverageMeter()
395 | for epoch in range(start_epoch, args.epoch + 2):
396 | adjust_learning_rate(args, optimizer, epoch + 1)
397 | train_loss, train_acc = train(epoch)
398 | logger.log_value('train_loss', train_loss, epoch)
399 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
400 | if epoch % args.save_epoch == 0:
401 | checkpoint(epoch)
402 |
--------------------------------------------------------------------------------