├── IDAA.png
├── LICENSE
├── README.md
├── SimCLR
├── eval_knn.py
├── eval_lr.py
├── main.py
├── model.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── lars.cpython-38.pyc
│ │ ├── logistic_regression.cpython-38.pyc
│ │ ├── nt_xent.cpython-38.pyc
│ │ ├── resnet_BN.cpython-38.pyc
│ │ ├── resnet_BN_imagenet.cpython-38.pyc
│ │ └── simclr_BN.cpython-38.pyc
│ ├── lars.py
│ ├── logistic_regression.py
│ ├── nt_xent.py
│ ├── resnet_BN.py
│ ├── resnet_BN_imagenet.py
│ ├── simclr_BN.py
│ └── transformations
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── simclr.cpython-38.pyc
│ │ └── simclr.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── masks.cpython-38.pyc
│ └── masks.py
├── set.py
├── train_vae.py
└── vae.py
/IDAA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/IDAA.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kaiwen Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IDAA
2 |
3 | Official implementation:
4 | - Identity-Disentangled Adversarial Augmentation for Self-Supervised Learning, ICML 2022. ([Paper](https://proceedings.mlr.press/v162/yang22s/yang22s.pdf))
5 |
6 |
7 |
8 |

9 |
Architecture and pipeline of Identity-Disentangled Adversarial Augmentation (IDAA)
10 |
11 |
12 | For questions, you can contact (kwyang@mail.ustc.edu.cn).
13 |
14 | ## Requirements
15 |
16 | 1. [Python](https://www.python.org/)
17 | 2. [Pytorch](https://pytorch.org/)
18 | 3. [Wandb](https://wandb.ai/site)
19 | 4. [Torchvision](https://pytorch.org/vision/stable/index.html)
20 | 5. [Apex(optional)](https://github.com/NVIDIA/apex)
21 |
22 | ## Pretrain a VAE
23 |
24 |
25 | ```
26 | python train_vae.py --dim 512 --kl 0.1 --save_dir ./results/vae_cifar10_dim512_kl0.1_simclr --mode simclr --dataset cifar10
27 | ```
28 |
29 | ## Apply IDAA to SimCLR
30 | ```
31 | cd SimCLR
32 | ```
33 |
34 | SimCLR training and evaluation:
35 | ```
36 | python main.py --seed 1 --gpu 0 --dataset cifar10 --resnet resnet18;
37 | python eval_lr.py --seed 1 --gpu 0 --dataset cifar10 --resnet resnet18
38 | ```
39 | SimCLR+IDAA training and evaluation:
40 | ```
41 | python main.py --adv --eps 0.1 --seed 1 --gpu 0 --dataset cifar10 --dim 512 --vae_path ../results/vae_cifar10_dim512_kl0.1_simclr/model_epoch292.pth --resnet resnet18;
42 | python eval_lr.py --adv --eps 0.1 --seed 1 --gpu 0 --dataset cifar10 --dim 512 --resnet resnet18
43 | ```
44 |
45 | ## References
46 | We borrow some code from https://github.com/chihhuiho/CLAE.
47 |
48 |
49 | ## Citation
50 |
51 | If you find this repo useful for your research, please consider citing the paper
52 | ```
53 | @inproceedings{yang2022identity,
54 | title={Identity-Disentangled Adversarial Augmentation for Self-supervised Learning},
55 | author={Yang, Kaiwen and Zhou, Tianyi and Tian, Xinmei and Tao, Dacheng},
56 | booktitle={International Conference on Machine Learning},
57 | pages={25364--25381},
58 | year={2022},
59 | organization={PMLR}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/SimCLR/eval_knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import argparse
5 | import sys
6 | import os
7 | #from experiment import ex
8 | from model import load_model, save_model
9 | import wandb
10 | from modules import LogisticRegression
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 |
15 | def kNN(epoch, net, trainloader, testloader, K, sigma, ndata, low_dim = 128):
16 | net.eval()
17 | total = 0
18 | correct_t = 0
19 | testsize = testloader.dataset.__len__()
20 |
21 | try:
22 | trainLabels = torch.LongTensor(trainloader.dataset.targets).cuda()
23 | except:
24 | trainLabels = torch.LongTensor(trainloader.dataset.labels).cuda()
25 | trainFeatures = np.zeros((low_dim, ndata))
26 | trainFeatures = torch.Tensor(trainFeatures).cuda()
27 | C = trainLabels.max() + 1
28 | C = np.int(C)
29 |
30 | with torch.no_grad():
31 | transform_bak = trainloader.dataset.transform
32 | trainloader.dataset.transform = testloader.dataset.transform
33 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=256, shuffle=False, num_workers=4)
34 | for batch_idx, (inputs, targets) in tqdm(enumerate(temploader)):
35 | targets = targets.cuda()
36 | batchSize = inputs.size(0)
37 | _, features = net(inputs.cuda())
38 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.t()
39 |
40 |
41 | trainloader.dataset.transform = transform_bak
42 | #
43 |
44 |
45 | top1 = 0.
46 | top5 = 0.
47 | with torch.no_grad():
48 | retrieval_one_hot = torch.zeros(K, C).cuda()
49 | for batch_idx, (inputs, targets) in enumerate(testloader):
50 |
51 | targets = targets.cuda()
52 | batchSize = inputs.size(0)
53 | _, features = net(inputs.cuda())
54 | total += targets.size(0)
55 |
56 | dist = torch.mm(features, trainFeatures)
57 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
58 | candidates = trainLabels.view(1,-1).expand(batchSize, -1)
59 | retrieval = torch.gather(candidates, 1, yi)
60 | retrieval_one_hot.resize_(batchSize * K, C).zero_()
61 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
62 | yd_transform = yd.clone().div_(sigma).exp_()
63 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1)
64 |
65 | _, predictions = probs.sort(1, True)
66 | # Find which predictions match the target
67 | correct = predictions.eq(targets.data.view(-1,1))
68 |
69 | top1 = top1 + correct.narrow(1,0,1).sum().item()
70 |
71 | print(top1*100./total)
72 |
73 | return top1*100./total
--------------------------------------------------------------------------------
/SimCLR/eval_lr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import argparse
5 | import os
6 | from model import load_model, save_model
7 | import sys
8 | import wandb
9 | from modules import LogisticRegression
10 | sys.path.append('.')
11 | sys.path.append('..')
12 | from set import *
13 | from utils import *
14 |
15 |
16 | parser = argparse.ArgumentParser(description='PyTorch Seen Testing Category Training')
17 | parser.add_argument('--batch_size', default=256, type=int,
18 | metavar='B', help='training batch size')
19 | parser.add_argument('--logistic_batch_size', default=256, type=int,
20 | metavar='B', help='logistic_batch_size batch size')
21 | parser.add_argument('--logistic_epochs', default=1000, type=int, help='logistic_epochs')
22 | parser.add_argument('--workers', default=4, type=int, help='workers')
23 | parser.add_argument('--epochs', default=300, type=int,help='epochs')
24 | parser.add_argument('--resnet', default="resnet18", type=str, help="resnet")
25 | parser.add_argument('--normalize', default=True, action='store_true', help='normalize')
26 | parser.add_argument('--projection_dim', default=64, type=int,help='projection_dim')
27 | parser.add_argument('--optimizer', default="Adam", type=str, help="optimizer")
28 | parser.add_argument('--weight_decay', default=1.0e-6, type=float, help='weight_decay')
29 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature')
30 | parser.add_argument('--model_path', default='checkpoint/', type=str,
31 | help='model save path')
32 | parser.add_argument('--model_dir', default='checkpoint/', type=str,
33 | help='model save path')
34 | parser.add_argument('--lr', default=3e-4, type=float, help='learning rate')
35 | parser.add_argument('--dataset', default='cifar10',
36 | help='[cifar10, cifar100]')
37 | parser.add_argument('--gpu', default='0', type=str,
38 | help='gpu device ids for CUDA_VISIBLE_DEVICES')
39 | parser.add_argument('--trial', type=int, help='trial')
40 | parser.add_argument('--adv', default=False, action='store_true', help='adversarial exmaple')
41 | parser.add_argument('--eps', default=0.01, type=float, help='eps for adversarial')
42 | parser.add_argument('--bn_adv_momentum', default=0.01, type=float, help='batch norm momentum for advprop')
43 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for contrastive loss with adversarial example')
44 | parser.add_argument('--debug', default=False, action='store_true', help='debug mode')
45 | parser.add_argument('--seed', default=1, type=int, help='seed')
46 | parser.add_argument('--dim', default=512, type=int, help='CNN_embed_dim')
47 | args = parser.parse_args()
48 | set_random_seed(args.seed)
49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
50 |
51 |
52 | def train(args, loader, simclr_model, model, criterion, optimizer):
53 | loss_epoch = 0
54 | accuracy_epoch = 0
55 | for step, (x, y) in enumerate(loader):
56 | optimizer.zero_grad()
57 |
58 | x = x.to(args.device)
59 | y = y.to(args.device)
60 |
61 | # get encoding
62 | with torch.no_grad():
63 | h, z = simclr_model(x)
64 | # h = 512
65 | # z = 64
66 |
67 | output = model(h)
68 | loss = criterion(output, y)
69 |
70 | predicted = output.argmax(1)
71 | acc = (predicted == y).sum().item() / y.size(0)
72 | accuracy_epoch += acc
73 |
74 | loss.backward()
75 | optimizer.step()
76 |
77 | loss_epoch += loss.item()
78 | if step % 100 == 0:
79 | print(f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}")
80 |
81 | if args.debug:
82 | break
83 |
84 | return loss_epoch, accuracy_epoch
85 |
86 |
87 | def test(args, loader, simclr_model, model, criterion, optimizer):
88 | loss_epoch = 0
89 | accuracy_epoch = 0
90 | model.eval()
91 | for step, (x, y) in enumerate(loader):
92 | model.zero_grad()
93 |
94 | x = x.to(args.device)
95 | y = y.to(args.device)
96 |
97 | # get encoding
98 | with torch.no_grad():
99 | h, z = simclr_model(x)
100 | # h = 512
101 | # z = 64
102 |
103 | output = model(h)
104 | loss = criterion(output, y)
105 |
106 | predicted = output.argmax(1)
107 | acc = (predicted == y).sum().item() / y.size(0)
108 | accuracy_epoch += acc
109 |
110 | loss_epoch += loss.item()
111 |
112 | return loss_epoch, accuracy_epoch
113 |
114 |
115 | def main():
116 | args.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
117 |
118 | root = "../../data"
119 |
120 | if args.dataset == 'tinyImagenet':
121 | transform = transforms.Compose([
122 | torchvision.transforms.Resize((224, 224)),
123 | transforms.ToTensor(),
124 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
125 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
126 | std=[0.229, 0.224, 0.225])
127 | ])
128 | data = 'imagenet'
129 | elif args.dataset == 'miniImagenet':
130 | transform = transforms.Compose([
131 | torchvision.transforms.Resize((84, 84)),
132 | transforms.ToTensor(),
133 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
134 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
135 | std=[0.229, 0.224, 0.225])
136 | ])
137 | data = 'imagenet'
138 | elif args.dataset == 'imagenet100':
139 | transform = transforms.Compose([
140 | torchvision.transforms.Resize((224, 224)),
141 | transforms.ToTensor(),
142 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
144 | std=[0.229, 0.224, 0.225])
145 | ])
146 | data = 'imagenet'
147 | else:
148 | transform = transforms.Compose([
149 | torchvision.transforms.Resize(size=32),
150 | transforms.ToTensor(),
151 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
152 | ])
153 | data = 'non_imagenet'
154 |
155 | if args.dataset == "cifar10" :
156 | train_dataset = torchvision.datasets.CIFAR10(
157 | root, train=True, download=True, transform=transform
158 | )
159 | test_dataset = torchvision.datasets.CIFAR10(
160 | root, train=False, download=True, transform=transform
161 | )
162 | elif args.dataset == "cifar100":
163 | train_dataset = torchvision.datasets.CIFAR100(
164 | root, train=True, download=True, transform=transform
165 | )
166 | test_dataset = torchvision.datasets.CIFAR100(
167 | root, train=False, download=True, transform=transform
168 | )
169 | else:
170 | raise NotImplementedError
171 |
172 | train_loader = torch.utils.data.DataLoader(
173 | train_dataset,
174 | batch_size=args.logistic_batch_size,
175 | shuffle=True,
176 | drop_last=True,
177 | num_workers=args.workers,
178 | )
179 |
180 | test_loader = torch.utils.data.DataLoader(
181 | test_dataset,
182 | batch_size=args.logistic_batch_size,
183 | shuffle=False,
184 | drop_last=True,
185 | num_workers=args.workers,
186 | )
187 |
188 | log_dir = "log_eval/" + args.dataset + '_LR_log/'
189 |
190 | if not os.path.isdir(log_dir):
191 | os.makedirs(log_dir)
192 |
193 | suffix = args.dataset + '_{}_batch_{}'.format(args.resnet, args.batch_size)
194 | if args.adv:
195 | suffix = suffix + '_alpha_{}_adv_eps_{}'.format(args.alpha, args.eps)
196 |
197 | suffix = suffix + '_proj_dim_{}'.format(args.projection_dim)
198 | suffix = suffix + '_bn_adv_momentum_{}_seed_{}'.format(args.bn_adv_momentum, args.seed)
199 | wandb.init(config=args, name='LR/' + suffix.replace("_log/", ''))
200 | args.model_dir = args.model_dir + args.dataset + '/'
201 | print("Loading {}".format(args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs)))
202 | if args.adv:
203 | simclr_model, _, _ = load_model(args, train_loader, reload_model=True , load_path = args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs), bn_adv_flag = True, bn_adv_momentum = args.bn_adv_momentum, data=data)
204 | else:
205 | simclr_model, _, _ = load_model(args, train_loader, reload_model=True , load_path = args.model_dir + suffix + '_epoch_{}.pt'.format(args.epochs), bn_adv_flag = False, bn_adv_momentum = args.bn_adv_momentum, data=data)
206 |
207 | test_log_file = open(log_dir + suffix + '.txt', "w")
208 | simclr_model = simclr_model.to(args.device)
209 | simclr_model.eval()
210 |
211 | ## Logistic Regression
212 | if args.dataset == "cifar100":
213 | n_classes = 100 # stl-10
214 | else:
215 | n_classes = 10
216 |
217 | model = LogisticRegression(simclr_model.n_features, n_classes)
218 | model = model.to(args.device)
219 |
220 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
221 | criterion = torch.nn.CrossEntropyLoss()
222 |
223 |
224 | best_acc = 0
225 | for epoch in range(args.logistic_epochs):
226 | loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer)
227 | print("Train Epoch [{}]\t Loss: {}\t Accuracy: {}".format(epoch, loss_epoch / len(train_loader), accuracy_epoch / len(train_loader)), file = test_log_file)
228 | test_log_file.flush()
229 | wandb.log({'Train/Loss': loss_epoch / len(train_loader),
230 | 'Train/ACC': accuracy_epoch / len(train_loader)})
231 |
232 | # final testing
233 | test_loss_epoch, test_accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer)
234 | test_current_acc = test_accuracy_epoch / len(test_loader)
235 | if test_current_acc > best_acc:
236 | best_acc = test_current_acc
237 | print("Test Epoch [{}]\t Loss: {}\t Accuracy: {}\t Best Accuracy: {}".format(epoch, test_loss_epoch / len(test_loader), test_current_acc, best_acc), file = test_log_file)
238 | wandb.log({'Test/Loss': test_loss_epoch / len(test_loader),
239 | 'Test/ACC': test_current_acc,
240 | 'Test/BestACC': best_acc})
241 | test_log_file.flush()
242 |
243 | if args.debug:
244 | break
245 | print("Final \t Best Accuracy: {}".format(epoch, best_acc), file = test_log_file)
246 | test_log_file.flush()
247 | if not os.path.isdir("checkpoint/" + args.dataset + '_eval/'):
248 | os.makedirs("checkpoint/" + args.dataset + '_eval/')
249 | save_model("checkpoint/" + args.dataset + '_eval/' + suffix, model, optimizer, 0)
250 |
251 |
252 |
253 | if __name__ == "__main__":
254 | main()
255 |
--------------------------------------------------------------------------------
/SimCLR/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import argparse
5 | import sys
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import wandb
9 | import torchvision.transforms as transforms
10 | from model import load_model, save_model
11 | from modules import NT_Xent
12 | from modules.transformations import TransformsSimCLR
13 | from utils import mask_correlated_samples
14 | from eval_knn import kNN
15 | sys.path.append('..')
16 | from set import *
17 | from vae import *
18 | from apex import amp
19 |
20 |
21 | parser = argparse.ArgumentParser(description=' Seen Testing Category Training')
22 | parser.add_argument('--batch_size', default=256, type=int,
23 | metavar='B', help='training batch size')
24 | parser.add_argument('--dim', default=512, type=int, help='CNN_embed_dim')
25 | parser.add_argument('--workers', default=4, type=int, help='workers')
26 | parser.add_argument('--epochs', default=300, type=int, help='epochs')
27 | parser.add_argument('--save_epochs', default=100, type=int, help='save epochs')
28 | parser.add_argument('--resnet', default="resnet18", type=str, help="resnet")
29 | parser.add_argument('--normalize', default=True, action='store_true', help='normalize')
30 | parser.add_argument('--projection_dim', default=64, type=int, help='projection_dim')
31 | parser.add_argument('--optimizer', default="Adam", type=str, help="optimizer")
32 | parser.add_argument('--weight_decay', default=1.0e-6, type=float, help='weight_decay')
33 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature')
34 | parser.add_argument('--model_path', default='log/', type=str,
35 | help='model save path')
36 | parser.add_argument('--model_dir', default='checkpoint/', type=str,
37 | help='model save path')
38 |
39 | parser.add_argument('--dataset', default='cifar10',
40 | help='[cifar10, cifar100]')
41 | parser.add_argument('--gpu', default='0', type=str,
42 | help='gpu device ids for CUDA_VISIBLE_DEVICES')
43 | parser.add_argument('--adv', default=False, action='store_true', help='adversarial exmaple')
44 | parser.add_argument('--eps', default=0.01, type=float, help='eps for adversarial')
45 | parser.add_argument('--bn_adv_momentum', default=0.01, type=float, help='batch norm momentum for advprop')
46 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for contrastive loss with adversarial example')
47 | parser.add_argument('--debug', default=False, action='store_true', help='debug mode')
48 | parser.add_argument('--vae_path',
49 | default='../results/vae_dim512_kl0.1_simclr/model_epoch92.pth',
50 | type=str, help='vae_path')
51 | parser.add_argument('--seed', default=1, type=int, help='seed')
52 | parser.add_argument("--amp", action="store_true",
53 | help="use 16-bit (mixed) precision through NVIDIA apex AMP")
54 | parser.add_argument("--opt_level", type=str, default="O1",
55 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
56 | "See details at https://nvidia.github.io/apex/amp.html")
57 | args = parser.parse_args()
58 | print(args)
59 | set_random_seed(args.seed)
60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
61 |
62 |
63 | def gen_adv(model, vae, x_i, criterion, optimizer):
64 | x_i = x_i.detach()
65 | h_i, z_i = model(x_i, adv=True)
66 |
67 | with torch.no_grad():
68 | z, gx, _, _ = vae(x_i)
69 | variable_bottle = Variable(z.detach(), requires_grad=True)
70 | adv_gx = vae(variable_bottle, True)
71 | x_j_adv = adv_gx + (x_i - gx).detach()
72 | h_j_adv, z_j_adv = model(x_j_adv, adv=True)
73 | tmp_loss = criterion(z_i, z_j_adv)
74 | if args.amp:
75 | with amp.scale_loss(tmp_loss, optimizer) as scaled_loss:
76 | scaled_loss.backward()
77 | else:
78 | tmp_loss.backward()
79 |
80 | with torch.no_grad():
81 | sign_grad = variable_bottle.grad.data.sign()
82 | variable_bottle.data = variable_bottle.data + args.eps * sign_grad
83 | adv_gx = vae(variable_bottle, True)
84 | x_j_adv = adv_gx + (x_i - gx).detach()
85 | x_j_adv.requires_grad = False
86 | x_j_adv.detach()
87 | return x_j_adv, gx
88 |
89 |
90 | def train(args, epoch, train_loader, model, vae, criterion, optimizer):
91 | model.train()
92 | loss_epoch = 0
93 | for step, ((x_i, x_j), _) in enumerate(train_loader):
94 |
95 | optimizer.zero_grad()
96 | x_i = x_i.to(args.device)
97 | x_j = x_j.to(args.device)
98 |
99 | # positive pair, with encoding
100 | h_i, z_i = model(x_i)
101 | if args.adv:
102 | x_j_adv, gx = gen_adv(model, vae, x_i, criterion, optimizer)
103 |
104 | optimizer.zero_grad()
105 | h_j, z_j = model(x_j)
106 | loss_og = criterion(z_i, z_j)
107 | if args.adv:
108 | _, z_j_adv = model(x_j_adv, adv=True)
109 | loss_adv = criterion(z_i, z_j_adv)
110 | loss = loss_og + args.alpha * loss_adv
111 | else:
112 | loss = loss_og
113 | loss_adv = loss_og
114 | if args.amp:
115 | with amp.scale_loss(loss, optimizer) as scaled_loss:
116 | scaled_loss.backward()
117 | else:
118 | loss.backward()
119 |
120 | optimizer.step()
121 |
122 | if step % 50 == 0:
123 | print(f"[Epoch]: {epoch} [{step}/{len(train_loader)}]\t Loss: {loss.item():.3f} Loss_og: {loss_og.item():.3f} Loss_adv: {loss_adv.item():.3f}")
124 |
125 | loss_epoch += loss.item()
126 | args.global_step += 1
127 |
128 | if args.debug:
129 | break
130 | if step % 10 == 0:
131 | wandb.log({'loss_og': loss_og.item(),
132 | 'loss_adv': loss_adv.item(),
133 | 'lr': optimizer.param_groups[0]['lr']})
134 | if args.global_step % 1000 == 0:
135 | if args.adv:
136 | reconst_images(x_i, gx, x_j_adv)
137 | return loss_epoch
138 |
139 |
140 | def main():
141 | args.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
142 |
143 | train_sampler = None
144 | if args.dataset == "cifar10":
145 | root = "../../data"
146 | train_dataset = torchvision.datasets.CIFAR10(
147 | root, download=True, transform=TransformsSimCLR()
148 | )
149 | data = 'non_imagenet'
150 | transform_test = transforms.Compose([
151 | transforms.Resize(size=32),
152 | transforms.ToTensor(),
153 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
154 | ])
155 | testset = torchvision.datasets.CIFAR10(root='../../data', train=False, download=True, transform=transform_test)
156 | vae = CVAE_cifar_withbn(128, args.dim)
157 | elif args.dataset == "cifar100":
158 | root = "../../data"
159 | train_dataset = torchvision.datasets.CIFAR100(
160 | root, download=True, transform=TransformsSimCLR()
161 | )
162 | data = 'non_imagenet'
163 | transform_test = transforms.Compose([
164 | transforms.Resize(size=32),
165 | transforms.ToTensor(),
166 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
167 | ])
168 | testset = torchvision.datasets.CIFAR100(root='../../data', train=False, download=True, transform=transform_test)
169 | vae = CVAE_cifar_withbn(128, args.dim)
170 | else:
171 | raise NotImplementedError
172 |
173 | train_loader = torch.utils.data.DataLoader(
174 | train_dataset,
175 | batch_size=args.batch_size,
176 | shuffle=(train_sampler is None),
177 | drop_last=True,
178 | num_workers=args.workers,
179 | sampler=train_sampler,
180 | )
181 | testloader = torch.utils.data.DataLoader(testset,
182 | batch_size=100, shuffle=False, num_workers=4)
183 |
184 | ndata = train_dataset.__len__()
185 | log_dir = "log/" + args.dataset + '_log/'
186 |
187 | if not os.path.isdir(log_dir):
188 | os.makedirs(log_dir)
189 |
190 | suffix = args.dataset + '_{}_batch_{}'.format(args.resnet, args.batch_size)
191 | if args.adv:
192 | suffix = suffix + '_alpha_{}_adv_eps_{}'.format(args.alpha, args.eps)
193 | model, optimizer, scheduler = load_model(args, train_loader, bn_adv_flag=True,
194 | bn_adv_momentum=args.bn_adv_momentum, data=data)
195 | else:
196 | model, optimizer, scheduler = load_model(args, train_loader, bn_adv_flag=False,
197 | bn_adv_momentum=args.bn_adv_momentum, data=data)
198 |
199 | vae.load_state_dict(torch.load(args.vae_path))
200 | vae.to(args.device)
201 | vae.eval()
202 | if args.amp:
203 | [model, vae], optimizer = amp.initialize(
204 | [model, vae], optimizer, opt_level=args.opt_level)
205 |
206 | suffix = suffix + '_proj_dim_{}'.format(args.projection_dim)
207 | suffix = suffix + '_bn_adv_momentum_{}_seed_{}'.format(args.bn_adv_momentum, args.seed)
208 | wandb.init(config=args, name=suffix.replace("_log/", ''))
209 |
210 | test_log_file = open(log_dir + suffix + '.txt', "w")
211 |
212 | if not os.path.isdir(args.model_dir):
213 | os.mkdir(args.model_dir)
214 | args.model_dir = args.model_dir + args.dataset + '/'
215 | if not os.path.isdir(args.model_dir):
216 | os.mkdir(args.model_dir)
217 |
218 | mask = mask_correlated_samples(args)
219 | criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)
220 |
221 | args.global_step = 0
222 | args.current_epoch = 0
223 | best_acc = 0
224 | for epoch in range(0, args.epochs):
225 | loss_epoch = train(args, epoch, train_loader, model, vae, criterion, optimizer)
226 | model.eval()
227 | if epoch > 10:
228 | scheduler.step()
229 | print('epoch: {}% \t (loss: {}%)'.format(epoch, loss_epoch / len(train_loader)), file=test_log_file)
230 | print('----------Evaluation---------')
231 | start = time.time()
232 | acc = kNN(epoch, model, train_loader, testloader, 200, args.temperature, ndata, low_dim=args.projection_dim)
233 | print("Evaluation Time: '{}'s".format(time.time() - start))
234 |
235 | if acc >= best_acc:
236 | print('Saving..')
237 | state = {
238 | 'model': model.state_dict(),
239 | 'acc': acc,
240 | 'epoch': epoch,
241 | }
242 | if not os.path.isdir(args.model_dir):
243 | os.mkdir(args.model_dir)
244 | torch.save(state, args.model_dir + suffix + '_best.t')
245 | best_acc = acc
246 | print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc))
247 | print('[Epoch]: {}'.format(epoch), file=test_log_file)
248 | print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc), file=test_log_file)
249 | wandb.log({'acc': acc})
250 | test_log_file.flush()
251 |
252 | args.current_epoch += 1
253 | if args.debug:
254 | break
255 | if epoch % 50 == 0:
256 | save_model(args.model_dir + suffix, model, optimizer, epoch)
257 |
258 | save_model(args.model_dir + suffix, model, optimizer, args.epochs)
259 |
260 |
261 | def reconst_images(x_i, gx, x_j_adv):
262 | grid_X = torchvision.utils.make_grid(x_i[32:96].data, nrow=8, padding=2, normalize=True)
263 | wandb.log({"X.jpg": [wandb.Image(grid_X)]}, commit=False)
264 | grid_GX = torchvision.utils.make_grid(gx[32:96].data, nrow=8, padding=2, normalize=True)
265 | wandb.log({"GX.jpg": [wandb.Image(grid_GX)]}, commit=False)
266 | grid_RX = torchvision.utils.make_grid((x_i[32:96] - gx[32:96]).data, nrow=8, padding=2, normalize=True)
267 | wandb.log({"RX.jpg": [wandb.Image(grid_RX)]}, commit=False)
268 | grid_AdvX = torchvision.utils.make_grid(x_j_adv[32:96].data, nrow=8, padding=2, normalize=True)
269 | wandb.log({"AdvX.jpg": [wandb.Image(grid_AdvX)]}, commit=False)
270 | grid_delta = torchvision.utils.make_grid((x_j_adv - x_i)[32:96].data, nrow=8, padding=2, normalize=True)
271 | wandb.log({"Delta.jpg": [wandb.Image(grid_delta)]}, commit=False)
272 | wandb.log({'l2_norm': torch.mean((x_j_adv - x_i).reshape(x_i.shape[0], -1).norm(dim=1)),
273 | 'linf_norm': torch.mean((x_j_adv - x_i).reshape(x_i.shape[0], -1).abs().max(dim=1)[0])
274 | }, commit=False)
275 |
276 |
277 | if __name__ == "__main__":
278 | main()
279 |
--------------------------------------------------------------------------------
/SimCLR/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from modules import SimCLR_BN
4 |
5 |
6 | def load_model(args, loader, reload_model=False, load_path = None, bn_adv_flag=False, bn_adv_momentum = 0.01, data='non_imagenet'):
7 |
8 | model = SimCLR_BN(args, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum, data = data)
9 |
10 | if reload_model:
11 | if os.path.isfile(load_path):
12 | model_fp = os.path.join(load_path)
13 | else:
14 | print("No file to load")
15 | return
16 | model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_fp, map_location=lambda storage, loc: storage).items()})
17 |
18 | #model = model.to(args.device)
19 | model.cuda()
20 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS
21 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
22 | optimizer, args.epochs, eta_min=0, last_epoch=-1
23 | )
24 | return model, optimizer, scheduler
25 |
26 |
27 | def save_model(model_dir, model, optimizer, epoch):
28 |
29 |
30 | # To save a DataParallel model generically, save the model.module.state_dict().
31 | # This way, you have the flexibility to load the model any way you want to any device you want.
32 | if isinstance(model, torch.nn.DataParallel):
33 | torch.save(model.module.state_dict(), model_dir + '_epoch_{}.pt'.format(epoch))
34 | else:
35 | torch.save(model.state_dict(), model_dir + '_epoch_{}.pt'.format(epoch))
36 |
--------------------------------------------------------------------------------
/SimCLR/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .simclr_BN import SimCLR_BN
2 | from .nt_xent import NT_Xent
3 | from .logistic_regression import LogisticRegression
4 | from .lars import LARS
5 |
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/lars.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/lars.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/logistic_regression.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/logistic_regression.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/nt_xent.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/nt_xent.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/resnet_BN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/resnet_BN.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/resnet_BN_imagenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/resnet_BN_imagenet.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/__pycache__/simclr_BN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/__pycache__/simclr_BN.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/lars.py:
--------------------------------------------------------------------------------
1 | """
2 | LARS: Layer-wise Adaptive Rate Scaling
3 |
4 | Converted from TensorFlow to PyTorch
5 | https://github.com/google-research/simclr/blob/master/lars_optimizer.py
6 | """
7 |
8 | import torch
9 | from torch.optim.optimizer import Optimizer, required
10 | import re
11 |
12 | EETA_DEFAULT = 0.001
13 |
14 | class LARS(Optimizer):
15 | """
16 | Layer-wise Adaptive Rate Scaling for large batch training.
17 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
18 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
19 | """
20 |
21 | def __init__(
22 | self,
23 | params,
24 | lr=required,
25 | momentum=0.9,
26 | use_nesterov=False,
27 | weight_decay=0.0,
28 | exclude_from_weight_decay=None,
29 | exclude_from_layer_adaptation=None,
30 | classic_momentum=True,
31 | eeta=EETA_DEFAULT,
32 | ):
33 | """Constructs a LARSOptimizer.
34 | Args:
35 | lr: A `float` for learning rate.
36 | momentum: A `float` for momentum.
37 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
38 | weight_decay: A `float` for weight decay.
39 | exclude_from_weight_decay: A list of `string` for variable screening, if
40 | any of the string appears in a variable's name, the variable will be
41 | excluded for computing weight decay. For example, one could specify
42 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
43 | from weight decay.
44 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
45 | for layer adaptation. If it is None, it will be defaulted the same as
46 | exclude_from_weight_decay.
47 | classic_momentum: A `boolean` for whether to use classic (or popular)
48 | momentum. The learning rate is applied during momeuntum update in
49 | classic momentum, but after momentum for popular momentum.
50 | eeta: A `float` for scaling of learning rate when computing trust ratio.
51 | name: The name for the scope.
52 | """
53 |
54 | self.epoch = 0
55 | defaults = dict(
56 | lr=lr,
57 | momentum=momentum,
58 | use_nesterov=use_nesterov,
59 | weight_decay=weight_decay,
60 | exclude_from_weight_decay=exclude_from_weight_decay,
61 | exclude_from_layer_adaptation=exclude_from_layer_adaptation,
62 | classic_momentum=classic_momentum,
63 | eeta=eeta,
64 | )
65 |
66 | super(LARS, self).__init__(params, defaults)
67 | self.lr = lr
68 | self.momentum = momentum
69 | self.weight_decay = weight_decay
70 | self.use_nesterov = use_nesterov
71 | self.classic_momentum = classic_momentum
72 | self.eeta = eeta
73 | self.exclude_from_weight_decay = exclude_from_weight_decay
74 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
75 | # arg is None.
76 | if exclude_from_layer_adaptation:
77 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
78 | else:
79 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
80 |
81 | def step(self, epoch=None, closure=None):
82 | loss = None
83 | if closure is not None:
84 | loss = closure()
85 |
86 | if epoch is None:
87 | epoch = self.epoch
88 | self.epoch += 1
89 |
90 | for group in self.param_groups:
91 | weight_decay = group["weight_decay"]
92 | momentum = group["momentum"]
93 | eeta = group["eeta"]
94 | lr = group["lr"]
95 |
96 | for p in group["params"]:
97 | if p.grad is None:
98 | continue
99 |
100 | param = p.data
101 | grad = p.grad.data
102 |
103 | param_state = self.state[p]
104 |
105 | # TODO: get param names
106 | # if self._use_weight_decay(param_name):
107 | grad += self.weight_decay * param
108 |
109 | if self.classic_momentum:
110 | trust_ratio = 1.0
111 |
112 | # TODO: get param names
113 | # if self._do_layer_adaptation(param_name):
114 | w_norm = torch.norm(param)
115 | g_norm = torch.norm(grad)
116 |
117 | device = g_norm.get_device()
118 | trust_ratio = torch.where(
119 | w_norm.ge(0),
120 | torch.where(g_norm.ge(0), (self.eeta * w_norm / g_norm), torch.Tensor([1.0]).to(device)),
121 | torch.Tensor([1.0]).to(device),
122 | ).item()
123 |
124 | scaled_lr = lr * trust_ratio
125 | if "momentum_buffer" not in param_state:
126 | next_v = param_state["momentum_buffer"] = torch.zeros_like(
127 | p.data
128 | )
129 | else:
130 | next_v = param_state["momentum_buffer"]
131 |
132 | next_v.mul_(momentum).add_(scaled_lr, grad)
133 | if self.use_nesterov:
134 | update = (self.momentum * next_v) + (scaled_lr * grad)
135 | else:
136 | update = next_v
137 |
138 | p.data.add_(-update)
139 | else:
140 | raise NotImplementedError
141 |
142 | return loss
143 |
144 | def _use_weight_decay(self, param_name):
145 | """Whether to use L2 weight decay for `param_name`."""
146 | if not self.weight_decay:
147 | return False
148 | if self.exclude_from_weight_decay:
149 | for r in self.exclude_from_weight_decay:
150 | if re.search(r, param_name) is not None:
151 | return False
152 | return True
153 |
154 | def _do_layer_adaptation(self, param_name):
155 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
156 | if self.exclude_from_layer_adaptation:
157 | for r in self.exclude_from_layer_adaptation:
158 | if re.search(r, param_name) is not None:
159 | return False
160 | return True
161 |
--------------------------------------------------------------------------------
/SimCLR/modules/logistic_regression.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class LogisticRegression(nn.Module):
4 |
5 | def __init__(self, n_features, n_classes):
6 | super(LogisticRegression, self).__init__()
7 |
8 | self.model = nn.Linear(n_features, n_classes)
9 |
10 | def forward(self, x):
11 | return self.model(x)
--------------------------------------------------------------------------------
/SimCLR/modules/nt_xent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pdb
4 |
5 | def gen_mask(k, feat_dim):
6 | mask = None
7 | for i in range(k):
8 | tmp_mask = torch.triu(torch.randint(0, 2, (feat_dim, feat_dim)), 1)
9 | tmp_mask = tmp_mask + torch.triu(1-tmp_mask,1).t()
10 | tmp_mask = tmp_mask.view(tmp_mask.shape[0], tmp_mask.shape[1],1)
11 | mask = tmp_mask if mask is None else torch.cat([mask,tmp_mask],2)
12 | return mask
13 |
14 |
15 | def entropy(prob):
16 | # assume m x m x k input
17 | return -torch.sum(prob*torch.log(prob),1)
18 |
19 |
20 | class NT_Xent(nn.Module):
21 |
22 | def __init__(self, batch_size, temperature, mask, device):
23 | super(NT_Xent, self).__init__()
24 | self.batch_size = batch_size
25 | self.temperature = temperature
26 | self.mask = mask
27 | self.device = device
28 |
29 | self.criterion = nn.CrossEntropyLoss(reduction="sum")
30 | self.similarity_f = nn.CosineSimilarity(dim=2)
31 |
32 | def forward(self, z_i, z_j):
33 | """
34 | We do not sample negative examples explicitly.
35 | Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
36 | """
37 | p1 = torch.cat((z_i, z_j), dim=0)
38 | sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature
39 |
40 |
41 | sim_i_j = torch.diag(sim, self.batch_size)
42 | sim_j_i = torch.diag(sim, -self.batch_size)
43 |
44 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2, 1)
45 | negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)
46 |
47 |
48 | labels = torch.zeros(self.batch_size * 2).to(self.device).long()
49 | logits = torch.cat((positive_samples, negative_samples), dim=1)
50 |
51 |
52 | loss = self.criterion(logits, labels)
53 | loss /= 2 * self.batch_size
54 |
55 | return loss
56 |
--------------------------------------------------------------------------------
/SimCLR/modules/resnet_BN.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 | from torch.autograd import Variable
14 |
15 | class Normalize(nn.Module):
16 |
17 | def __init__(self, power=2):
18 | super(Normalize, self).__init__()
19 | self.power = power
20 |
21 | def forward(self, x):
22 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
23 | out = x.div(norm)
24 | return out
25 |
26 | class MySequential(nn.Sequential):
27 | def forward(self, x, adv):
28 | for module in self._modules.values():
29 | x = module(x, adv=adv)
30 | return x
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, in_planes, planes, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01):
36 | super(BasicBlock, self).__init__()
37 | self.bn_adv_momentum = bn_adv_momentum
38 | self.bn_adv_flag = bn_adv_flag
39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
40 | self.bn1 = nn.BatchNorm2d(planes)
41 | if self.bn_adv_flag:
42 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
43 |
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
45 |
46 | self.bn2 = nn.BatchNorm2d(planes)
47 | if self.bn_adv_flag:
48 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
49 |
50 | self.shortcut = nn.Sequential()
51 | self.shortcut_bn = None
52 | self.shortcut_bn_adv = None
53 | if stride != 1 or in_planes != self.expansion*planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
56 | )
57 | self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes)
58 | if self.bn_adv_flag:
59 | self.shortcut_bn_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum)
60 |
61 | def forward(self, x, adv=False):
62 | if adv and self.bn_adv_flag:
63 | out = F.relu(self.bn1_adv(self.conv1(x)))
64 | out = self.conv2(out)
65 | out = self.bn2_adv(out)
66 | if self.shortcut_bn_adv:
67 | out += self.shortcut_bn_adv(self.shortcut(x))
68 | else:
69 | out += self.shortcut(x)
70 | else:
71 | out = F.relu(self.bn1(self.conv1(x)))
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 | if self.shortcut_bn:
75 | out += self.shortcut_bn(self.shortcut(x))
76 | else:
77 | out += self.shortcut(x)
78 |
79 | out = F.relu(out)
80 | return out
81 |
82 |
83 | class Bottleneck(nn.Module):
84 | expansion = 4
85 |
86 | def __init__(self, in_planes, planes, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01):
87 | super(Bottleneck, self).__init__()
88 | self.bn_adv_momentum = bn_adv_momentum
89 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
90 | self.bn_adv_flag = bn_adv_flag
91 |
92 | self.bn1 = nn.BatchNorm2d(planes)
93 | if self.bn_adv_flag:
94 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
95 |
96 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
97 |
98 | self.bn2 = nn.BatchNorm2d(planes)
99 | self.bn2 = nn.BatchNorm2d(planes)
100 | if self.bn_adv_flag:
101 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
102 |
103 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
104 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
105 | if self.bn_adv_flag:
106 | self.bn3_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum)
107 |
108 | self.shortcut = nn.Sequential()
109 | self.shortcut_bn = None
110 | self.shortcut_bn_adv = None
111 | if stride != 1 or in_planes != self.expansion*planes:
112 | self.shortcut = nn.Sequential(
113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
114 | )
115 | self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes)
116 | if self.bn_adv_flag:
117 | self.shortcut_bn_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum)
118 |
119 | def forward(self, x, adv=False):
120 |
121 | if adv and self.bn_adv_flag:
122 |
123 | out = F.relu(self.bn1_adv(self.conv1(x)))
124 | out = F.relu(self.bn2_adv(self.conv2(out)))
125 | out = self.bn3_adv(self.conv3(out))
126 | if self.shortcut_bn_adv:
127 | out += self.shortcut_bn_adv(self.shortcut(x))
128 | else:
129 | out += self.shortcut(x)
130 | else:
131 |
132 | out = F.relu(self.bn1(self.conv1(x)))
133 | out = F.relu(self.bn2(self.conv2(out)))
134 | out = self.bn3(self.conv3(out))
135 | if self.shortcut_bn:
136 | out += self.shortcut_bn(self.shortcut(x))
137 | else:
138 | out += self.shortcut(x)
139 |
140 | out = F.relu(out)
141 | return out
142 |
143 |
144 | class ResNetAdvProp_all(nn.Module):
145 | def __init__(self, block, num_blocks, pool_len =4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
146 | super(ResNetAdvProp_all, self).__init__()
147 | self.in_planes = 64
148 | self.bn_adv_momentum = bn_adv_momentum
149 | self.bn_adv_flag = bn_adv_flag
150 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
151 | self.bn_adv_flag = bn_adv_flag
152 |
153 | self.bn1 = nn.BatchNorm2d(64)
154 | if bn_adv_flag:
155 | self.bn1_adv = nn.BatchNorm2d(64, momentum = self.bn_adv_momentum)
156 |
157 |
158 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum)
159 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum)
160 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum)
161 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum)
162 | self.fc = nn.Linear(512*block.expansion, low_dim)
163 |
164 | self.pool_len = pool_len
165 | # for m in self.modules():
166 | # if isinstance(m, nn.Conv2d):
167 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
168 | # m.weight.data.normal_(0, math.sqrt(2. / n))
169 | # elif isinstance(m, nn.BatchNorm2d):
170 | # m.weight.data.fill_(1)
171 | # m.bias.data.zero_()
172 |
173 | def _make_layer(self, block, planes, num_blocks, stride, bn_adv_flag=False, bn_adv_momentum=0.01):
174 | strides = [stride] + [1]*(num_blocks-1)
175 | layers = []
176 | for stride in strides:
177 | layers.append(block(self.in_planes, planes, stride, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum))
178 | self.in_planes = planes * block.expansion
179 | return MySequential(*layers)
180 | #return layers
181 |
182 | def forward(self, x, adv = False):
183 | if adv and self.bn_adv_flag:
184 | out = F.relu(self.bn1_adv(self.conv1(x)))
185 | else:
186 | out = F.relu(self.bn1(self.conv1(x)))
187 |
188 | out = self.layer1(out, adv=adv)
189 | out = self.layer2(out, adv=adv)
190 | out = self.layer3(out, adv=adv)
191 | out = self.layer4(out, adv=adv)
192 |
193 | out = F.avg_pool2d(out, self.pool_len)
194 |
195 | out = out.view(out.size(0), -1)
196 |
197 | out = self.fc(out)
198 | return out
199 |
200 |
201 | def resnet18(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
202 | return ResNetAdvProp_all(BasicBlock, [2,2,2,2], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
203 |
204 | def resnet34(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
205 | return ResNetAdvProp_all(BasicBlock, [3,4,6,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
206 |
207 | def resnet50(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
208 | return ResNetAdvProp_all(Bottleneck, [3,4,6,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
209 |
210 | def resnet101(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
211 | return ResNetAdvProp_all(Bottleneck, [3,4,23,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
212 |
213 | def resnet152(pool_len = 4, low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
214 | return ResNetAdvProp_all(Bottleneck, [3,8,36,3], pool_len, low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
215 |
216 |
217 | def test():
218 | net = ResNet18()
219 | # y = net(Variable(torch.randn(1,3,32,32)))
220 | # pdb.set_trace()
221 | y = net(Variable(torch.randn(1,3,96,96)))
222 | # pdb.set_trace()
223 | print(y.size())
224 |
225 | # test()
226 |
--------------------------------------------------------------------------------
/SimCLR/modules/resnet_BN_imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch.autograd import Variable
6 | import pdb
7 |
8 |
9 | class MySequential(nn.Sequential):
10 | def forward(self, x, adv):
11 | for module in self._modules.values():
12 | x = module(x, adv=adv)
13 | return x
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | "3x3 convolution with padding"
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=1, bias=False)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, inplanes, planes, stride=1, downsample=False, expansion=0, bn_adv_flag=False, bn_adv_momentum=0.01):
25 | super(BasicBlock, self).__init__()
26 | self.bn_adv_momentum = bn_adv_momentum
27 | self.bn_adv_flag = bn_adv_flag
28 |
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = nn.BatchNorm2d(planes)
31 | if self.bn_adv_flag:
32 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 |
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | if self.bn_adv_flag:
38 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
39 |
40 | self.downsample = downsample
41 | if self.downsample:
42 | self.ds_conv1 = nn.Conv2d(inplanes, planes * expansion, kernel_size=1, stride=stride, bias=False)
43 | self.ds_bn1 = nn.BatchNorm2d(planes*expansion)
44 | self.ds_bn1_adv = nn.BatchNorm2d(planes*expansion)
45 | self.stride = stride
46 |
47 | def forward(self, x, adv = False):
48 | residual = x
49 | if adv and self.bn_adv_flag:
50 | out = self.conv1(x)
51 | out = self.bn1_adv(out)
52 | out = self.relu(out)
53 | out = self.conv2(out)
54 | out = self.bn2_adv(out)
55 | if self.downsample:
56 |
57 | residual = self.ds_bn1_adv(self.ds_conv1(x))
58 | out += residual
59 | out = self.relu(out)
60 | else:
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 | out = self.conv2(out)
65 | out = self.bn2(out)
66 | if self.downsample:
67 | residual = self.ds_bn1(self.ds_conv1(x))
68 | out += residual
69 | out = self.relu(out)
70 |
71 | return out
72 |
73 |
74 | class Bottleneck(nn.Module):
75 | expansion = 4
76 |
77 | def __init__(self, inplanes, planes, stride=1, downsample=False, expansion=0, bn_adv_flag=False, bn_adv_momentum=0.01):
78 | super(Bottleneck, self).__init__()
79 | self.bn_adv_flag = bn_adv_flag
80 | self.bn_adv_momentum = bn_adv_momentum
81 |
82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
83 | self.bn1 = nn.BatchNorm2d(planes)
84 | if self.bn_adv_flag:
85 | self.bn1_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
86 |
87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
88 | padding=1, bias=False)
89 | self.bn2 = nn.BatchNorm2d(planes)
90 | if self.bn_adv_flag:
91 | self.bn2_adv = nn.BatchNorm2d(planes, momentum = self.bn_adv_momentum)
92 |
93 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
94 | self.bn3 = nn.BatchNorm2d(planes * 4)
95 | if self.bn_adv_flag:
96 | self.bn3_adv = nn.BatchNorm2d(self.expansion*planes, momentum = self.bn_adv_momentum)
97 |
98 | self.relu = nn.ReLU(inplace=True)
99 | self.downsample = downsample
100 | if self.downsample:
101 | self.ds_conv1 = nn.Conv2d(inplanes, planes * expansion, kernel_size=1, stride=stride, bias=False)
102 | self.ds_bn1 = nn.BatchNorm2d(planes * expansion)
103 | self.ds_bn1_adv = nn.BatchNorm2d(planes * expansion)
104 | self.stride = stride
105 |
106 | def forward(self, x, adv = False):
107 | residual = x
108 | if adv and self.bn_adv_flag:
109 | out = self.conv1(x)
110 | out = self.bn1_adv(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv2(out)
114 | out = self.bn2_adv(out)
115 | out = self.relu(out)
116 |
117 | out = self.conv3(out)
118 | out = self.bn3_adv(out)
119 |
120 | if self.downsample:
121 |
122 | residual = self.ds_bn1_adv(self.ds_conv1(x))
123 |
124 | out += residual
125 | out = self.relu(out)
126 | else:
127 | out = self.conv1(x)
128 | out = self.bn1(out)
129 | out = self.relu(out)
130 |
131 | out = self.conv2(out)
132 | out = self.bn2(out)
133 | out = self.relu(out)
134 |
135 | out = self.conv3(out)
136 | out = self.bn3(out)
137 |
138 | if self.downsample:
139 |
140 | residual = self.ds_bn1(self.ds_conv1(x))
141 |
142 | out += residual
143 | out = self.relu(out)
144 | return out
145 |
146 |
147 | class ResNetAdvProp_imgnet(nn.Module):
148 |
149 | def __init__(self, block, layers, low_dim=128, is_feature=None, bn_adv_flag=False, bn_adv_momentum=0.01):
150 | super(ResNetAdvProp_imgnet, self).__init__()
151 | self.inplanes = 64
152 | self.bn_adv_flag = bn_adv_flag
153 | self.bn_adv_momentum = bn_adv_momentum
154 |
155 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
156 | bias=False)
157 | self.bn1 = nn.BatchNorm2d(64)
158 | if bn_adv_flag:
159 | self.bn1_adv = nn.BatchNorm2d(64, momentum = self.bn_adv_momentum)
160 |
161 | self.relu = nn.ReLU(inplace=True)
162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
163 | self.layer1 = self._make_layer(block, 64, layers[0], bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
165 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
166 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bn_adv_flag = self.bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
167 | self.avgpool = nn.AdaptiveAvgPool2d(1)
168 | self.fc = nn.Linear(512 * block.expansion, low_dim)
169 | self.dropout = nn.Dropout(p=0.5)
170 |
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | elif isinstance(m, nn.BatchNorm2d):
176 | m.weight.data.fill_(1)
177 | m.bias.data.zero_()
178 |
179 |
180 | def _make_layer(self, block, planes, blocks, stride=1, bn_adv_flag=False, bn_adv_momentum=0.01):
181 | downsample = False
182 | if stride != 1 or self.inplanes != planes * block.expansion:
183 | downsample = True
184 | layers = []
185 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, expansion=block.expansion , bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum))
186 | self.inplanes = planes * block.expansion
187 | for i in range(1, blocks):
188 | layers.append(block(self.inplanes, planes, bn_adv_flag=bn_adv_flag, bn_adv_momentum = bn_adv_momentum))
189 |
190 | return MySequential(*layers)
191 |
192 | def forward(self, x, adv = False):
193 | x = self.conv1(x)
194 | if adv and self.bn_adv_flag:
195 | out = self.bn1_adv(x)
196 | else:
197 | out = self.bn1(x)
198 |
199 | x = self.relu(x)
200 | x = self.maxpool(x)
201 |
202 | x = self.layer1(x, adv=adv)
203 | x = self.layer2(x, adv=adv)
204 | x = self.layer3(x, adv=adv)
205 | x = self.layer4(x, adv=adv)
206 |
207 | x = self.avgpool(x)
208 | x = x.view(x.size(0), -1)
209 | x = self.fc(x)
210 | return x
211 |
212 |
213 | def resnet18_imagenet(low_dim=128, bn_adv_flag=False,bn_adv_momentum=0.01):
214 | return ResNetAdvProp_imgnet(BasicBlock, [2,2,2,2], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
215 |
216 | def resnet34_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
217 | return ResNetAdvProp_imgnet(BasicBlock, [3,4,6,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
218 |
219 | def resnet50_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
220 | return ResNetAdvProp_imgnet(Bottleneck, [3,4,6,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
221 |
222 | def resnet101_imagenet( low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
223 | return ResNetAdvProp_imgnet(Bottleneck, [3,4,23,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
224 |
225 | def resnet152_imagenet(low_dim=128, bn_adv_flag=False, bn_adv_momentum=0.01):
226 | return ResNetAdvProp_imgnet(Bottleneck, [3,8,36,3], low_dim, bn_adv_flag=bn_adv_flag, bn_adv_momentum=bn_adv_momentum)
227 |
228 |
229 | def test():
230 | net = resnet50()
231 | # y = net(Variable(torch.randn(1,3,32,32)))
232 | # pdb.set_trace()
233 | y = net(Variable(torch.randn(1,3,224,224)), adv=True)
234 | # pdb.set_trace()
235 | print(y.size())
236 | #test()
237 |
238 |
--------------------------------------------------------------------------------
/SimCLR/modules/simclr_BN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torchvision
3 | from .resnet_BN import *
4 | from .resnet_BN_imagenet import *
5 |
6 |
7 | class Identity(nn.Module):
8 | def __init__(self):
9 | super(Identity, self).__init__()
10 |
11 | def forward(self, x):
12 | return x
13 |
14 |
15 | class SimCLR_BN(nn.Module):
16 | """
17 | We opt for simplicity and adopt the commonly used ResNet (He et al., 2016) to obtain hi = f(x ̃i) = ResNet(x ̃i) where hi ∈ Rd is the output after the average pooling layer.
18 | """
19 |
20 | def __init__(self, args, bn_adv_flag=False, bn_adv_momentum = 0.01, data='non_imagenet'):
21 | super(SimCLR_BN, self).__init__()
22 |
23 | self.args = args
24 | self.bn_adv_flag = bn_adv_flag
25 | self.bn_adv_momentum = bn_adv_momentum
26 | if data == 'imagenet':
27 | self.encoder = self.get_imagenet_resnet(args.resnet)
28 | else:
29 | self.encoder = self.get_resnet(args.resnet)
30 |
31 | self.n_features = self.encoder.fc.in_features # get dimensions of fc layer
32 | self.encoder.fc = Identity() # remove fully-connected layer after pooling layer
33 |
34 | self.projector = nn.Sequential(
35 | nn.Linear(self.n_features, self.n_features),
36 | nn.ReLU(),
37 | nn.Linear(self.n_features, args.projection_dim),
38 | )
39 |
40 | def get_resnet(self, name):
41 | resnets = {
42 | "resnet18": resnet18(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
43 | "resnet34": resnet34(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
44 | "resnet50": resnet50(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
45 | "resnet101": resnet101(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
46 | "resnet152": resnet152(pool_len=4, bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
47 | }
48 | if name not in resnets.keys():
49 | raise KeyError(f"{name} is not a valid ResNet version")
50 | return resnets[name]
51 |
52 | def get_imagenet_resnet(self, name):
53 | resnets = {
54 | "resnet18": resnet18_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
55 | "resnet34": resnet34_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
56 | "resnet50": resnet50_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
57 | "resnet101": resnet101_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
58 | "resnet152": resnet152_imagenet(bn_adv_flag=self.bn_adv_flag, bn_adv_momentum=self.bn_adv_momentum),
59 | }
60 | if name not in resnets.keys():
61 | raise KeyError(f"{name} is not a valid ResNet version")
62 | return resnets[name]
63 |
64 | def forward(self, x, adv=False):
65 | h = self.encoder(x, adv=adv)
66 | z = self.projector(h)
67 |
68 | if self.args.normalize:
69 | z = nn.functional.normalize(z, dim=1)
70 | return h, z
71 |
--------------------------------------------------------------------------------
/SimCLR/modules/transformations/__init__.py:
--------------------------------------------------------------------------------
1 | from .simclr import TransformsSimCLR
2 | from .simclr import TransformsSimCLR_imagenet
3 |
--------------------------------------------------------------------------------
/SimCLR/modules/transformations/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/transformations/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/transformations/__pycache__/simclr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/modules/transformations/__pycache__/simclr.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/modules/transformations/simclr.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from PIL import ImageFilter
3 | import random
4 |
5 |
6 | class GaussianBlur(object):
7 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
8 |
9 | def __init__(self, sigma=[.1, 2.]):
10 | self.sigma = sigma
11 |
12 | def __call__(self, x):
13 | sigma = random.uniform(self.sigma[0], self.sigma[1])
14 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
15 | return x
16 |
17 |
18 | class TransformsSimCLR:
19 | """
20 | A stochastic data augmentation module that transforms any given data example randomly
21 | resulting in two correlated views of the same example,
22 | denoted x ̃i and x ̃j, which we consider as a positive pair.
23 | """
24 |
25 | def __init__(self, size=32):
26 | s = 1
27 | color_jitter = torchvision.transforms.ColorJitter(
28 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
29 | )
30 | normalize = torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5],
31 | std=[0.5, 0.5, 0.5])
32 | self.train_transform = torchvision.transforms.Compose(
33 | [
34 | torchvision.transforms.RandomResizedCrop(size=size),
35 | #torchvision.transforms.RandomResizedCrop(size=96),
36 | torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability
37 | torchvision.transforms.RandomApply([color_jitter], p=0.8),
38 | torchvision.transforms.RandomGrayscale(p=0.2),
39 | torchvision.transforms.ToTensor(),
40 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
41 | ]
42 | )
43 |
44 | def __call__(self, x):
45 | return self.train_transform(x), self.train_transform(x)
46 |
47 |
48 | class TransformsSimCLR_imagenet:
49 | """
50 | A stochastic data augmentation module that transforms any given data example randomly
51 | resulting in two correlated views of the same example,
52 | denoted x ̃i and x ̃j, which we consider as a positive pair.
53 | """
54 |
55 | def __init__(self, size=224):
56 | s = 0.5
57 | color_jitter = torchvision.transforms.ColorJitter(
58 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
59 | )
60 |
61 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 | self.train_transform = torchvision.transforms.Compose(
64 | [
65 | torchvision.transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
66 | torchvision.transforms.RandomApply([color_jitter], p=0.8),
67 | torchvision.transforms.RandomGrayscale(p=0.2),
68 | torchvision.transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
69 | torchvision.transforms.RandomHorizontalFlip(),
70 | torchvision.transforms.ToTensor(),
71 | normalize
72 | ]
73 | )
74 |
75 | def __call__(self, x):
76 | return self.train_transform(x), self.train_transform(x)
77 |
78 |
--------------------------------------------------------------------------------
/SimCLR/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .masks import mask_correlated_samples
2 | #from .yaml_config_hook import post_config_hook
3 | #from .filestorage import CustomFileStorageObserver
--------------------------------------------------------------------------------
/SimCLR/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/utils/__pycache__/masks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kai-wen-yang/IDAA/20c81fca963003cb0defcab91a95400e2974c0a3/SimCLR/utils/__pycache__/masks.cpython-38.pyc
--------------------------------------------------------------------------------
/SimCLR/utils/masks.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def mask_correlated_samples(args):
4 | mask = torch.ones((args.batch_size * 2, args.batch_size * 2), dtype=bool)
5 | mask = mask.fill_diagonal_(0)
6 | for i in range(args.batch_size):
7 | mask[i, args.batch_size + i] = 0
8 | mask[args.batch_size + i, i] = 0
9 | return mask
10 |
--------------------------------------------------------------------------------
/set.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import os.path as osp
5 | import random
6 | import torch
7 | import numpy as np
8 | from torch.optim.optimizer import Optimizer
9 | import math
10 |
11 |
12 | def mkdir_if_missing(dirname):
13 | """Create dirname if it is missing."""
14 | if not osp.exists(dirname):
15 | try:
16 | os.makedirs(dirname)
17 | except OSError as e:
18 | if e.errno != errno.EEXIST:
19 | raise
20 |
21 |
22 | def set_random_seed(seed):
23 | torch.cuda.manual_seed_all(seed)
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed_all(seed)
28 |
29 |
30 | class Logger:
31 | """Write console output to external text file.
32 |
33 | Imported from ``_
34 |
35 | Args:
36 | fpath (str): directory to save logging file.
37 |
38 | Examples::
39 | >>> import sys
40 | >>> import os.path as osp
41 | >>> save_dir = 'output/experiment-1'
42 | >>> log_name = 'train.log'
43 | >>> sys.stdout = Logger(osp.join(save_dir, log_name))
44 | """
45 |
46 | def __init__(self, fpath=None):
47 | self.console = sys.stdout
48 | self.file = None
49 | if fpath is not None:
50 | mkdir_if_missing(osp.dirname(fpath))
51 | self.file = open(fpath, 'w')
52 |
53 | def __del__(self):
54 | self.close()
55 |
56 | def __enter__(self):
57 | pass
58 |
59 | def __exit__(self, *args):
60 | self.close()
61 |
62 | def write(self, msg):
63 | self.console.write(msg)
64 | if self.file is not None:
65 | self.file.write(msg)
66 |
67 | def flush(self):
68 | self.console.flush()
69 | if self.file is not None:
70 | self.file.flush()
71 | os.fsync(self.file.fileno())
72 |
73 | def close(self):
74 | self.console.close()
75 | if self.file is not None:
76 | self.file.close()
77 |
78 |
79 | def setup_logger(output=None):
80 | if output is None:
81 | return
82 |
83 | if output.endswith('.txt') or output.endswith('.log'):
84 | fpath = output
85 | else:
86 | fpath = osp.join(output, 'log.txt')
87 |
88 | if osp.exists(fpath):
89 | # make sure the existing log file is not over-written
90 | fpath += time.strftime('-%Y-%m-%d-%H-%M-%S')
91 |
92 | sys.stdout = Logger(fpath)
93 |
94 |
95 | def accuracy(output, target, topk=(1,)):
96 |
97 | maxk = max(topk)
98 | batch_size = target.size(0)
99 |
100 | _, pred = output.topk(maxk, 1, True, True)
101 | pred = pred.t()
102 | correct = pred.eq(target.view(1, -1).expand_as(pred))
103 |
104 | res = []
105 | for k in topk:
106 | correct_k = correct[:k].contiguous().view(-1).float().sum(0)
107 | res.append(correct_k.mul_(100.0 / batch_size))
108 |
109 | if len(res) == 1:
110 | return res[0]
111 | else:
112 | return (res[0], res[1], correct[0], pred[0])
113 |
114 |
115 | class AdamW(Optimizer):
116 | """Implements Adam algorithm.
117 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
118 | Arguments:
119 | params (iterable): iterable of parameters to optimize or dicts defining
120 | parameter groups
121 | lr (float, optional): learning rate (default: 1e-3)
122 | betas (Tuple[float, float], optional): coefficients used for computing
123 | running averages of gradient and its square (default: (0.9, 0.999))
124 | eps (float, optional): term added to the denominator to improve
125 | numerical stability (default: 1e-8)
126 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
127 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
128 | algorithm from the paper `On the Convergence of Adam and Beyond`_
129 | .. _Adam\: A Method for Stochastic Optimization:
130 | https://arxiv.org/abs/1412.6980
131 | .. _On the Convergence of Adam and Beyond:
132 | https://openreview.net/forum?id=ryQu7f-RZ
133 | """
134 |
135 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
136 | weight_decay=0, amsgrad=False):
137 | if not 0.0 <= lr:
138 | raise ValueError("Invalid learning rate: {}".format(lr))
139 | if not 0.0 <= eps:
140 | raise ValueError("Invalid epsilon value: {}".format(eps))
141 | if not 0.0 <= betas[0] < 1.0:
142 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
143 | if not 0.0 <= betas[1] < 1.0:
144 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
145 | defaults = dict(lr=lr, betas=betas, eps=eps,
146 | weight_decay=weight_decay, amsgrad=amsgrad)
147 | super(AdamW, self).__init__(params, defaults)
148 |
149 | def __setstate__(self, state):
150 | super(AdamW, self).__setstate__(state)
151 | for group in self.param_groups:
152 | group.setdefault('amsgrad', False)
153 |
154 | def step(self, closure=None):
155 | """Performs a single optimization step.
156 | Arguments:
157 | closure (callable, optional): A closure that reevaluates the model
158 | and returns the loss.
159 | """
160 | loss = None
161 | if closure is not None:
162 | loss = closure()
163 |
164 | for group in self.param_groups:
165 | for p in group['params']:
166 | if p.grad is None:
167 | continue
168 | grad = p.grad.data
169 | if grad.is_sparse:
170 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
171 | amsgrad = group['amsgrad']
172 |
173 | state = self.state[p]
174 |
175 | # State initialization
176 | if len(state) == 0:
177 | state['step'] = 0
178 | # Exponential moving average of gradient values
179 | state['exp_avg'] = torch.zeros_like(p.data)
180 | # Exponential moving average of squared gradient values
181 | state['exp_avg_sq'] = torch.zeros_like(p.data)
182 | if amsgrad:
183 | # Maintains max of all exp. moving avg. of sq. grad. values
184 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
185 |
186 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
187 | if amsgrad:
188 | max_exp_avg_sq = state['max_exp_avg_sq']
189 | beta1, beta2 = group['betas']
190 |
191 | state['step'] += 1
192 |
193 | # if group['weight_decay'] != 0:
194 | # grad = grad.add(group['weight_decay'], p.data)
195 |
196 | # Decay the first and second moment running average coefficient
197 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
198 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
199 | if amsgrad:
200 | # Maintains the maximum of all 2nd moment running avg. till now
201 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
202 | # Use the max. for normalizing running avg. of gradient
203 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
204 | else:
205 | denom = exp_avg_sq.sqrt().add_(group['eps'])
206 |
207 | bias_correction1 = 1 - beta1 ** state['step']
208 | bias_correction2 = 1 - beta2 ** state['step']
209 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
210 |
211 | # p.data.addcdiv_(-step_size, exp_avg, denom)
212 | p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
213 |
214 | return loss
215 |
216 |
217 | class AverageMeter(object):
218 | """Computes and stores the average and current value
219 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
220 | """
221 |
222 | def __init__(self):
223 | self.reset()
224 |
225 | def reset(self):
226 | self.val = 0
227 | self.avg = 0
228 | self.sum = 0
229 | self.count = 0
230 |
231 | def update(self, val, n=1):
232 | self.val = val
233 | self.sum += val * n
234 | self.count += n
235 | self.avg = self.sum / self.count
236 |
237 |
--------------------------------------------------------------------------------
/train_vae.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | from tqdm import tqdm
9 | from copy import deepcopy
10 | import torchvision
11 | import torchvision.transforms as transforms
12 | import wandb
13 | import os
14 | import time
15 | import argparse
16 | import datetime
17 | from torch.autograd import Variable
18 | import pdb
19 | import sys
20 | import torch.autograd as autograd
21 | import torchvision.models as models
22 | sys.path.append('.')
23 |
24 | from vae import *
25 | from set import *
26 | from apex import amp
27 |
28 |
29 | def reconst_images(batch_size=64, batch_num=1, dataloader=None, model=None):
30 | cifar10_dataloader = dataloader
31 | model.eval()
32 | with torch.no_grad():
33 | for batch_idx, (X, y) in enumerate(cifar10_dataloader):
34 | if batch_idx >= batch_num:
35 | break
36 | else:
37 | X, y = X.cuda(), y.cuda().view(-1, )
38 | _, gx, _, _ = model(X)
39 |
40 | grid_X = torchvision.utils.make_grid(X[:batch_size].data, nrow=8, padding=2, normalize=True)
41 | wandb.log({"_Batch_{batch}_X.jpg".format(batch=batch_idx): [
42 | wandb.Image(grid_X)]}, commit=False)
43 | grid_Xi = torchvision.utils.make_grid(gx[:batch_size].data, nrow=8, padding=2, normalize=True)
44 | wandb.log({"_Batch_{batch}_GX.jpg".format(batch=batch_idx): [
45 | wandb.Image(grid_Xi)]}, commit=False)
46 | grid_X_Xi = torchvision.utils.make_grid((X[:batch_size] - gx[:batch_size]).data, nrow=8, padding=2,
47 | normalize=True)
48 | wandb.log({"_Batch_{batch}_RX.jpg".format(batch=batch_idx): [
49 | wandb.Image(grid_X_Xi)]}, commit=False)
50 | print('reconstruction complete!')
51 |
52 |
53 | def test(epoch, model, testloader):
54 | # set model as testing mode
55 | model.eval()
56 | acc_gx_avg = AverageMeter()
57 | acc_rx_avg = AverageMeter()
58 |
59 | with torch.no_grad():
60 | for batch_idx, (x, y) in enumerate(testloader):
61 | # distribute data to device
62 | x, y = x.cuda(), y.cuda().view(-1, )
63 | bs = x.size(0)
64 | norm = torch.norm(torch.abs(x.view(bs, -1)), p=2, dim=1)
65 | _, gx, _, _ = model(x)
66 | acc_gx = 1 - F.mse_loss(torch.div(gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \
67 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \
68 | reduction='sum') / bs
69 | acc_rx = 1 - F.mse_loss(torch.div(x - gx, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \
70 | torch.div(x, norm.unsqueeze(1).unsqueeze(2).unsqueeze(3)), \
71 | reduction='sum') / bs
72 |
73 | acc_gx_avg.update(acc_gx.data.item(), bs)
74 | acc_rx_avg.update(acc_rx.data.item(), bs)
75 |
76 | wandb.log({'acc_gx_avg': acc_gx_avg.avg, \
77 | 'acc_rx_avg': acc_rx_avg.avg}, commit=False)
78 | # plot progress
79 | print("\n| Validation Epoch #%d\t\tRec_gx: %.4f Rec_rx: %.4f " % (epoch, acc_gx_avg.avg, acc_rx_avg.avg))
80 | reconst_images(batch_size=64, batch_num=2, dataloader=testloader, model=model)
81 | torch.save(model.state_dict(),
82 | os.path.join(args.save_dir, 'model_epoch{}.pth'.format(epoch + 1))) # save motion_encoder
83 | print("Epoch {} model saved!".format(epoch + 1))
84 |
85 |
86 | def main(args):
87 | setup_logger(args.save_dir)
88 | use_cuda = torch.cuda.is_available()
89 | print('\n[Phase 1] : Data Preparation')
90 |
91 | if args.dataset == 'imagenet':
92 | size = 224
93 | normalizer = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
94 | model = CVAE_imagenet_withbn(128, args.dim)
95 | p_blur = 0.5
96 | else:
97 | size = 32
98 | normalizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
99 | model = CVAE_cifar_withbn(128, args.dim)
100 | p_blur = 0.0
101 |
102 | if args.mode=='simclr':
103 | print('\nData Augmentation: SimCLR')
104 | s = 1
105 | color_jitter = transforms.ColorJitter(
106 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
107 | )
108 | transform_train = transforms.Compose(
109 | [
110 | transforms.RandomResizedCrop(size=size),
111 | transforms.RandomHorizontalFlip(), # with 0.5 probability
112 | transforms.RandomApply([color_jitter], p=0.8),
113 | transforms.RandomGrayscale(p=0.2),
114 | transforms.ToTensor(),
115 | normalizer
116 | ]
117 | )
118 | elif args.mode=='simsiam':
119 | print('\nData Augmentation: SimSiam')
120 | transform_train = transforms.Compose([
121 | transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
122 | transforms.RandomHorizontalFlip(),
123 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
124 | transforms.RandomGrayscale(p=0.2),
125 | transforms.RandomApply([transforms.GaussianBlur(kernel_size=size // 20 * 2 + 1, sigma=(0.1, 2.0))], p=p_blur),
126 | transforms.ToTensor(),
127 | normalizer
128 | ])
129 | else:
130 | print('\nData Augmentation: Normal')
131 | transform_train = transforms.Compose([
132 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
133 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
134 | transforms.RandomGrayscale(p=0.2),
135 | transforms.RandomHorizontalFlip(),
136 | transforms.ToTensor(),
137 | normalizer
138 | ])
139 | if args.dataset == 'cifar10':
140 | print("| Preparing CIFAR-10 dataset...")
141 | sys.stdout.write("| ")
142 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
143 | elif args.dataset == 'cifar100':
144 | print("| Preparing CIFAR-100 dataset...")
145 | sys.stdout.write("| ")
146 | trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
147 | elif args.dataset == 'imagenet':
148 | print("| Preparing imagenet dataset...")
149 | sys.stdout.write("| ")
150 | root='/gpub/imagenet_raw'
151 | train_path = os.path.join(root, 'train')
152 | trainset = datasets.ImageFolder(root=train_path, transform=transform_train)
153 |
154 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4,
155 | drop_last=True)
156 | # Model
157 | print('\n[Phase 2] : Model setup')
158 | if use_cuda:
159 | model.cuda()
160 | cudnn.benchmark = True
161 |
162 | optimizer = AdamW([
163 | {'params': model.parameters()},
164 | ], lr=args.lr, betas=(0.0, 0.9))
165 |
166 | scheduler = optim.lr_scheduler.LambdaLR(
167 | optimizer, lambda epoch: 1 - epoch / args.epochs)
168 |
169 | if args.amp:
170 | model, optimizer = amp.initialize(
171 | model, optimizer, opt_level=args.opt_level)
172 |
173 | print('\n[Phase 3] : Training model')
174 | print('| Training Epochs = ' + str(args.epochs))
175 | print('| Initial Learning Rate = ' + str(args.lr))
176 |
177 | start_epoch = 1
178 | for epoch in range(start_epoch, start_epoch + args.epochs):
179 | model.train()
180 |
181 | loss_avg = AverageMeter()
182 | loss_rec = AverageMeter()
183 | loss_kl = AverageMeter()
184 |
185 | print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, optimizer.param_groups[0]['lr']))
186 | for batch_idx, (x, y) in enumerate(trainloader):
187 | x, y = x.cuda(), y.cuda().view(-1, )
188 | x, y = Variable(x), Variable(y)
189 | bs = x.size(0)
190 |
191 | _, gx, mu, logvar = model(x)
192 | optimizer.zero_grad()
193 | l_rec = F.mse_loss(x, gx)
194 | l_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
195 | l_kl /= bs * 3 * args.dim
196 | loss = l_rec + args.kl * l_kl
197 |
198 | if args.amp:
199 | with amp.scale_loss(loss, optimizer) as scaled_loss:
200 | scaled_loss.backward()
201 | else:
202 | loss.backward()
203 |
204 | optimizer.step()
205 |
206 | loss_avg.update(loss.data.item(), bs)
207 | loss_rec.update(l_rec.data.item(), bs)
208 | loss_kl.update(l_kl.data.item(), bs)
209 |
210 | n_iter = (epoch - 1) * len(trainloader) + batch_idx
211 | wandb.log({'loss': loss_avg.avg, \
212 | 'loss_rec': loss_rec.avg, \
213 | 'loss_kl': loss_kl.avg, \
214 | 'lr': optimizer.param_groups[0]['lr']}, step=n_iter)
215 | if (batch_idx + 1) % 30 == 0:
216 | sys.stdout.write('\r')
217 | sys.stdout.write(
218 | '| Epoch [%3d/%3d] Iter[%3d/%3d]\t\t Loss_rec: %.4f Loss_kl: %.4f'
219 | % (epoch, args.epochs, batch_idx + 1,
220 | len(trainloader), loss_rec.avg, loss_kl.avg))
221 | scheduler.step()
222 | test(epoch, model, trainloader)
223 | wandb.finish()
224 |
225 |
226 | if __name__ == '__main__':
227 | parser = argparse.ArgumentParser(description='VAE Training')
228 | parser.add_argument('--lr', default=5e-4, type=float, help='learning_rate')
229 | parser.add_argument('--save_dir', default='./results/vae_cifar10_simclr', type=str, help='save_dir')
230 | parser.add_argument('--seed', default=666, type=int, help='seed')
231 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100/imagenet]')
232 | parser.add_argument('--epochs', default=300, type=int, help='training_epochs')
233 | parser.add_argument('--batch_size', default=128, type=int, help='batch_size')
234 | parser.add_argument('--dim', default=128, type=int, help='CNN_embed_dim')
235 | parser.add_argument('--kl', default=0.1, type=float, help='kl weight')
236 | parser.add_argument('--mode', default='normal', type=str, help='augmentation mode')
237 | parser.add_argument("--amp", action="store_true",
238 | help="use 16-bit (mixed) precision through NVIDIA apex AMP")
239 | parser.add_argument("--opt_level", type=str, default="O1",
240 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
241 | "See details at https://nvidia.github.io/apex/amp.html")
242 | args = parser.parse_args()
243 | wandb.init(config=args, name=args.save_dir.replace("./results/", ''))
244 | set_random_seed(args.seed)
245 | main(args)
246 |
--------------------------------------------------------------------------------
/vae.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import abc
3 | import os
4 | import math
5 | import pdb
6 | import numpy as np
7 | import logging
8 | import torch
9 | import torch.utils.data
10 | from torch import nn
11 | from torch.nn import init
12 | from torch.nn import functional as F
13 | from torch.autograd import Variable
14 | import torchvision.models as models
15 | import torch.nn.functional as F
16 | from torch.autograd import Function, Variable
17 |
18 |
19 | class ResBlock(nn.Module):
20 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False):
21 | super(ResBlock, self).__init__()
22 |
23 | if mid_channels is None:
24 | mid_channels = out_channels
25 |
26 | layers = [
27 | nn.LeakyReLU(),
28 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
29 | nn.LeakyReLU(),
30 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)]
31 | if bn:
32 | layers.insert(2, nn.BatchNorm2d(out_channels))
33 | self.convs = nn.Sequential(*layers)
34 |
35 | def forward(self, x):
36 | return x + self.convs(x)
37 |
38 |
39 | class AbstractAutoEncoder(nn.Module):
40 | __metaclass__ = abc.ABCMeta
41 |
42 | @abc.abstractmethod
43 | def encode(self, x):
44 | return
45 |
46 | @abc.abstractmethod
47 | def decode(self, z):
48 | return
49 |
50 | @abc.abstractmethod
51 | def forward(self, x):
52 | """model return (reconstructed_x, *)"""
53 | return
54 |
55 | @abc.abstractmethod
56 | def sample(self, size):
57 | """sample new images from model"""
58 | return
59 |
60 | @abc.abstractmethod
61 | def loss_function(self, **kwargs):
62 | """accepts (original images, *) where * is the same as returned from forward()"""
63 | return
64 |
65 | @abc.abstractmethod
66 | def latest_losses(self):
67 | """returns the latest losses in a dictionary. Useful for logging."""
68 | return
69 |
70 |
71 | class CVAE_cifar_withbn(AbstractAutoEncoder):
72 | def __init__(self, d, z, **kwargs):
73 | super(CVAE_cifar_withbn, self).__init__()
74 |
75 | self.encoder = nn.Sequential(
76 | nn.Conv2d(3, d // 2, kernel_size=4, stride=2, padding=1, bias=False),
77 | nn.BatchNorm2d(d // 2),
78 | nn.ReLU(inplace=True),
79 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False),
80 | nn.BatchNorm2d(d),
81 | nn.ReLU(inplace=True),
82 | ResBlock(d, d, bn=True),
83 | nn.BatchNorm2d(d),
84 | ResBlock(d, d, bn=True),
85 | )
86 |
87 | self.decoder = nn.Sequential(
88 | ResBlock(d, d, bn=True),
89 | nn.BatchNorm2d(d),
90 | ResBlock(d, d, bn=True),
91 | nn.BatchNorm2d(d),
92 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False),
93 | nn.BatchNorm2d(d // 2),
94 | nn.LeakyReLU(inplace=True),
95 | nn.ConvTranspose2d(d // 2, 3, kernel_size=4, stride=2, padding=1, bias=False),
96 | )
97 | self.bn = nn.BatchNorm2d(3)
98 | self.f = 8
99 | self.d = d
100 | self.z = z
101 | self.fc11 = nn.Linear(d * self.f ** 2, self.z)
102 | self.fc12 = nn.Linear(d * self.f ** 2, self.z)
103 | self.fc21 = nn.Linear(self.z, d * self.f ** 2)
104 |
105 | def encode(self, x):
106 | h = self.encoder(x)
107 | h1 = h.view(-1, self.d * self.f ** 2)
108 | return h, self.fc11(h1), self.fc12(h1)
109 |
110 | def reparameterize(self, mu, logvar):
111 | if self.training:
112 | std = logvar.mul(0.5).exp_()
113 | eps = std.new(std.size()).normal_()
114 | return eps.mul(std).add_(mu)
115 | else:
116 | return mu
117 |
118 | def decode(self, z):
119 | z = z.view(-1, self.d, self.f, self.f)
120 | h3 = self.decoder(z)
121 | return torch.tanh(h3)
122 |
123 | def forward(self, x, decode=False):
124 | if decode:
125 | z_projected = self.fc21(x)
126 | gx = self.decode(z_projected)
127 | gx = self.bn(gx)
128 | return gx
129 | else:
130 | _, mu, logvar = self.encode(x)
131 | z = self.reparameterize(mu, logvar)
132 | z_projected = self.fc21(z)
133 | gx = self.decode(z_projected)
134 | gx = self.bn(gx)
135 | return z, gx, mu, logvar
136 |
137 |
138 | class CVAE_imagenet_withbn(AbstractAutoEncoder):
139 | def __init__(self, d, z, **kwargs):
140 | super(CVAE_imagenet_withbn, self).__init__()
141 |
142 | self.encoder = nn.Sequential(
143 | nn.Conv2d(3, d // 16, kernel_size=4, stride=2, padding=1, bias=False),
144 | nn.BatchNorm2d(d // 16),
145 | nn.ReLU(inplace=True),
146 | nn.Conv2d(d // 16, d // 8, kernel_size=4, stride=2, padding=1, bias=False),
147 | nn.BatchNorm2d(d // 8),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(d // 8, d // 4, kernel_size=4, stride=2, padding=1, bias=False),
150 | nn.BatchNorm2d(d // 4),
151 | nn.ReLU(inplace=True),
152 | nn.Conv2d(d // 4, d // 2, kernel_size=4, stride=2, padding=1, bias=False),
153 | nn.BatchNorm2d(d // 2),
154 | nn.ReLU(inplace=True),
155 | nn.Conv2d(d // 2, d, kernel_size=4, stride=2, padding=1, bias=False),
156 | nn.BatchNorm2d(d),
157 | nn.ReLU(inplace=True),
158 | ResBlock(d, d, bn=True),
159 | nn.BatchNorm2d(d),
160 | ResBlock(d, d, bn=True),
161 | nn.BatchNorm2d(d)
162 | )
163 |
164 | self.decoder = nn.Sequential(
165 | nn.BatchNorm2d(d),
166 | ResBlock(d, d, bn=True),
167 | nn.BatchNorm2d(d),
168 | ResBlock(d, d, bn=True),
169 | nn.BatchNorm2d(d),
170 | nn.ConvTranspose2d(d, d // 2, kernel_size=4, stride=2, padding=1, bias=False),
171 | nn.BatchNorm2d(d // 2),
172 | nn.LeakyReLU(inplace=True),
173 | nn.ConvTranspose2d(d // 2, d // 4, kernel_size=4, stride=2, padding=1, bias=False),
174 | nn.BatchNorm2d(d // 4),
175 | nn.LeakyReLU(inplace=True),
176 | nn.ConvTranspose2d(d // 4, d // 8, kernel_size=4, stride=2, padding=1, bias=False),
177 | nn.BatchNorm2d(d // 8),
178 | nn.LeakyReLU(inplace=True),
179 | nn.ConvTranspose2d(d // 8, d // 16, kernel_size=4, stride=2, padding=1, bias=False),
180 | nn.BatchNorm2d(d // 16),
181 | nn.LeakyReLU(inplace=True),
182 | nn.ConvTranspose2d(d // 16, 3, kernel_size=4, stride=2, padding=1, bias=False),
183 | )
184 | self.bn = nn.BatchNorm2d(3)
185 | self.f = 7
186 | self.d = d
187 | self.z = z
188 | self.fc11 = nn.Linear(d * self.f ** 2, self.z)
189 | self.fc12 = nn.Linear(d * self.f ** 2, self.z)
190 | self.fc21 = nn.Linear(self.z, d * self.f ** 2)
191 |
192 | def encode(self, x):
193 | h = self.encoder(x)
194 | h1 = h.view(-1, self.d * self.f ** 2)
195 | return h, self.fc11(h1), self.fc12(h1)
196 |
197 | def reparameterize(self, mu, logvar):
198 | if self.training:
199 | std = logvar.mul(0.5).exp_()
200 | eps = std.new(std.size()).normal_()
201 | return eps.mul(std).add_(mu)
202 | else:
203 | return mu
204 |
205 | def decode(self, z):
206 | z = z.view(-1, self.d, self.f, self.f)
207 | h3 = self.decoder(z)
208 | return torch.tanh(h3)
209 |
210 | def forward(self, x, decode=False):
211 | if decode:
212 | z_projected = self.fc21(x)
213 | gx = self.decode(z_projected)
214 | gx = self.bn(gx)
215 | return gx
216 | else:
217 | _, mu, logvar = self.encode(x)
218 | z = self.reparameterize(mu, logvar)
219 | z_projected = self.fc21(z)
220 | gx = self.decode(z_projected)
221 | gx = self.bn(gx)
222 | return z, gx, mu, logvar
223 |
224 |
225 |
--------------------------------------------------------------------------------