├── IDAA.png ├── LICENSE ├── README.md ├── SimCLR ├── eval_knn.py ├── eval_lr.py ├── main.py ├── model.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── lars.cpython-38.pyc │ │ ├── logistic_regression.cpython-38.pyc │ │ ├── nt_xent.cpython-38.pyc │ │ ├── resnet_BN.cpython-38.pyc │ │ ├── resnet_BN_imagenet.cpython-38.pyc │ │ └── simclr_BN.cpython-38.pyc │ ├── lars.py │ ├── logistic_regression.py │ ├── nt_xent.py │ ├── resnet_BN.py │ ├── resnet_BN_imagenet.py │ ├── simclr_BN.py │ └── transformations │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── simclr.cpython-38.pyc │ │ └── simclr.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── masks.cpython-38.pyc │ └── masks.py ├── set.py ├── train_vae.py └── vae.py /IDAA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/IDAA.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kaiwen Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IDAA 2 | 3 | Official implementation: 4 | - Identity-Disentangled Adversarial Augmentation for Self-Supervised Learning, ICML 2022. ([Paper](https://proceedings.mlr.press/v162/yang22s/yang22s.pdf)) 5 | 6 | 7 |
8 | 9 |

Architecture and pipeline of Identity-Disentangled Adversarial Augmentation (IDAA)

10 |
11 | 12 | For questions, you can contact (kwyang@mail.ustc.edu.cn). 13 | 14 | ## Requirements 15 | 16 | 1. [Python](https://www.python.org/) 17 | 2. [Pytorch](https://pytorch.org/) 18 | 3. [Wandb](https://wandb.ai/site) 19 | 4. [Torchvision](https://pytorch.org/vision/stable/index.html) 20 | 5. [Apex(optional)](https://github.com/NVIDIA/apex) 21 | 22 | ## Pretrain a VAE 23 | 24 | 25 | ``` 26 | python train_vae.py --dim 512 --kl 0.1 --save_dir ./results/vae_cifar10_dim512_kl0.1_simclr --mode simclr --dataset cifar10 27 | ``` 28 | 29 | ## Apply IDAA to SimCLR 30 | ``` 31 | cd SimCLR 32 | ``` 33 | 34 | SimCLR training and evaluation: 35 | ``` 36 | python main.py --seed 1 --gpu 0 --dataset cifar10 --resnet resnet18; 37 | python eval_lr.py --seed 1 --gpu 0 --dataset cifar10 --resnet resnet18 38 | ``` 39 | SimCLR+IDAA training and evaluation: 40 | ``` 41 | python main.py --adv --eps 0.1 --seed 1 --gpu 0 --dataset cifar10 --dim 512 --vae_path ../results/vae_cifar10_dim512_kl0.1_simclr/model_epoch292.pth --resnet resnet18; 42 | python eval_lr.py --adv --eps 0.1 --seed 1 --gpu 0 --dataset cifar10 --dim 512 --resnet resnet18 43 | ``` 44 | 45 | ## References 46 | We borrow some code from https://github.com/chihhuiho/CLAE. 47 | 48 | 49 | ## Citation 50 | 51 | If you find this repo useful for your research, please consider citing the paper 52 | ``` 53 | @inproceedings{yang2022identity, 54 | title={Identity-Disentangled Adversarial Augmentation for Self-supervised Learning}, 55 | author={Yang, Kaiwen and Zhou, Tianyi and Tian, Xinmei and Tao, Dacheng}, 56 | booktitle={International Conference on Machine Learning}, 57 | pages={25364--25381}, 58 | year={2022}, 59 | organization={PMLR} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /SimCLR/eval_knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import argparse 5 | import sys 6 | import os 7 | #from experiment import ex 8 | from model import load_model, save_model 9 | import wandb 10 | from modules import LogisticRegression 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | 15 | def kNN(epoch, net, trainloader, testloader, K, sigma, ndata, low_dim = 128): 16 | net.eval() 17 | total = 0 18 | correct_t = 0 19 | testsize = testloader.dataset.__len__() 20 | 21 | try: 22 | trainLabels = torch.LongTensor(trainloader.dataset.targets).cuda() 23 | except: 24 | trainLabels = torch.LongTensor(trainloader.dataset.labels).cuda() 25 | trainFeatures = np.zeros((low_dim, ndata)) 26 | trainFeatures = torch.Tensor(trainFeatures).cuda() 27 | C = trainLabels.max() + 1 28 | C = np.int(C) 29 | 30 | with torch.no_grad(): 31 | transform_bak = trainloader.dataset.transform 32 | trainloader.dataset.transform = testloader.dataset.transform 33 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=256, shuffle=False, num_workers=4) 34 | for batch_idx, (inputs, targets) in tqdm(enumerate(temploader)): 35 | targets = targets.cuda() 36 | batchSize = inputs.size(0) 37 | _, features = net(inputs.cuda()) 38 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.t() 39 | 40 | 41 | trainloader.dataset.transform = transform_bak 42 | # 43 | 44 | 45 | top1 = 0. 46 | top5 = 0. 47 | with torch.no_grad(): 48 | retrieval_one_hot = torch.zeros(K, C).cuda() 49 | for batch_idx, (inputs, targets) in enumerate(testloader): 50 | 51 | targets = targets.cuda() 52 | batchSize = inputs.size(0) 53 | _, features = net(inputs.cuda()) 54 | total += targets.size(0) 55 | 56 | dist = torch.mm(features, trainFeatures) 57 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True) 58 | candidates = trainLabels.view(1,-1).expand(batchSize, -1) 59 | retrieval = torch.gather(candidates, 1, yi) 60 | retrieval_one_hot.resize_(batchSize * K, C).zero_() 61 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 62 | yd_transform = yd.clone().div_(sigma).exp_() 63 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1) 64 | 65 | _, predictions = probs.sort(1, True) 66 | # Find which predictions match the target 67 | correct = predictions.eq(targets.data.view(-1,1)) 68 | 69 | top1 = top1 + correct.narrow(1,0,1).sum().item() 70 | 71 | print(top1*100./total) 72 | 73 | return top1*100./total -------------------------------------------------------------------------------- /SimCLR/eval_lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import argparse 5 | import os 6 | from model import load_model, save_model 7 | import sys 8 | import wandb 9 | from modules import LogisticRegression 10 | sys.path.append('.') 11 | sys.path.append('..') 12 | from set import * 13 | from utils import * 14 | 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch Seen Testing Category Training') 17 | parser.add_argument('--batch_size', default=256, type=int, 18 | metavar='B', help='training batch size') 19 | parser.add_argument('--logistic_batch_size', default=256, type=int, 20 | metavar='B', help='logistic_batch_size batch size') 21 | parser.add_argument('--logistic_epochs', default=1000, type=int, help='logistic_epochs') 22 | parser.add_argument('--workers', default=4, type=int, help='workers') 23 | parser.add_argument('--epochs', default=300, type=int,help='epochs') 24 | parser.add_argument('--resnet', default="resnet18", type=str, help="resnet") 25 | parser.add_argument('--normalize', default=True, action='store_true', help='normalize') 26 | parser.add_argument('--projection_dim', default=64, type=int,help='projection_dim') 27 | parser.add_argument('--optimizer', default="Adam", type=str, help="optimizer") 28 | parser.add_argument('--weight_decay', default=1.0e-6, type=float, help='weight_decay') 29 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature') 30 | parser.add_argument('--model_path', default='checkpoint/', type=str, 31 | help='model save path') 32 | parser.add_argument('--model_dir', default='checkpoint/', type=str, 33 | help='model save path') 34 | parser.add_argument('--lr', default=3e-4, type=float, help='learning rate') 35 | parser.add_argument('--dataset', default='cifar10', 36 | help='[cifar10, cifar100]') 37 | parser.add_argument('--gpu', default='0', type=str, 38 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 39 | parser.add_argument('--trial', type=int, help='trial') 40 | parser.add_argument('--adv', default=False, action='store_true', help='adversarial exmaple') 41 | parser.add_argument('--eps', default=0.01, type=float, help='eps for adversarial') 42 | parser.add_argument('--bn_adv_momentum', default=0.01, type=float, help='batch norm momentum for advprop') 43 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for contrastive loss with adversarial example') 44 | parser.add_argument('--debug', default=False, action='store_true', help='debug mode') 45 | parser.add_argument('--seed', default=1, type=int, help='seed') 46 | parser.add_argument('--dim', default=512, type=int, help='CNN_embed_dim') 47 | args = parser.parse_args() 48 | set_random_seed(args.seed) 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | 51 | 52 | def train(args, loader, simclr_model, model, criterion, optimizer): 53 | loss_epoch = 0 54 | accuracy_epoch = 0 55 | for step, (x, y) in enumerate(loader): 56 | optimizer.zero_grad() 57 | 58 | x = x.to(args.device) 59 | y = y.to(args.device) 60 | 61 | # get encoding 62 | with torch.no_grad(): 63 | h, z = simclr_model(x) 64 | # h = 512 65 | # z = 64 66 | 67 | output = model(h) 68 | loss = criterion(output, y) 69 | 70 | predicted = output.argmax(1) 71 | acc = (predicted == y).sum().item() / y.size(0) 72 | accuracy_epoch += acc 73 | 74 | loss.backward() 75 | optimizer.step() 76 | 77 | loss_epoch += loss.item() 78 | if step % 100 == 0: 79 | print(f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}") 80 | 81 | if args.debug: 82 | break 83 | 84 | return loss_epoch, accuracy_epoch 85 | 86 | 87 | def test(args, loader, simclr_model, model, criterion, optimizer): 88 | loss_epoch = 0 89 | accuracy_epoch = 0 90 | model.eval() 91 | for step, (x, y) in enumerate(loader): 92 | model.zero_grad() 93 | 94 | x = x.to(args.device) 95 | y = y.to(args.device) 96 | 97 | # get encoding 98 | with torch.no_grad(): 99 | h, z = simclr_model(x) 100 | # h = 512 101 | # z = 64 102 | 103 | output = model(h) 104 | loss = criterion(output, y) 105 | 106 | predicted = output.argmax(1) 107 | acc = (predicted == y).sum().item() / y.size(0) 108 | accuracy_epoch += acc 109 | 110 | loss_epoch += loss.item() 111 | 112 | return loss_epoch, accuracy_epoch 113 | 114 | 115 | def main(): 116 | args.device = device = 'cuda' if torch.cuda.is_available() else 'cpu' 117 | 118 | root = "../../data" 119 | 120 | if args.dataset == 'tinyImagenet': 121 | transform = transforms.Compose([ 122 | torchvision.transforms.Resize((224, 224)), 123 | transforms.ToTensor(), 124 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 125 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 126 | std=[0.229, 0.224, 0.225]) 127 | ]) 128 | data = 'imagenet' 129 | elif args.dataset == 'miniImagenet': 130 | transform = transforms.Compose([ 131 | torchvision.transforms.Resize((84, 84)), 132 | transforms.ToTensor(), 133 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 134 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 135 | std=[0.229, 0.224, 0.225]) 136 | ]) 137 | data = 'imagenet' 138 | elif args.dataset == 'imagenet100': 139 | transform = transforms.Compose([ 140 | torchvision.transforms.Resize((224, 224)), 141 | transforms.ToTensor(), 142 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 143 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 144 | std=[0.229, 0.224, 0.225]) 145 | ]) 146 | data = 'imagenet' 147 | else: 148 | transform = transforms.Compose([ 149 | torchvision.transforms.Resize(size=32), 150 | transforms.ToTensor(), 151 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 152 | ]) 153 | data = 'non_imagenet' 154 | 155 | if args.dataset == "cifar10" : 156 | train_dataset = torchvision.datasets.CIFAR10( 157 | root, train=True, download=True, transform=transform 158 | ) 159 | test_dataset = torchvision.datasets.CIFAR10( 160 | root, train=False, download=True, transform=transform 161 | ) 162 | elif args.dataset == "cifar100": 163 | train_dataset = torchvision.datasets.CIFAR100( 164 | root, train=True, download=True, transform=transform 165 | ) 166 | test_dataset = torchvision.datasets.CIFAR100( 167 | root, train=False, download=True, transform=transform 168 | ) 169 | else: 170 | raise NotImplementedError 171 | 172 | train_loader = torch.utils.data.DataLoader( 173 | train_dataset, 174 | batch_size=args.logistic_batch_size, 175 | shuffle=True, 176 | drop_last=True, 177 | num_workers=args.workers, 178 | ) 179 | 180 | test_loader = torch.utils.data.DataLoader( 181 | test_dataset, 182 | batch_size=args.logistic_batch_size, 183 | shuffle=False, 184 | drop_last=True, 185 | num_workers=args.workers, 186 | ) 187 | 188 | log_dir = "log_eval/" + args.dataset + '_LR_log/' 189 | 190 | if not os.path.isdir(log_dir): 191 | os.makedirs(log_dir) 192 | 193 | suffix = args.dataset + '_{}_batch_{}'.format(args.resnet, args.batch_size) 194 | if args.adv: 195 | suffix = suffix + '_alpha_{}_adv_eps_{}'.format(args.alpha, args.eps) 196 | 197 | suffix = suffix + '_proj_dim_{}'.format(args.projection_dim) 198 | suffix = suffix + '_bn_adv_momentum_{}_seed_{}'.format(args.bn_adv_momentum, args.seed) 199 | wandb.init(config=args, name='LR/' + suffix.replace("_log/", '')) 200 | args.model_dir = args.model_dir + args.dataset + '/' 201 | print("Loading {}".format(args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs))) 202 | if args.adv: 203 | simclr_model, _, _ = load_model(args, train_loader, reload_model=True , load_path = args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs), bn_adv_flag = True, bn_adv_momentum = args.bn_adv_momentum, data=data) 204 | else: 205 | simclr_model, _, _ = load_model(args, train_loader, reload_model=True , load_path = args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs), bn_adv_flag = False, bn_adv_momentum = args.bn_adv_momentum, data=data) 206 | 207 | test_log_file = open(log_dir + suffix + '.txt', "w") 208 | simclr_model = simclr_model.to(args.device) 209 | simclr_model.eval() 210 | 211 | ## Logistic Regression 212 | if args.dataset == "cifar100": 213 | n_classes = 100 # stl-10 214 | else: 215 | n_classes = 10 216 | 217 | model = LogisticRegression(simclr_model.n_features, n_classes) 218 | model = model.to(args.device) 219 | 220 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) 221 | criterion = torch.nn.CrossEntropyLoss() 222 | 223 | 224 | best_acc = 0 225 | for epoch in range(args.logistic_epochs): 226 | loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer) 227 | print("Train Epoch [{}]\t Loss: {}\t Accuracy: {}".format(epoch, loss_epoch / len(train_loader), accuracy_epoch / len(train_loader)), file = test_log_file) 228 | test_log_file.flush() 229 | wandb.log({'Train/Loss': loss_epoch / len(train_loader), 230 | 'Train/ACC': accuracy_epoch / len(train_loader)}) 231 | 232 | # final testing 233 | test_loss_epoch, test_accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer) 234 | test_current_acc = test_accuracy_epoch / len(test_loader) 235 | if test_current_acc > best_acc: 236 | best_acc = test_current_acc 237 | print("Test Epoch [{}]\t Loss: {}\t Accuracy: {}\t Best Accuracy: {}".format(epoch, test_loss_epoch / len(test_loader), test_current_acc, best_acc), file = test_log_file) 238 | wandb.log({'Test/Loss': test_loss_epoch / len(test_loader), 239 | 'Test/ACC': test_current_acc, 240 | 'Test/BestACC': best_acc}) 241 | test_log_file.flush() 242 | 243 | if args.debug: 244 | break 245 | print("Final \t Best Accuracy: {}".format(epoch, best_acc), file = test_log_file) 246 | test_log_file.flush() 247 | if not os.path.isdir("checkpoint/" + args.dataset + '_eval/'): 248 | os.makedirs("checkpoint/" + args.dataset + '_eval/') 249 | save_model("checkpoint/" + args.dataset + '_eval/' + suffix, model, optimizer, 0) 250 | 251 | 252 | 253 | if __name__ == "__main__": 254 | main() 255 | -------------------------------------------------------------------------------- /SimCLR/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import argparse 5 | import sys 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import wandb 9 | import torchvision.transforms as transforms 10 | from model import load_model, save_model 11 | from modules import NT_Xent 12 | from modules.transformations import TransformsSimCLR 13 | from utils import mask_correlated_samples 14 | from eval_knn import kNN 15 | sys.path.append('..') 16 | from set import * 17 | from vae import * 18 | from apex import amp 19 | 20 | 21 | parser = argparse.ArgumentParser(description=' Seen Testing Category Training') 22 | parser.add_argument('--batch_size', default=256, type=int, 23 | metavar='B', help='training batch size') 24 | parser.add_argument('--dim', default=512, type=int, help='CNN_embed_dim') 25 | parser.add_argument('--workers', default=4, type=int, help='workers') 26 | parser.add_argument('--epochs', default=300, type=int, help='epochs') 27 | parser.add_argument('--save_epochs', default=100, type=int, help='save epochs') 28 | parser.add_argument('--resnet', default="resnet18", type=str, help="resnet") 29 | parser.add_argument('--normalize', default=True, action='store_true', help='normalize') 30 | parser.add_argument('--projection_dim', default=64, type=int, help='projection_dim') 31 | parser.add_argument('--optimizer', default="Adam", type=str, help="optimizer") 32 | parser.add_argument('--weight_decay', default=1.0e-6, type=float, help='weight_decay') 33 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature') 34 | parser.add_argument('--model_path', default='log/', type=str, 35 | help='model save path') 36 | parser.add_argument('--model_dir', default='checkpoint/', type=str, 37 | help='model save path') 38 | 39 | parser.add_argument('--dataset', default='cifar10', 40 | help='[cifar10, cifar100]') 41 | parser.add_argument('--gpu', default='0', type=str, 42 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 43 | parser.add_argument('--adv', default=False, action='store_true', help='adversarial exmaple') 44 | parser.add_argument('--eps', default=0.01, type=float, help='eps for adversarial') 45 | parser.add_argument('--bn_adv_momentum', default=0.01, type=float, help='batch norm momentum for advprop') 46 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for contrastive loss with adversarial example') 47 | parser.add_argument('--debug', default=False, action='store_true', help='debug mode') 48 | parser.add_argument('--vae_path', 49 | default='../results/vae_dim512_kl0.1_simclr/model_epoch92.pth', 50 | type=str, help='vae_path') 51 | parser.add_argument('--seed', default=1, type=int, help='seed') 52 | parser.add_argument("--amp", action="store_true", 53 | help="use 16-bit (mixed) precision through NVIDIA apex AMP") 54 | parser.add_argument("--opt_level", type=str, default="O1", 55 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 56 | "See details at https://nvidia.github.io/apex/amp.html") 57 | args = parser.parse_args() 58 | print(args) 59 | set_random_seed(args.seed) 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 61 | 62 | 63 | def gen_adv(model, vae, x_i, criterion, optimizer): 64 | x_i = x_i.detach() 65 | h_i, z_i = model(x_i, adv=True) 66 | 67 | with torch.no_grad(): 68 | z, gx, _, _ = vae(x_i) 69 | variable_bottle = Variable(z.detach(), requires_grad=True) 70 | adv_gx = vae(variable_bottle, True) 71 | x_j_adv = adv_gx + (x_i - gx).detach() 72 | h_j_adv, z_j_adv = model(x_j_adv, adv=True) 73 | tmp_loss = criterion(z_i, z_j_adv) 74 | if args.amp: 75 | with amp.scale_loss(tmp_loss, optimizer) as scaled_loss: 76 | scaled_loss.backward() 77 | else: 78 | tmp_loss.backward() 79 | 80 | with torch.no_grad(): 81 | sign_grad = variable_bottle.grad.data.sign() 82 | variable_bottle.data = variable_bottle.data + args.eps * sign_grad 83 | adv_gx = vae(variable_bottle, True) 84 | x_j_adv = adv_gx + (x_i - gx).detach() 85 | x_j_adv.requires_grad = False 86 | x_j_adv.detach() 87 | return x_j_adv, gx 88 | 89 | 90 | def train(args, epoch, train_loader, model, vae, criterion, optimizer): 91 | model.train() 92 | loss_epoch = 0 93 | for step, ((x_i, x_j), _) in enumerate(train_loader): 94 | 95 | optimizer.zero_grad() 96 | x_i = x_i.to(args.device) 97 | x_j = x_j.to(args.device) 98 | 99 | # positive pair, with encoding 100 | h_i, z_i = model(x_i) 101 | if args.adv: 102 | x_j_adv, gx = gen_adv(model, vae, x_i, criterion, optimizer) 103 | 104 | optimizer.zero_grad() 105 | h_j, z_j = model(x_j) 106 | loss_og = criterion(z_i, z_j) 107 | if args.adv: 108 | _, z_j_adv = model(x_j_adv, adv=True) 109 | loss_adv = criterion(z_i, z_j_adv) 110 | loss = loss_og + args.alpha * loss_adv 111 | else: 112 | loss = loss_og 113 | loss_adv = loss_og 114 | if args.amp: 115 | with amp.scale_loss(loss, optimizer) as scaled_loss: 116 | scaled_loss.backward() 117 | else: 118 | loss.backward() 119 | 120 | optimizer.step() 121 | 122 | if step % 50 == 0: 123 | print(f"[Epoch]: {epoch} [{step}/{len(train_loader)}]\t Loss: {loss.item():.3f} Loss_og: {loss_og.item():.3f} Loss_adv: {loss_adv.item():.3f}") 124 | 125 | loss_epoch += loss.item() 126 | args.global_step += 1 127 | 128 | if args.debug: 129 | break 130 | if step % 10 == 0: 131 | wandb.log({'loss_og': loss_og.item(), 132 | 'loss_adv': loss_adv.item(), 133 | 'lr': optimizer.param_groups[0]['lr']}) 134 | if args.global_step % 1000 == 0: 135 | if args.adv: 136 | reconst_images(x_i, gx, x_j_adv) 137 | return loss_epoch 138 | 139 | 140 | def main(): 141 | args.device = device = 'cuda' if torch.cuda.is_available() else 'cpu' 142 | 143 | train_sampler = None 144 | if args.dataset == "cifar10": 145 | root = "../../data" 146 | train_dataset = torchvision.datasets.CIFAR10( 147 | root, download=True, transform=TransformsSimCLR() 148 | ) 149 | data = 'non_imagenet' 150 | transform_test = transforms.Compose([ 151 | transforms.Resize(size=32), 152 | transforms.ToTensor(), 153 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 154 | ]) 155 | testset = torchvision.datasets.CIFAR10(root='../../data', train=False, download=True, transform=transform_test) 156 | vae = CVAE_cifar_withbn(128, args.dim) 157 | elif args.dataset == "cifar100": 158 | root = "../../data" 159 | train_dataset = torchvision.datasets.CIFAR100( 160 | root, download=True, transform=TransformsSimCLR() 161 | ) 162 | data = 'non_imagenet' 163 | transform_test = transforms.Compose([ 164 | transforms.Resize(size=32), 165 | transforms.ToTensor(), 166 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 167 | ]) 168 | testset = torchvision.datasets.CIFAR100(root='../../data', train=False, download=True, transform=transform_test) 169 | vae = CVAE_cifar_withbn(128, args.dim) 170 | else: 171 | raise NotImplementedError 172 | 173 | train_loader = torch.utils.data.DataLoader( 174 | train_dataset, 175 | batch_size=args.batch_size, 176 | shuffle=(train_sampler is None), 177 | drop_last=True, 178 | num_workers=args.workers, 179 | sampler=train_sampler, 180 | ) 181 | testloader = torch.utils.data.DataLoader(testset, 182 | batch_size=100, shuffle=False, num_workers=4) 183 | 184 | ndata = train_dataset.__len__() 185 | log_dir = "log/" + args.dataset + '_log/' 186 | 187 | if not os.path.isdir(log_dir): 188 | os.makedirs(log_dir) 189 | 190 | suffix = args.dataset + '_{}_batch_{}'.format(args.resnet, args.batch_size) 191 | if args.adv: 192 | suffix = suffix + '_alpha_{}_adv_eps_{}'.format(args.alpha, args.eps) 193 | model, optimizer, scheduler = load_model(args, train_loader, bn_adv_flag=True, 194 | bn_adv_momentum=args.bn_adv_momentum, data=data) 195 | else: 196 | model, optimizer, scheduler = load_model(args, train_loader, bn_adv_flag=False, 197 | bn_adv_momentum=args.bn_adv_momentum, data=data) 198 | 199 | vae.load_state_dict(torch.load(args.vae_path)) 200 | vae.to(args.device) 201 | vae.eval() 202 | if args.amp: 203 | [model, vae], optimizer = amp.initialize( 204 | [model, vae], optimizer, opt_level=args.opt_level) 205 | 206 | suffix = suffix + '_proj_dim_{}'.format(args.projection_dim) 207 | suffix = suffix + '_bn_adv_momentum_{}_seed_{}'.format(args.bn_adv_momentum, args.seed) 208 | wandb.init(config=args, name=suffix.replace("_log/", '')) 209 | 210 | test_log_file = open(log_dir + suffix + '.txt', "w") 211 | 212 | if not os.path.isdir(args.model_dir): 213 | os.mkdir(args.model_dir) 214 | args.model_dir = args.model_dir + args.dataset + '/' 215 | if not os.path.isdir(args.model_dir): 216 | os.mkdir(args.model_dir) 217 | 218 | mask = mask_correlated_samples(args) 219 | criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device) 220 | 221 | args.global_step = 0 222 | args.current_epoch = 0 223 | best_acc = 0 224 | for epoch in range(0, args.epochs): 225 | loss_epoch = train(args, epoch, train_loader, model, vae, criterion, optimizer) 226 | model.eval() 227 | if epoch > 10: 228 | scheduler.step() 229 | print('epoch: {}% \t (loss: {}%)'.format(epoch, loss_epoch / len(train_loader)), file=test_log_file) 230 | print('----------Evaluation---------') 231 | start = time.time() 232 | acc = kNN(epoch, model, train_loader, testloader, 200, args.temperature, ndata, low_dim=args.projection_dim) 233 | print("Evaluation Time: '{}'s".format(time.time() - start)) 234 | 235 | if acc >= best_acc: 236 | print('Saving..') 237 | state = { 238 | 'model': model.state_dict(), 239 | 'acc': acc, 240 | 'epoch': epoch, 241 | } 242 | if not os.path.isdir(args.model_dir): 243 | os.mkdir(args.model_dir) 244 | torch.save(state, args.model_dir + suffix + '_best.t') 245 | best_acc = acc 246 | print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc)) 247 | print('[Epoch]: {}'.format(epoch), file=test_log_file) 248 | print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc), file=test_log_file) 249 | wandb.log({'acc': acc}) 250 | test_log_file.flush() 251 | 252 | args.current_epoch += 1 253 | if args.debug: 254 | break 255 | if epoch % 50 == 0: 256 | save_model(args.model_dir + suffix, model, optimizer, epoch) 257 | 258 | save_model(args.model_dir + suffix, model, optimizer, args.epochs) 259 | 260 | 261 | def reconst_images(x_i, gx, x_j_adv): 262 | grid_X = torchvision.utils.make_grid(x_i[32:96].data, nrow=8, padding=2, normalize=True) 263 | wandb.log({"X.jpg": [wandb.Image(grid_X)]}, commit=False) 264 | grid_GX = torchvision.utils.make_grid(gx[32:96].data, nrow=8, padding=2, normalize=True) 265 | wandb.log({"GX.jpg": [wandb.Image(grid_GX)]}, commit=False) 266 | grid_RX = torchvision.utils.make_grid((x_i[32:96] - gx[32:96]).data, nrow=8, padding=2, normalize=True) 267 | wandb.log({"RX.jpg": [wandb.Image(grid_RX)]}, commit=False) 268 | grid_AdvX = torchvision.utils.make_grid(x_j_adv[32:96].data, nrow=8, padding=2, normalize=True) 269 | wandb.log({"AdvX.jpg": [wandb.Image(grid_AdvX)]}, commit=False) 270 | grid_delta = torchvision.utils.make_grid((x_j_adv - x_i)[32:96].data, nrow=8, padding=2, normalize=True) 271 | wandb.log({"Delta.jpg": [wandb.Image(grid_delta)]}, commit=False) 272 | wandb.log({'l2_norm': torch.mean((x_j_adv - x_i).reshape(x_i.shape[0], -1).norm(dim=1)), 273 | 'linf_norm': torch.mean((x_j_adv - x_i).reshape(x_i.shape[0], -1).abs().max(dim=1)[0]) 274 | }, commit=False) 275 | 276 | 277 | if __name__ == "__main__": 278 | main() 279 | -------------------------------------------------------------------------------- /SimCLR/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from modules import SimCLR_BN 4 | 5 | 6 | def load_model(args, loader, reload_model=False, load_path = None, bn_adv_flag=False, bn_adv_momentum = 0.01, data='non_imagenet'): 7 | 8 | model = SimCLR_BN(args, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum, data = data) 9 | 10 | if reload_model: 11 | if os.path.isfile(load_path): 12 | model_fp = os.path.join(load_path) 13 | else: 14 | print("No file to load") 15 | return 16 | model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_fp, map_location=lambda storage, loc: storage).items()}) 17 | 18 | #model = model.to(args.device) 19 | model.cuda() 20 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS 21 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 22 | optimizer, args.epochs, eta_min=0, last_epoch=-1 23 | ) 24 | return model, optimizer, scheduler 25 | 26 | 27 | def save_model(model_dir, model, optimizer, epoch): 28 | 29 | 30 | # To save a DataParallel model generically, save the model.module.state_dict(). 31 | # This way, you have the flexibility to load the model any way you want to any device you want. 32 | if isinstance(model, torch.nn.DataParallel): 33 | torch.save(model.module.state_dict(), model_dir + '_epoch_{}.pt'.format(epoch)) 34 | else: 35 | torch.save(model.state_dict(), model_dir + '_epoch_{}.pt'.format(epoch)) 36 | -------------------------------------------------------------------------------- /SimCLR/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .simclr_BN import SimCLR_BN 2 | from .nt_xent import NT_Xent 3 | from .logistic_regression import LogisticRegression 4 | from .lars import LARS 5 | -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/lars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/lars.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/logistic_regression.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/logistic_regression.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/nt_xent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/nt_xent.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/resnet_BN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/resnet_BN.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/resnet_BN_imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/resnet_BN_imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/__pycache__/simclr_BN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/simclr_BN.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/lars.py: -------------------------------------------------------------------------------- 1 | """ 2 | LARS: Layer-wise Adaptive Rate Scaling 3 | 4 | Converted from TensorFlow to PyTorch 5 | https://github.com/google-research/simclr/blob/master/lars_optimizer.py 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer, required 10 | import re 11 | 12 | EETA_DEFAULT = 0.001 13 | 14 | class LARS(Optimizer): 15 | """ 16 | Layer-wise Adaptive Rate Scaling for large batch training. 17 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You, 18 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) 19 | """ 20 | 21 | def __init__( 22 | self, 23 | params, 24 | lr=required, 25 | momentum=0.9, 26 | use_nesterov=False, 27 | weight_decay=0.0, 28 | exclude_from_weight_decay=None, 29 | exclude_from_layer_adaptation=None, 30 | classic_momentum=True, 31 | eeta=EETA_DEFAULT, 32 | ): 33 | """Constructs a LARSOptimizer. 34 | Args: 35 | lr: A `float` for learning rate. 36 | momentum: A `float` for momentum. 37 | use_nesterov: A 'Boolean' for whether to use nesterov momentum. 38 | weight_decay: A `float` for weight decay. 39 | exclude_from_weight_decay: A list of `string` for variable screening, if 40 | any of the string appears in a variable's name, the variable will be 41 | excluded for computing weight decay. For example, one could specify 42 | the list like ['batch_normalization', 'bias'] to exclude BN and bias 43 | from weight decay. 44 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but 45 | for layer adaptation. If it is None, it will be defaulted the same as 46 | exclude_from_weight_decay. 47 | classic_momentum: A `boolean` for whether to use classic (or popular) 48 | momentum. The learning rate is applied during momeuntum update in 49 | classic momentum, but after momentum for popular momentum. 50 | eeta: A `float` for scaling of learning rate when computing trust ratio. 51 | name: The name for the scope. 52 | """ 53 | 54 | self.epoch = 0 55 | defaults = dict( 56 | lr=lr, 57 | momentum=momentum, 58 | use_nesterov=use_nesterov, 59 | weight_decay=weight_decay, 60 | exclude_from_weight_decay=exclude_from_weight_decay, 61 | exclude_from_layer_adaptation=exclude_from_layer_adaptation, 62 | classic_momentum=classic_momentum, 63 | eeta=eeta, 64 | ) 65 | 66 | super(LARS, self).__init__(params, defaults) 67 | self.lr = lr 68 | self.momentum = momentum 69 | self.weight_decay = weight_decay 70 | self.use_nesterov = use_nesterov 71 | self.classic_momentum = classic_momentum 72 | self.eeta = eeta 73 | self.exclude_from_weight_decay = exclude_from_weight_decay 74 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 75 | # arg is None. 76 | if exclude_from_layer_adaptation: 77 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 78 | else: 79 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 80 | 81 | def step(self, epoch=None, closure=None): 82 | loss = None 83 | if closure is not None: 84 | loss = closure() 85 | 86 | if epoch is None: 87 | epoch = self.epoch 88 | self.epoch += 1 89 | 90 | for group in self.param_groups: 91 | weight_decay = group["weight_decay"] 92 | momentum = group["momentum"] 93 | eeta = group["eeta"] 94 | lr = group["lr"] 95 | 96 | for p in group["params"]: 97 | if p.grad is None: 98 | continue 99 | 100 | param = p.data 101 | grad = p.grad.data 102 | 103 | param_state = self.state[p] 104 | 105 | # TODO: get param names 106 | # if self._use_weight_decay(param_name): 107 | grad += self.weight_decay * param 108 | 109 | if self.classic_momentum: 110 | trust_ratio = 1.0 111 | 112 | # TODO: get param names 113 | # if self._do_layer_adaptation(param_name): 114 | w_norm = torch.norm(param) 115 | g_norm = torch.norm(grad) 116 | 117 | device = g_norm.get_device() 118 | trust_ratio = torch.where( 119 | w_norm.ge(0), 120 | torch.where(g_norm.ge(0), (self.eeta * w_norm / g_norm), torch.Tensor([1.0]).to(device)), 121 | torch.Tensor([1.0]).to(device), 122 | ).item() 123 | 124 | scaled_lr = lr * trust_ratio 125 | if "momentum_buffer" not in param_state: 126 | next_v = param_state["momentum_buffer"] = torch.zeros_like( 127 | p.data 128 | ) 129 | else: 130 | next_v = param_state["momentum_buffer"] 131 | 132 | next_v.mul_(momentum).add_(scaled_lr, grad) 133 | if self.use_nesterov: 134 | update = (self.momentum * next_v) + (scaled_lr * grad) 135 | else: 136 | update = next_v 137 | 138 | p.data.add_(-update) 139 | else: 140 | raise NotImplementedError 141 | 142 | return loss 143 | 144 | def _use_weight_decay(self, param_name): 145 | """Whether to use L2 weight decay for `param_name`.""" 146 | if not self.weight_decay: 147 | return False 148 | if self.exclude_from_weight_decay: 149 | for r in self.exclude_from_weight_decay: 150 | if re.search(r, param_name) is not None: 151 | return False 152 | return True 153 | 154 | def _do_layer_adaptation(self, param_name): 155 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 156 | if self.exclude_from_layer_adaptation: 157 | for r in self.exclude_from_layer_adaptation: 158 | if re.search(r, param_name) is not None: 159 | return False 160 | return True 161 | -------------------------------------------------------------------------------- /SimCLR/modules/logistic_regression.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class LogisticRegression(nn.Module): 4 | 5 | def __init__(self, n_features, n_classes): 6 | super(LogisticRegression, self).__init__() 7 | 8 | self.model = nn.Linear(n_features, n_classes) 9 | 10 | def forward(self, x): 11 | return self.model(x) -------------------------------------------------------------------------------- /SimCLR/modules/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | def gen_mask(k, feat_dim): 6 | mask = None 7 | for i in range(k): 8 | tmp_mask = torch.triu(torch.randint(0, 2, (feat_dim, feat_dim)), 1) 9 | tmp_mask = tmp_mask + torch.triu(1-tmp_mask,1).t() 10 | tmp_mask = tmp_mask.view(tmp_mask.shape[0], tmp_mask.shape[1],1) 11 | mask = tmp_mask if mask is None else torch.cat([mask,tmp_mask],2) 12 | return mask 13 | 14 | 15 | def entropy(prob): 16 | # assume m x m x k input 17 | return -torch.sum(prob*torch.log(prob),1) 18 | 19 | 20 | class NT_Xent(nn.Module): 21 | 22 | def __init__(self, batch_size, temperature, mask, device): 23 | super(NT_Xent, self).__init__() 24 | self.batch_size = batch_size 25 | self.temperature = temperature 26 | self.mask = mask 27 | self.device = device 28 | 29 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 30 | self.similarity_f = nn.CosineSimilarity(dim=2) 31 | 32 | def forward(self, z_i, z_j): 33 | """ 34 | We do not sample negative examples explicitly. 35 | Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. 36 | """ 37 | p1 = torch.cat((z_i, z_j), dim=0) 38 | sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature 39 | 40 | 41 | sim_i_j = torch.diag(sim, self.batch_size) 42 | sim_j_i = torch.diag(sim, -self.batch_size) 43 | 44 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2, 1) 45 | negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1) 46 | 47 | 48 | labels = torch.zeros(self.batch_size * 2).to(self.device).long() 49 | logits = torch.cat((positive_samples, negative_samples), dim=1) 50 | 51 | 52 | loss = self.criterion(logits, labels) 53 | loss /= 2 * self.batch_size 54 | 55 | return loss 56 | -------------------------------------------------------------------------------- /SimCLR/modules/resnet_BN.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | from torch.autograd import Variable 14 | 15 | class Normalize(nn.Module): 16 | 17 | def __init__(self, power=2): 18 | super(Normalize, self).__init__() 19 | self.power = power 20 | 21 | def forward(self, x): 22 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 23 | out = x.div(norm) 24 | return out 25 | 26 | class MySequential(nn.Sequential): 27 | def forward(self, x, adv): 28 | for module in self._modules.values(): 29 | x = module(x, adv=adv) 30 | return x 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, in_planes, planes, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01): 36 | super(BasicBlock, self).__init__() 37 | self.bn_adv_momentum = bn_adv_momentum 38 | self.bn_adv_flag = bn_adv_flag 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | if self.bn_adv_flag: 42 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 43 | 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 45 | 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | if self.bn_adv_flag: 48 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 49 | 50 | self.shortcut = nn.Sequential() 51 | self.shortcut_bn = None 52 | self.shortcut_bn_adv = None 53 | if stride != 1 or in_planes != self.expansion*planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 56 | ) 57 | self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes) 58 | if self.bn_adv_flag: 59 | self.shortcut_bn_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum) 60 | 61 | def forward(self, x, adv=False): 62 | if adv and self.bn_adv_flag: 63 | out = F.relu(self.bn1_adv(self.conv1(x))) 64 | out = self.conv2(out) 65 | out = self.bn2_adv(out) 66 | if self.shortcut_bn_adv: 67 | out += self.shortcut_bn_adv(self.shortcut(x)) 68 | else: 69 | out += self.shortcut(x) 70 | else: 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | if self.shortcut_bn: 75 | out += self.shortcut_bn(self.shortcut(x)) 76 | else: 77 | out += self.shortcut(x) 78 | 79 | out = F.relu(out) 80 | return out 81 | 82 | 83 | class Bottleneck(nn.Module): 84 | expansion = 4 85 | 86 | def __init__(self, in_planes, planes, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01): 87 | super(Bottleneck, self).__init__() 88 | self.bn_adv_momentum = bn_adv_momentum 89 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 90 | self.bn_adv_flag = bn_adv_flag 91 | 92 | self.bn1 = nn.BatchNorm2d(planes) 93 | if self.bn_adv_flag: 94 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 95 | 96 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 97 | 98 | self.bn2 = nn.BatchNorm2d(planes) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | if self.bn_adv_flag: 101 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 102 | 103 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 104 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 105 | if self.bn_adv_flag: 106 | self.bn3_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum) 107 | 108 | self.shortcut = nn.Sequential() 109 | self.shortcut_bn = None 110 | self.shortcut_bn_adv = None 111 | if stride != 1 or in_planes != self.expansion*planes: 112 | self.shortcut = nn.Sequential( 113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 114 | ) 115 | self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes) 116 | if self.bn_adv_flag: 117 | self.shortcut_bn_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum) 118 | 119 | def forward(self, x, adv=False): 120 | 121 | if adv and self.bn_adv_flag: 122 | 123 | out = F.relu(self.bn1_adv(self.conv1(x))) 124 | out = F.relu(self.bn2_adv(self.conv2(out))) 125 | out = self.bn3_adv(self.conv3(out)) 126 | if self.shortcut_bn_adv: 127 | out += self.shortcut_bn_adv(self.shortcut(x)) 128 | else: 129 | out += self.shortcut(x) 130 | else: 131 | 132 | out = F.relu(self.bn1(self.conv1(x))) 133 | out = F.relu(self.bn2(self.conv2(out))) 134 | out = self.bn3(self.conv3(out)) 135 | if self.shortcut_bn: 136 | out += self.shortcut_bn(self.shortcut(x)) 137 | else: 138 | out += self.shortcut(x) 139 | 140 | out = F.relu(out) 141 | return out 142 | 143 | 144 | class ResNetAdvProp_all(nn.Module): 145 | def __init__(self, block, num_blocks, pool_len =4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 146 | super(ResNetAdvProp_all, self).__init__() 147 | self.in_planes = 64 148 | self.bn_adv_momentum = bn_adv_momentum 149 | self.bn_adv_flag = bn_adv_flag 150 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 151 | self.bn_adv_flag = bn_adv_flag 152 | 153 | self.bn1 = nn.BatchNorm2d(64) 154 | if bn_adv_flag: 155 | self.bn1_adv = nn.BatchNorm2d(64, momentum = self.bn_adv_momentum) 156 | 157 | 158 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum) 159 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum) 160 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum) 161 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum) 162 | self.fc = nn.Linear(512*block.expansion, low_dim) 163 | 164 | self.pool_len = pool_len 165 | # for m in self.modules(): 166 | # if isinstance(m, nn.Conv2d): 167 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 168 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 169 | # elif isinstance(m, nn.BatchNorm2d): 170 | # m.weight.data.fill_(1) 171 | # m.bias.data.zero_() 172 | 173 | def _make_layer(self, block, planes, num_blocks, stride, bn_adv_flag=False, bn_adv_momentum=0.01): 174 | strides = [stride] + [1]*(num_blocks-1) 175 | layers = [] 176 | for stride in strides: 177 | layers.append(block(self.in_planes, planes, stride, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum)) 178 | self.in_planes = planes * block.expansion 179 | return MySequential(*layers) 180 | #return layers 181 | 182 | def forward(self, x, adv = False): 183 | if adv and self.bn_adv_flag: 184 | out = F.relu(self.bn1_adv(self.conv1(x))) 185 | else: 186 | out = F.relu(self.bn1(self.conv1(x))) 187 | 188 | out = self.layer1(out, adv=adv) 189 | out = self.layer2(out, adv=adv) 190 | out = self.layer3(out, adv=adv) 191 | out = self.layer4(out, adv=adv) 192 | 193 | out = F.avg_pool2d(out, self.pool_len) 194 | 195 | out = out.view(out.size(0), -1) 196 | 197 | out = self.fc(out) 198 | return out 199 | 200 | 201 | def resnet18(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 202 | return ResNetAdvProp_all(BasicBlock, [2,2,2,2], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 203 | 204 | def resnet34(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 205 | return ResNetAdvProp_all(BasicBlock, [3,4,6,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 206 | 207 | def resnet50(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 208 | return ResNetAdvProp_all(Bottleneck, [3,4,6,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 209 | 210 | def resnet101(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 211 | return ResNetAdvProp_all(Bottleneck, [3,4,23,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 212 | 213 | def resnet152(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 214 | return ResNetAdvProp_all(Bottleneck, [3,8,36,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 215 | 216 | 217 | def test(): 218 | net = ResNet18() 219 | # y = net(Variable(torch.randn(1,3,32,32))) 220 | # pdb.set_trace() 221 | y = net(Variable(torch.randn(1,3,96,96))) 222 | # pdb.set_trace() 223 | print(y.size()) 224 | 225 | # test() 226 | -------------------------------------------------------------------------------- /SimCLR/modules/resnet_BN_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import pdb 7 | 8 | 9 | class MySequential(nn.Sequential): 10 | def forward(self, x, adv): 11 | for module in self._modules.values(): 12 | x = module(x, adv=adv) 13 | return x 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=False, expansion=0, bn_adv_flag=False, bn_adv_momentum=0.01): 25 | super(BasicBlock, self).__init__() 26 | self.bn_adv_momentum = bn_adv_momentum 27 | self.bn_adv_flag = bn_adv_flag 28 | 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | if self.bn_adv_flag: 32 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | if self.bn_adv_flag: 38 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 39 | 40 | self.downsample = downsample 41 | if self.downsample: 42 | self.ds_conv1 = nn.Conv2d(inplanes, planes * expansion, kernel_size=1, stride=stride, bias=False) 43 | self.ds_bn1 = nn.BatchNorm2d(planes*expansion) 44 | self.ds_bn1_adv = nn.BatchNorm2d(planes*expansion) 45 | self.stride = stride 46 | 47 | def forward(self, x, adv = False): 48 | residual = x 49 | if adv and self.bn_adv_flag: 50 | out = self.conv1(x) 51 | out = self.bn1_adv(out) 52 | out = self.relu(out) 53 | out = self.conv2(out) 54 | out = self.bn2_adv(out) 55 | if self.downsample: 56 | 57 | residual = self.ds_bn1_adv(self.ds_conv1(x)) 58 | out += residual 59 | out = self.relu(out) 60 | else: 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | if self.downsample: 67 | residual = self.ds_bn1(self.ds_conv1(x)) 68 | out += residual 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=False, expansion=0, bn_adv_flag=False, bn_adv_momentum=0.01): 78 | super(Bottleneck, self).__init__() 79 | self.bn_adv_flag = bn_adv_flag 80 | self.bn_adv_momentum = bn_adv_momentum 81 | 82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 83 | self.bn1 = nn.BatchNorm2d(planes) 84 | if self.bn_adv_flag: 85 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 86 | 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 88 | padding=1, bias=False) 89 | self.bn2 = nn.BatchNorm2d(planes) 90 | if self.bn_adv_flag: 91 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum) 92 | 93 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 94 | self.bn3 = nn.BatchNorm2d(planes * 4) 95 | if self.bn_adv_flag: 96 | self.bn3_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum) 97 | 98 | self.relu = nn.ReLU(inplace=True) 99 | self.downsample = downsample 100 | if self.downsample: 101 | self.ds_conv1 = nn.Conv2d(inplanes, planes * expansion, kernel_size=1, stride=stride, bias=False) 102 | self.ds_bn1 = nn.BatchNorm2d(planes * expansion) 103 | self.ds_bn1_adv = nn.BatchNorm2d(planes * expansion) 104 | self.stride = stride 105 | 106 | def forward(self, x, adv = False): 107 | residual = x 108 | if adv and self.bn_adv_flag: 109 | out = self.conv1(x) 110 | out = self.bn1_adv(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2_adv(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3_adv(out) 119 | 120 | if self.downsample: 121 | 122 | residual = self.ds_bn1_adv(self.ds_conv1(x)) 123 | 124 | out += residual 125 | out = self.relu(out) 126 | else: 127 | out = self.conv1(x) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.relu(out) 134 | 135 | out = self.conv3(out) 136 | out = self.bn3(out) 137 | 138 | if self.downsample: 139 | 140 | residual = self.ds_bn1(self.ds_conv1(x)) 141 | 142 | out += residual 143 | out = self.relu(out) 144 | return out 145 | 146 | 147 | class ResNetAdvProp_imgnet(nn.Module): 148 | 149 | def __init__(self, block, layers, low_dim=128, is_feature=None, bn_adv_flag=False, bn_adv_momentum=0.01): 150 | super(ResNetAdvProp_imgnet, self).__init__() 151 | self.inplanes = 64 152 | self.bn_adv_flag = bn_adv_flag 153 | self.bn_adv_momentum = bn_adv_momentum 154 | 155 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 156 | bias=False) 157 | self.bn1 = nn.BatchNorm2d(64) 158 | if bn_adv_flag: 159 | self.bn1_adv = nn.BatchNorm2d(64, momentum = self.bn_adv_momentum) 160 | 161 | self.relu = nn.ReLU(inplace=True) 162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 163 | self.layer1 = self._make_layer(block, 64, layers[0], bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 165 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 166 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 167 | self.avgpool = nn.AdaptiveAvgPool2d(1) 168 | self.fc = nn.Linear(512 * block.expansion, low_dim) 169 | self.dropout = nn.Dropout(p=0.5) 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01): 181 | downsample = False 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = True 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, expansion=block.expansion , bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum)) 186 | self.inplanes = planes * block.expansion 187 | for i in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum)) 189 | 190 | return MySequential(*layers) 191 | 192 | def forward(self, x, adv = False): 193 | x = self.conv1(x) 194 | if adv and self.bn_adv_flag: 195 | out = self.bn1_adv(x) 196 | else: 197 | out = self.bn1(x) 198 | 199 | x = self.relu(x) 200 | x = self.maxpool(x) 201 | 202 | x = self.layer1(x, adv=adv) 203 | x = self.layer2(x, adv=adv) 204 | x = self.layer3(x, adv=adv) 205 | x = self.layer4(x, adv=adv) 206 | 207 | x = self.avgpool(x) 208 | x = x.view(x.size(0), -1) 209 | x = self.fc(x) 210 | return x 211 | 212 | 213 | def resnet18_imagenet(low_dim=128, bn_adv_flag=False,bn_adv_momentum=0.01): 214 | return ResNetAdvProp_imgnet(BasicBlock, [2,2,2,2], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 215 | 216 | def resnet34_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 217 | return ResNetAdvProp_imgnet(BasicBlock, [3,4,6,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 218 | 219 | def resnet50_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 220 | return ResNetAdvProp_imgnet(Bottleneck, [3,4,6,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 221 | 222 | def resnet101_imagenet( low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 223 | return ResNetAdvProp_imgnet(Bottleneck, [3,4,23,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 224 | 225 | def resnet152_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01): 226 | return ResNetAdvProp_imgnet(Bottleneck, [3,8,36,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum) 227 | 228 | 229 | def test(): 230 | net = resnet50() 231 | # y = net(Variable(torch.randn(1,3,32,32))) 232 | # pdb.set_trace() 233 | y = net(Variable(torch.randn(1,3,224,224)), adv=True) 234 | # pdb.set_trace() 235 | print(y.size()) 236 | #test() 237 | 238 | -------------------------------------------------------------------------------- /SimCLR/modules/simclr_BN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | from .resnet_BN import * 4 | from .resnet_BN_imagenet import * 5 | 6 | 7 | class Identity(nn.Module): 8 | def __init__(self): 9 | super(Identity, self).__init__() 10 | 11 | def forward(self, x): 12 | return x 13 | 14 | 15 | class SimCLR_BN(nn.Module): 16 | """ 17 | We opt for simplicity and adopt the commonly used ResNet (He et al., 2016) to obtain hi = f(x ̃i) = ResNet(x ̃i) where hi ∈ Rd is the output after the average pooling layer. 18 | """ 19 | 20 | def __init__(self, args, bn_adv_flag=False, bn_adv_momentum = 0.01, data='non_imagenet'): 21 | super(SimCLR_BN, self).__init__() 22 | 23 | self.args = args 24 | self.bn_adv_flag = bn_adv_flag 25 | self.bn_adv_momentum = bn_adv_momentum 26 | if data == 'imagenet': 27 | self.encoder = self.get_imagenet_resnet(args.resnet) 28 | else: 29 | self.encoder = self.get_resnet(args.resnet) 30 | 31 | self.n_features = self.encoder.fc.in_features # get dimensions of fc layer 32 | self.encoder.fc = Identity() # remove fully-connected layer after pooling layer 33 | 34 | self.projector = nn.Sequential( 35 | nn.Linear(self.n_features, self.n_features), 36 | nn.ReLU(), 37 | nn.Linear(self.n_features, args.projection_dim), 38 | ) 39 | 40 | def get_resnet(self, name): 41 | resnets = { 42 | "resnet18": resnet18(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 43 | "resnet34": resnet34(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 44 | "resnet50": resnet50(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 45 | "resnet101": resnet101(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 46 | "resnet152": resnet152(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 47 | } 48 | if name not in resnets.keys(): 49 | raise KeyError(f"{name} is not a valid ResNet version") 50 | return resnets[name] 51 | 52 | def get_imagenet_resnet(self, name): 53 | resnets = { 54 | "resnet18": resnet18_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 55 | "resnet34": resnet34_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 56 | "resnet50": resnet50_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 57 | "resnet101": resnet101_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 58 | "resnet152": resnet152_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum), 59 | } 60 | if name not in resnets.keys(): 61 | raise KeyError(f"{name} is not a valid ResNet version") 62 | return resnets[name] 63 | 64 | def forward(self, x, adv=False): 65 | h = self.encoder(x, adv=adv) 66 | z = self.projector(h) 67 | 68 | if self.args.normalize: 69 | z = nn.functional.normalize(z, dim=1) 70 | return h, z 71 | -------------------------------------------------------------------------------- /SimCLR/modules/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | from .simclr import TransformsSimCLR 2 | from .simclr import TransformsSimCLR_imagenet 3 | -------------------------------------------------------------------------------- /SimCLR/modules/transformations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/transformations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/transformations/__pycache__/simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/transformations/__pycache__/simclr.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/modules/transformations/simclr.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class GaussianBlur(object): 7 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 8 | 9 | def __init__(self, sigma=[.1, 2.]): 10 | self.sigma = sigma 11 | 12 | def __call__(self, x): 13 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 14 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 15 | return x 16 | 17 | 18 | class TransformsSimCLR: 19 | """ 20 | A stochastic data augmentation module that transforms any given data example randomly 21 | resulting in two correlated views of the same example, 22 | denoted x ̃i and x ̃j, which we consider as a positive pair. 23 | """ 24 | 25 | def __init__(self, size=32): 26 | s = 1 27 | color_jitter = torchvision.transforms.ColorJitter( 28 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s 29 | ) 30 | normalize = torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], 31 | std=[0.5, 0.5, 0.5]) 32 | self.train_transform = torchvision.transforms.Compose( 33 | [ 34 | torchvision.transforms.RandomResizedCrop(size=size), 35 | #torchvision.transforms.RandomResizedCrop(size=96), 36 | torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability 37 | torchvision.transforms.RandomApply([color_jitter], p=0.8), 38 | torchvision.transforms.RandomGrayscale(p=0.2), 39 | torchvision.transforms.ToTensor(), 40 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 41 | ] 42 | ) 43 | 44 | def __call__(self, x): 45 | return self.train_transform(x), self.train_transform(x) 46 | 47 | 48 | class TransformsSimCLR_imagenet: 49 | """ 50 | A stochastic data augmentation module that transforms any given data example randomly 51 | resulting in two correlated views of the same example, 52 | denoted x ̃i and x ̃j, which we consider as a positive pair. 53 | """ 54 | 55 | def __init__(self, size=224): 56 | s = 0.5 57 | color_jitter = torchvision.transforms.ColorJitter( 58 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s 59 | ) 60 | 61 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | self.train_transform = torchvision.transforms.Compose( 64 | [ 65 | torchvision.transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 66 | torchvision.transforms.RandomApply([color_jitter], p=0.8), 67 | torchvision.transforms.RandomGrayscale(p=0.2), 68 | torchvision.transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 69 | torchvision.transforms.RandomHorizontalFlip(), 70 | torchvision.transforms.ToTensor(), 71 | normalize 72 | ] 73 | ) 74 | 75 | def __call__(self, x): 76 | return self.train_transform(x), self.train_transform(x) 77 | 78 | -------------------------------------------------------------------------------- /SimCLR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .masks import mask_correlated_samples 2 | #from .yaml_config_hook import post_config_hook 3 | #from .filestorage import CustomFileStorageObserver -------------------------------------------------------------------------------- /SimCLR/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/utils/__pycache__/masks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/utils/__pycache__/masks.cpython-38.pyc -------------------------------------------------------------------------------- /SimCLR/utils/masks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mask_correlated_samples(args): 4 | mask = torch.ones((args.batch_size * 2, args.batch_size * 2), dtype=bool) 5 | mask = mask.fill_diagonal_(0) 6 | for i in range(args.batch_size): 7 | mask[i, args.batch_size + i] = 0 8 | mask[args.batch_size + i, i] = 0 9 | return mask 10 | -------------------------------------------------------------------------------- /set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import os.path as osp 5 | import random 6 | import torch 7 | import numpy as np 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | def mkdir_if_missing(dirname): 13 | """Create dirname if it is missing.""" 14 | if not osp.exists(dirname): 15 | try: 16 | os.makedirs(dirname) 17 | except OSError as e: 18 | if e.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | def set_random_seed(seed): 23 | torch.cuda.manual_seed_all(seed) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | 30 | class Logger: 31 | """Write console output to external text file. 32 | 33 | Imported from ``_ 34 | 35 | Args: 36 | fpath (str): directory to save logging file. 37 | 38 | Examples:: 39 | >>> import sys 40 | >>> import os.path as osp 41 | >>> save_dir = 'output/experiment-1' 42 | >>> log_name = 'train.log' 43 | >>> sys.stdout = Logger(osp.join(save_dir, log_name)) 44 | """ 45 | 46 | def __init__(self, fpath=None): 47 | self.console = sys.stdout 48 | self.file = None 49 | if fpath is not None: 50 | mkdir_if_missing(osp.dirname(fpath)) 51 | self.file = open(fpath, 'w') 52 | 53 | def __del__(self): 54 | self.close() 55 | 56 | def __enter__(self): 57 | pass 58 | 59 | def __exit__(self, *args): 60 | self.close() 61 | 62 | def write(self, msg): 63 | self.console.write(msg) 64 | if self.file is not None: 65 | self.file.write(msg) 66 | 67 | def flush(self): 68 | self.console.flush() 69 | if self.file is not None: 70 | self.file.flush() 71 | os.fsync(self.file.fileno()) 72 | 73 | def close(self): 74 | self.console.close() 75 | if self.file is not None: 76 | self.file.close() 77 | 78 | 79 | def setup_logger(output=None): 80 | if output is None: 81 | return 82 | 83 | if output.endswith('.txt') or output.endswith('.log'): 84 | fpath = output 85 | else: 86 | fpath = osp.join(output, 'log.txt') 87 | 88 | if osp.exists(fpath): 89 | # make sure the existing log file is not over-written 90 | fpath += time.strftime('-%Y-%m-%d-%H-%M-%S') 91 | 92 | sys.stdout = Logger(fpath) 93 | 94 | 95 | def accuracy(output, target, topk=(1,)): 96 | 97 | maxk = max(topk) 98 | batch_size = target.size(0) 99 | 100 | _, pred = output.topk(maxk, 1, True, True) 101 | pred = pred.t() 102 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 103 | 104 | res = [] 105 | for k in topk: 106 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 107 | res.append(correct_k.mul_(100.0 / batch_size)) 108 | 109 | if len(res) == 1: 110 | return res[0] 111 | else: 112 | return (res[0], res[1], correct[0], pred[0]) 113 | 114 | 115 | class AdamW(Optimizer): 116 | """Implements Adam algorithm. 117 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 118 | Arguments: 119 | params (iterable): iterable of parameters to optimize or dicts defining 120 | parameter groups 121 | lr (float, optional): learning rate (default: 1e-3) 122 | betas (Tuple[float, float], optional): coefficients used for computing 123 | running averages of gradient and its square (default: (0.9, 0.999)) 124 | eps (float, optional): term added to the denominator to improve 125 | numerical stability (default: 1e-8) 126 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 127 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 128 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 129 | .. _Adam\: A Method for Stochastic Optimization: 130 | https://arxiv.org/abs/1412.6980 131 | .. _On the Convergence of Adam and Beyond: 132 | https://openreview.net/forum?id=ryQu7f-RZ 133 | """ 134 | 135 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 136 | weight_decay=0, amsgrad=False): 137 | if not 0.0 <= lr: 138 | raise ValueError("Invalid learning rate: {}".format(lr)) 139 | if not 0.0 <= eps: 140 | raise ValueError("Invalid epsilon value: {}".format(eps)) 141 | if not 0.0 <= betas[0] < 1.0: 142 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 143 | if not 0.0 <= betas[1] < 1.0: 144 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 145 | defaults = dict(lr=lr, betas=betas, eps=eps, 146 | weight_decay=weight_decay, amsgrad=amsgrad) 147 | super(AdamW, self).__init__(params, defaults) 148 | 149 | def __setstate__(self, state): 150 | super(AdamW, self).__setstate__(state) 151 | for group in self.param_groups: 152 | group.setdefault('amsgrad', False) 153 | 154 | def step(self, closure=None): 155 | """Performs a single optimization step. 156 | Arguments: 157 | closure (callable, optional): A closure that reevaluates the model 158 | and returns the loss. 159 | """ 160 | loss = None 161 | if closure is not None: 162 | loss = closure() 163 | 164 | for group in self.param_groups: 165 | for p in group['params']: 166 | if p.grad is None: 167 | continue 168 | grad = p.grad.data 169 | if grad.is_sparse: 170 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 171 | amsgrad = group['amsgrad'] 172 | 173 | state = self.state[p] 174 | 175 | # State initialization 176 | if len(state) == 0: 177 | state['step'] = 0 178 | # Exponential moving average of gradient values 179 | state['exp_avg'] = torch.zeros_like(p.data) 180 | # Exponential moving average of squared gradient values 181 | state['exp_avg_sq'] = torch.zeros_like(p.data) 182 | if amsgrad: 183 | # Maintains max of all exp. moving avg. of sq. grad. values 184 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 185 | 186 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 187 | if amsgrad: 188 | max_exp_avg_sq = state['max_exp_avg_sq'] 189 | beta1, beta2 = group['betas'] 190 | 191 | state['step'] += 1 192 | 193 | # if group['weight_decay'] != 0: 194 | # grad = grad.add(group['weight_decay'], p.data) 195 | 196 | # Decay the first and second moment running average coefficient 197 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 198 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 199 | if amsgrad: 200 | # Maintains the maximum of all 2nd moment running avg. till now 201 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 202 | # Use the max. for normalizing running avg. of gradient 203 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 204 | else: 205 | denom = exp_avg_sq.sqrt().add_(group['eps']) 206 | 207 | bias_correction1 = 1 - beta1 ** state['step'] 208 | bias_correction2 = 1 - beta2 ** state['step'] 209 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 210 | 211 | # p.data.addcdiv_(-step_size, exp_avg, denom) 212 | p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)) 213 | 214 | return loss 215 | 216 | 217 | class AverageMeter(object): 218 | """Computes and stores the average and current value 219 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 220 | """ 221 | 222 | def __init__(self): 223 | self.reset() 224 | 225 | def reset(self): 226 | self.val = 0 227 | self.avg = 0 228 | self.sum = 0 229 | self.count = 0 230 | 231 | def update(self, val, n=1): 232 | self.val = val 233 | self.sum += val * n 234 | self.count += n 235 | self.avg = self.sum / self.count 236 | 237 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import wandb 13 | import os 14 | import time 15 | import argparse 16 | import datetime 17 | from torch.autograd import Variable 18 | import pdb 19 | import sys 20 | import torch.autograd as autograd 21 | import torchvision.models as models 22 | sys.path.append('.') 23 | 24 | from vae import * 25 | from set import * 26 | from apex import amp 27 | 28 | 29 | def reconst_images(batch_size=64, batch_num=1, dataloader=None, model=None): 30 | cifar10_dataloader = dataloader 31 | model.eval() 32 | with torch.no_grad(): 33 | for batch_idx, (X, y) in enumerate(cifar10_dataloader): 34 | if batch_idx >= batch_num: 35 | break 36 | else: 37 | X, y = X.cuda(), y.cuda().view(-1, ) 38 | _, gx, _, _ = model(X) 39 | 40 | grid_X = torchvision.utils.make_grid(X[:batch_size].data, nrow=8, padding=2, normalize=True) 41 | wandb.log({"_Batch_{batch}_X.jpg".format(batch=batch_idx): [ 42 | wandb.Image(grid_X)]}, commit=False) 43 | grid_Xi = torchvision.utils.make_grid(gx[:batch_size].data, nrow=8, padding=2, normalize=True) 44 | wandb.log({"_Batch_{batch}_GX.jpg".format(batch=batch_idx): [ 45 | wandb.Image(grid_Xi)]}, commit=False) 46 | grid_X_Xi = torchvision.utils.make_grid((X[:batch_size] - gx[:batch_size]).data, nrow=8, padding=2, 47 | normalize=True) 48 | wandb.log({"_Batch_{batch}_RX.jpg".format(batch=batch_idx): [ 49 | wandb.Image(grid_X_Xi)]}, commit=False) 50 | print('reconstruction complete!') 51 | 52 | 53 | def test(epoch, model, testloader): 54 | # set model as testing mode 55 | model.eval() 56 | acc_gx_avg = AverageMeter() 57 | acc_rx_avg = AverageMeter() 58 | 59 | with torch.no_grad(): 60 | for batch_idx, (x, y) in enumerate(testloader): 61 | # distribute data to device 62 | x, y = x.cuda(), y.cuda().view(-1, ) 63 | bs = x.size(0) 64 | norm = torch.norm(torch.abs(x.view(bs, -1)), p=2, dim=1) 65 | _, gx, _, _ = model(x) 66 | acc_gx = 1 - F.mse_loss(torch.div(gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 67 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 68 | reduction='sum') / bs 69 | acc_rx = 1 - F.mse_loss(torch.div(x - gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 70 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \ 71 | reduction='sum') / bs 72 | 73 | acc_gx_avg.update(acc_gx.data.item(), bs) 74 | acc_rx_avg.update(acc_rx.data.item(), bs) 75 | 76 | wandb.log({'acc_gx_avg': acc_gx_avg.avg, \ 77 | 'acc_rx_avg': acc_rx_avg.avg}, commit=False) 78 | # plot progress 79 | print("\n| Validation Epoch #%d\t\tRec_gx: %.4f Rec_rx: %.4f " % (epoch, acc_gx_avg.avg, acc_rx_avg.avg)) 80 | reconst_images(batch_size=64, batch_num=2, dataloader=testloader, model=model) 81 | torch.save(model.state_dict(), 82 | os.path.join(args.save_dir, 'model_epoch{}.pth'.format(epoch + 1))) # save motion_encoder 83 | print("Epoch {} model saved!".format(epoch + 1)) 84 | 85 | 86 | def main(args): 87 | setup_logger(args.save_dir) 88 | use_cuda = torch.cuda.is_available() 89 | print('\n[Phase 1] : Data Preparation') 90 | 91 | if args.dataset == 'imagenet': 92 | size = 224 93 | normalizer = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 94 | model = CVAE_imagenet_withbn(128, args.dim) 95 | p_blur = 0.5 96 | else: 97 | size = 32 98 | normalizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 99 | model = CVAE_cifar_withbn(128, args.dim) 100 | p_blur = 0.0 101 | 102 | if args.mode=='simclr': 103 | print('\nData Augmentation: SimCLR') 104 | s = 1 105 | color_jitter = transforms.ColorJitter( 106 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s 107 | ) 108 | transform_train = transforms.Compose( 109 | [ 110 | transforms.RandomResizedCrop(size=size), 111 | transforms.RandomHorizontalFlip(), # with 0.5 probability 112 | transforms.RandomApply([color_jitter], p=0.8), 113 | transforms.RandomGrayscale(p=0.2), 114 | transforms.ToTensor(), 115 | normalizer 116 | ] 117 | ) 118 | elif args.mode=='simsiam': 119 | print('\nData Augmentation: SimSiam') 120 | transform_train = transforms.Compose([ 121 | transforms.RandomResizedCrop(size, scale=(0.2, 1.0)), 122 | transforms.RandomHorizontalFlip(), 123 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 124 | transforms.RandomGrayscale(p=0.2), 125 | transforms.RandomApply([transforms.GaussianBlur(kernel_size=size // 20 * 2 + 1, sigma=(0.1, 2.0))], p=p_blur), 126 | transforms.ToTensor(), 127 | normalizer 128 | ]) 129 | else: 130 | print('\nData Augmentation: Normal') 131 | transform_train = transforms.Compose([ 132 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 133 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 134 | transforms.RandomGrayscale(p=0.2), 135 | transforms.RandomHorizontalFlip(), 136 | transforms.ToTensor(), 137 | normalizer 138 | ]) 139 | if args.dataset == 'cifar10': 140 | print("| Preparing CIFAR-10 dataset...") 141 | sys.stdout.write("| ") 142 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train) 143 | elif args.dataset == 'cifar100': 144 | print("| Preparing CIFAR-100 dataset...") 145 | sys.stdout.write("| ") 146 | trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train) 147 | elif args.dataset == 'imagenet': 148 | print("| Preparing imagenet dataset...") 149 | sys.stdout.write("| ") 150 | root='/gpub/imagenet_raw' 151 | train_path = os.path.join(root, 'train') 152 | trainset = datasets.ImageFolder(root=train_path, transform=transform_train) 153 | 154 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, 155 | drop_last=True) 156 | # Model 157 | print('\n[Phase 2] : Model setup') 158 | if use_cuda: 159 | model.cuda() 160 | cudnn.benchmark = True 161 | 162 | optimizer = AdamW([ 163 | {'params': model.parameters()}, 164 | ], lr=args.lr, betas=(0.0, 0.9)) 165 | 166 | scheduler = optim.lr_scheduler.LambdaLR( 167 | optimizer, lambda epoch: 1 - epoch / args.epochs) 168 | 169 | if args.amp: 170 | model, optimizer = amp.initialize( 171 | model, optimizer, opt_level=args.opt_level) 172 | 173 | print('\n[Phase 3] : Training model') 174 | print('| Training Epochs = ' + str(args.epochs)) 175 | print('| Initial Learning Rate = ' + str(args.lr)) 176 | 177 | start_epoch = 1 178 | for epoch in range(start_epoch, start_epoch + args.epochs): 179 | model.train() 180 | 181 | loss_avg = AverageMeter() 182 | loss_rec = AverageMeter() 183 | loss_kl = AverageMeter() 184 | 185 | print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, optimizer.param_groups[0]['lr'])) 186 | for batch_idx, (x, y) in enumerate(trainloader): 187 | x, y = x.cuda(), y.cuda().view(-1, ) 188 | x, y = Variable(x), Variable(y) 189 | bs = x.size(0) 190 | 191 | _, gx, mu, logvar = model(x) 192 | optimizer.zero_grad() 193 | l_rec = F.mse_loss(x, gx) 194 | l_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 195 | l_kl /= bs * 3 * args.dim 196 | loss = l_rec + args.kl * l_kl 197 | 198 | if args.amp: 199 | with amp.scale_loss(loss, optimizer) as scaled_loss: 200 | scaled_loss.backward() 201 | else: 202 | loss.backward() 203 | 204 | optimizer.step() 205 | 206 | loss_avg.update(loss.data.item(), bs) 207 | loss_rec.update(l_rec.data.item(), bs) 208 | loss_kl.update(l_kl.data.item(), bs) 209 | 210 | n_iter = (epoch - 1) * len(trainloader) + batch_idx 211 | wandb.log({'loss': loss_avg.avg, \ 212 | 'loss_rec': loss_rec.avg, \ 213 | 'loss_kl': loss_kl.avg, \ 214 | 'lr': optimizer.param_groups[0]['lr']}, step=n_iter) 215 | if (batch_idx + 1) % 30 == 0: 216 | sys.stdout.write('\r') 217 | sys.stdout.write( 218 | '| Epoch [%3d/%3d] Iter[%3d/%3d]\t\t Loss_rec: %.4f Loss_kl: %.4f' 219 | % (epoch, args.epochs, batch_idx + 1, 220 | len(trainloader), loss_rec.avg, loss_kl.avg)) 221 | scheduler.step() 222 | test(epoch, model, trainloader) 223 | wandb.finish() 224 | 225 | 226 | if __name__ == '__main__': 227 | parser = argparse.ArgumentParser(description='VAE Training') 228 | parser.add_argument('--lr', default=5e-4, type=float, help='learning_rate') 229 | parser.add_argument('--save_dir', default='./results/vae_cifar10_simclr', type=str, help='save_dir') 230 | parser.add_argument('--seed', default=666, type=int, help='seed') 231 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100/imagenet]') 232 | parser.add_argument('--epochs', default=300, type=int, help='training_epochs') 233 | parser.add_argument('--batch_size', default=128, type=int, help='batch_size') 234 | parser.add_argument('--dim', default=128, type=int, help='CNN_embed_dim') 235 | parser.add_argument('--kl', default=0.1, type=float, help='kl weight') 236 | parser.add_argument('--mode', default='normal', type=str, help='augmentation mode') 237 | parser.add_argument("--amp", action="store_true", 238 | help="use 16-bit (mixed) precision through NVIDIA apex AMP") 239 | parser.add_argument("--opt_level", type=str, default="O1", 240 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 241 | "See details at https://nvidia.github.io/apex/amp.html") 242 | args = parser.parse_args() 243 | wandb.init(config=args, name=args.save_dir.replace("./results/", '')) 244 | set_random_seed(args.seed) 245 | main(args) 246 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import abc 3 | import os 4 | import math 5 | import pdb 6 | import numpy as np 7 | import logging 8 | import torch 9 | import torch.utils.data 10 | from torch import nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | import torchvision.models as models 15 | import torch.nn.functional as F 16 | from torch.autograd import Function, Variable 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 21 | super(ResBlock, self).__init__() 22 | 23 | if mid_channels is None: 24 | mid_channels = out_channels 25 | 26 | layers = [ 27 | nn.LeakyReLU(), 28 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 29 | nn.LeakyReLU(), 30 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)] 31 | if bn: 32 | layers.insert(2, nn.BatchNorm2d(out_channels)) 33 | self.convs = nn.Sequential(*layers) 34 | 35 | def forward(self, x): 36 | return x + self.convs(x) 37 | 38 | 39 | class AbstractAutoEncoder(nn.Module): 40 | __metaclass__ = abc.ABCMeta 41 | 42 | @abc.abstractmethod 43 | def encode(self, x): 44 | return 45 | 46 | @abc.abstractmethod 47 | def decode(self, z): 48 | return 49 | 50 | @abc.abstractmethod 51 | def forward(self, x): 52 | """model return (reconstructed_x, *)""" 53 | return 54 | 55 | @abc.abstractmethod 56 | def sample(self, size): 57 | """sample new images from model""" 58 | return 59 | 60 | @abc.abstractmethod 61 | def loss_function(self, **kwargs): 62 | """accepts (original images, *) where * is the same as returned from forward()""" 63 | return 64 | 65 | @abc.abstractmethod 66 | def latest_losses(self): 67 | """returns the latest losses in a dictionary. Useful for logging.""" 68 | return 69 | 70 | 71 | class CVAE_cifar_withbn(AbstractAutoEncoder): 72 | def __init__(self, d, z, **kwargs): 73 | super(CVAE_cifar_withbn, self).__init__() 74 | 75 | self.encoder = nn.Sequential( 76 | nn.Conv2d(3, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 77 | nn.BatchNorm2d(d // 2), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False), 80 | nn.BatchNorm2d(d), 81 | nn.ReLU(inplace=True), 82 | ResBlock(d, d, bn=True), 83 | nn.BatchNorm2d(d), 84 | ResBlock(d, d, bn=True), 85 | ) 86 | 87 | self.decoder = nn.Sequential( 88 | ResBlock(d, d, bn=True), 89 | nn.BatchNorm2d(d), 90 | ResBlock(d, d, bn=True), 91 | nn.BatchNorm2d(d), 92 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 93 | nn.BatchNorm2d(d // 2), 94 | nn.LeakyReLU(inplace=True), 95 | nn.ConvTranspose2d(d // 2, 3, kernel_size=4, stride=2, padding=1, bias=False), 96 | ) 97 | self.bn = nn.BatchNorm2d(3) 98 | self.f = 8 99 | self.d = d 100 | self.z = z 101 | self.fc11 = nn.Linear(d * self.f ** 2, self.z) 102 | self.fc12 = nn.Linear(d * self.f ** 2, self.z) 103 | self.fc21 = nn.Linear(self.z, d * self.f ** 2) 104 | 105 | def encode(self, x): 106 | h = self.encoder(x) 107 | h1 = h.view(-1, self.d * self.f ** 2) 108 | return h, self.fc11(h1), self.fc12(h1) 109 | 110 | def reparameterize(self, mu, logvar): 111 | if self.training: 112 | std = logvar.mul(0.5).exp_() 113 | eps = std.new(std.size()).normal_() 114 | return eps.mul(std).add_(mu) 115 | else: 116 | return mu 117 | 118 | def decode(self, z): 119 | z = z.view(-1, self.d, self.f, self.f) 120 | h3 = self.decoder(z) 121 | return torch.tanh(h3) 122 | 123 | def forward(self, x, decode=False): 124 | if decode: 125 | z_projected = self.fc21(x) 126 | gx = self.decode(z_projected) 127 | gx = self.bn(gx) 128 | return gx 129 | else: 130 | _, mu, logvar = self.encode(x) 131 | z = self.reparameterize(mu, logvar) 132 | z_projected = self.fc21(z) 133 | gx = self.decode(z_projected) 134 | gx = self.bn(gx) 135 | return z, gx, mu, logvar 136 | 137 | 138 | class CVAE_imagenet_withbn(AbstractAutoEncoder): 139 | def __init__(self, d, z, **kwargs): 140 | super(CVAE_imagenet_withbn, self).__init__() 141 | 142 | self.encoder = nn.Sequential( 143 | nn.Conv2d(3, d // 16, kernel_size=4, stride=2, padding=1, bias=False), 144 | nn.BatchNorm2d(d // 16), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(d // 16, d // 8, kernel_size=4, stride=2, padding=1, bias=False), 147 | nn.BatchNorm2d(d // 8), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(d // 8, d // 4, kernel_size=4, stride=2, padding=1, bias=False), 150 | nn.BatchNorm2d(d // 4), 151 | nn.ReLU(inplace=True), 152 | nn.Conv2d(d // 4, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 153 | nn.BatchNorm2d(d // 2), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False), 156 | nn.BatchNorm2d(d), 157 | nn.ReLU(inplace=True), 158 | ResBlock(d, d, bn=True), 159 | nn.BatchNorm2d(d), 160 | ResBlock(d, d, bn=True), 161 | nn.BatchNorm2d(d) 162 | ) 163 | 164 | self.decoder = nn.Sequential( 165 | nn.BatchNorm2d(d), 166 | ResBlock(d, d, bn=True), 167 | nn.BatchNorm2d(d), 168 | ResBlock(d, d, bn=True), 169 | nn.BatchNorm2d(d), 170 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False), 171 | nn.BatchNorm2d(d // 2), 172 | nn.LeakyReLU(inplace=True), 173 | nn.ConvTranspose2d(d // 2, d // 4, kernel_size=4, stride=2, padding=1, bias=False), 174 | nn.BatchNorm2d(d // 4), 175 | nn.LeakyReLU(inplace=True), 176 | nn.ConvTranspose2d(d // 4, d // 8, kernel_size=4, stride=2, padding=1, bias=False), 177 | nn.BatchNorm2d(d // 8), 178 | nn.LeakyReLU(inplace=True), 179 | nn.ConvTranspose2d(d // 8, d // 16, kernel_size=4, stride=2, padding=1, bias=False), 180 | nn.BatchNorm2d(d // 16), 181 | nn.LeakyReLU(inplace=True), 182 | nn.ConvTranspose2d(d // 16, 3, kernel_size=4, stride=2, padding=1, bias=False), 183 | ) 184 | self.bn = nn.BatchNorm2d(3) 185 | self.f = 7 186 | self.d = d 187 | self.z = z 188 | self.fc11 = nn.Linear(d * self.f ** 2, self.z) 189 | self.fc12 = nn.Linear(d * self.f ** 2, self.z) 190 | self.fc21 = nn.Linear(self.z, d * self.f ** 2) 191 | 192 | def encode(self, x): 193 | h = self.encoder(x) 194 | h1 = h.view(-1, self.d * self.f ** 2) 195 | return h, self.fc11(h1), self.fc12(h1) 196 | 197 | def reparameterize(self, mu, logvar): 198 | if self.training: 199 | std = logvar.mul(0.5).exp_() 200 | eps = std.new(std.size()).normal_() 201 | return eps.mul(std).add_(mu) 202 | else: 203 | return mu 204 | 205 | def decode(self, z): 206 | z = z.view(-1, self.d, self.f, self.f) 207 | h3 = self.decoder(z) 208 | return torch.tanh(h3) 209 | 210 | def forward(self, x, decode=False): 211 | if decode: 212 | z_projected = self.fc21(x) 213 | gx = self.decode(z_projected) 214 | gx = self.bn(gx) 215 | return gx 216 | else: 217 | _, mu, logvar = self.encode(x) 218 | z = self.reparameterize(mu, logvar) 219 | z_projected = self.fc21(z) 220 | gx = self.decode(z_projected) 221 | gx = self.bn(gx) 222 | return z, gx, mu, logvar 223 | 224 | 225 | --------------------------------------------------------------------------------