├── README.md
├── attack.py
├── config
└── celeba
│ ├── attacking
│ ├── celeba.json
│ └── ffhq.json
│ ├── training_GAN
│ ├── general_gan
│ │ ├── celeba.json
│ │ └── ffhq.json
│ └── specific_gan
│ │ ├── celeba.json
│ │ └── ffhq.json
│ ├── training_augmodel
│ ├── celeba.json
│ └── ffhq.json
│ └── training_classifiers
│ └── classify.json
├── dataloader.py
├── engine.py
├── evaluation.py
├── losses.py
├── metrics
├── KNN_dist.py
├── __init__.py
├── eval_accuracy.py
├── fid.py
└── metric_utils.py
├── models
├── __init__.py
├── classify.py
├── discri.py
├── evolve.py
├── facenet.py
└── generator.py
├── recovery.py
├── requirements.txt
├── train_augmented_model.py
├── train_classifier.py
├── train_gan.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Implementation of paper "Re-thinking Model Inversion Attacks Against Deep Neural Networks" - CVPR 2023
2 | [Paper](https://arxiv.org/pdf/2304.01669.pdf) | [Project page](https://ngoc-nguyen-0.github.io/re-thinking_model_inversion_attacks/)
3 | ## 1. Setup Environment
4 | This code has been tested with Python 3.7, PyTorch 1.11.0 and Cuda 11.3.
5 |
6 | ```
7 | conda create -n MI python=3.7
8 |
9 | conda activate MI
10 |
11 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
12 |
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## 2. Prepare Dataset & Checkpoints
17 |
18 | * Dowload CelebA and FFHQ dataset at the official website.
19 | - CelebA: download and extract the [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Then, place the `img_align_celeba` folder to `.\datasets\celeba`
20 |
21 | - FFHQ: download and extract the [FFHQ](https://github.com/NVlabs/ffhq-dataset). Then, place the `thumbnails128x128` folder to `.\datasets\ffhq`
22 |
23 | * Download meta data for the experiments at: https://drive.google.com/drive/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ?usp=sharing
24 |
25 |
26 | * We use the same target models and GAN as previous papers. You can download target models and generator at https://drive.google.com/drive/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ?usp=sharing
27 |
28 | Otherwise, you can train the target classifier and GAN as follow:
29 |
30 |
31 | ### 2.1. Training the target classifier (Optional)
32 |
33 | - Modify the configuration in `.\config\celeba\classify.json`
34 | - Then, run the following command line to get the target model
35 | ```
36 | python train_classifier.py
37 | ```
38 |
39 | ### 2.2. Training GAN (Optional)
40 |
41 | SOTA MI attacks work with a general GAN[1]. However, Inversion-Specific GANs[2] help improve the attack accuracy. In this repo, we provide codes for both training general GAN and Inversion-Specific GAN.
42 |
43 | #### 2.2.1. Build a inversion-specific GAN
44 | * Modify the configuration in
45 | * `./config/celeba/training_GAN/specific_gan/celeba.json` if training a Inversion-Specific GAN on CelebA (KEDMI[2]).
46 | * `./config/celeba/training_GAN/specific_gan/ffhq.json` if training a Inversion-Specific GAN on FFHQ (KEDMI[2]).
47 |
48 | * Then, run the following command line to get the Inversion-Specific GAN
49 | ```
50 | python train_gan.py --configs path/to/config.json --mode "specific"
51 | ```
52 |
53 | #### 2.2.2. Build a general GAN
54 | * Modify the configuration in
55 | * `./config/celeba/training_GAN/general_gan/celeba.json` if training a general GAN on CelebA (GMI[1]).
56 | * `./config/celeba/training_GAN/general_gan/ffhq.json` if training a general GAN on FFHQ (GMI[1]).
57 |
58 | * Then, run the following command line to get the General GAN
59 | ```
60 | python train_gan.py --configs path/to/config.json --mode "general"
61 | ```
62 |
63 | ## 3. Learn augmented models
64 | We provide code to train augmented models (i.e., `efficientnet_b0`, `efficientnet_b1`, and `efficientnet_b2`) from a ***target model***.
65 | * Modify the configuration in
66 | * `./config/celeba/training_augmodel/celeba.json` if training an augmented model on CelebA
67 | * `./config/celeba/training_augmodel/ffhq.json` if training an augmented model on FFHQ
68 |
69 | * Then, run the following command line to train augmented models
70 | ```
71 | python train_augmented_model.py --configs path/to/config.json
72 | ```
73 |
74 | Pretrained augmented models and p_reg can be downloaded at https://drive.google.com/drive/u/2/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ
75 |
76 | ***We remark that if you train augmented models, please do not use our p_reg***. Delete files in `./p_reg/` before inversion. Our code will automatically estimate p_reg with new augmented models.
77 |
78 | ## 4. Model Inversion Attack
79 |
80 | * Modify the configuration in
81 | * `./config/celeba/attacking/celeba.json` if training an augmented model on CelebA
82 | * `./config/celeba/attacking/ffhq.json` if training an augmented model on FFHQ
83 |
84 | * Important arguments:
85 | * `method`: select the method either ***gmi*** or ***kedmi***
86 | * `variant` select the variant either ***baseline***, ***L_aug***, ***L_logit***, or ***ours***
87 |
88 | * Then, run the following command line to attack
89 | ```
90 | python recovery.py --configs path/to/config.json
91 | ```
92 |
93 | ## 5. Evaluation
94 |
95 | After attack, use the same configuration file to run the following command line to get the result:\
96 | ```
97 | python evaluation.py --configs path/to/config.json
98 | ```
99 |
100 | ## Acknowledgements
101 | We gratefully acknowledge the following works:
102 | - Knowledge-Enriched-Distributional-Model-Inversion-Attacks: https://github.com/SCccc21/Knowledge-Enriched-DMI
103 | - EfficientNet (Pytorch): https://pytorch.org/vision/stable/models/efficientnet.html
104 | - Experiment Tracking with Weights and Biases : https://www.wandb.com/
105 |
106 |
107 | ## Reference
108 | [1]
109 | Zhang, Yuheng, et al. "The secret revealer: Generative model-inversion attacks against deep neural networks." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
110 |
111 |
112 | [2] Si Chen, Mostafa Kahla, Ruoxi Jia, and Guo-Jun Qi. Knowledge-enriched distributional model inversion attacks. In Proceedings of the IEEE/CVF international conference on computer vision, pages 16178–16187, 2021
113 |
--------------------------------------------------------------------------------
/attack.py:
--------------------------------------------------------------------------------
1 | import torch, os, time, utils
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils import log_sum_exp, save_tensor_images
6 | from torch.autograd import Variable
7 | import torch.optim as optim
8 |
9 |
10 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
11 |
12 | def reg_loss(featureT,fea_mean, fea_logvar):
13 |
14 | fea_reg = reparameterize(fea_mean, fea_logvar)
15 | fea_reg = fea_mean.repeat(featureT.shape[0],1)
16 | loss_reg = torch.mean((featureT - fea_reg).pow(2))
17 | # print('loss_reg',loss_reg)
18 | return loss_reg
19 |
20 | def attack_acc(fake,iden,E,):
21 |
22 | eval_prob = E(utils.low2high(fake))[-1]
23 |
24 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
25 |
26 | cnt, cnt5 = 0, 0
27 | bs = fake.shape[0]
28 | # print('correct id')
29 | for i in range(bs):
30 | gt = iden[i].item()
31 | if eval_iden[i].item() == gt:
32 | cnt += 1
33 | # print(gt)
34 | _, top5_idx = torch.topk(eval_prob[i], 5)
35 | if gt in top5_idx:
36 | cnt5 += 1
37 | return cnt*100.0/bs, cnt5*100.0/bs
38 |
39 | def reparameterize(mu, logvar):
40 | """
41 | Reparameterization trick to sample from N(mu, var) from
42 | N(0,1).
43 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
44 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
45 | :return: (Tensor) [B x D]
46 | """
47 | std = torch.exp(0.5 * logvar)
48 | eps = torch.randn_like(std)
49 |
50 | return eps * std + mu
51 |
52 | def find_criterion(used_loss):
53 | criterion = None
54 | if used_loss=='logit_loss':
55 | criterion = nn.NLLLoss().to(device)
56 | print('criterion:{}'.format(used_loss))
57 | elif used_loss=='cel':
58 | criterion = nn.CrossEntropyLoss().to(device)
59 | print('criterion',criterion)
60 | else:
61 | print('criterion:{}'.format(used_loss))
62 | return criterion
63 |
64 | def get_act_reg(train_loader,T,device,Nsample=5000):
65 | all_fea = []
66 | with torch.no_grad():
67 | for batch_idx, data in enumerate(train_loader): # batchsize =100
68 | # print(data.shape)
69 | if batch_idx*len(data) > Nsample:
70 | break
71 | data = data.to(device)
72 | fea,_ = T(data)
73 | if batch_idx == 0:
74 | all_fea = fea
75 | else:
76 | all_fea = torch.cat((all_fea,fea))
77 | fea_mean = torch.mean(all_fea,dim=0)
78 | fea_logvar = torch.std(all_fea,dim=0)
79 |
80 | print(fea_mean.shape, fea_logvar.shape, all_fea.shape)
81 | return fea_mean,fea_logvar
82 |
83 | def iden_loss(T,fake, iden, used_loss,criterion,fea_mean=0, fea_logvar=0,lam=0.1):
84 | Iden_Loss = 0
85 | loss_reg = 0
86 | for tn in T:
87 |
88 | feat,out = tn(fake)
89 | if used_loss == 'logit_loss': #reg only with the target classifier, reg is randomly from distribution
90 | if Iden_Loss ==0:
91 | loss_sdt = criterion(out, iden)
92 | loss_reg = lam*reg_loss(feat,fea_mean[0], fea_logvar[0]) #reg only with the target classifier
93 |
94 | Iden_Loss = Iden_Loss + loss_sdt
95 | else:
96 | loss_sdt = criterion(out, iden)
97 | Iden_Loss = Iden_Loss + loss_sdt
98 |
99 | else:
100 | loss_sdt = criterion(out, iden)
101 | Iden_Loss = Iden_Loss + loss_sdt
102 |
103 | Iden_Loss = Iden_Loss/len(T) + loss_reg
104 | return Iden_Loss
105 |
106 |
107 |
108 | def dist_inversion(G, D, T, E, iden, lr=2e-2, momentum=0.9, lamda=100, \
109 | iter_times=1500, clip_range=1.0, improved=False, num_seeds=5, \
110 | used_loss='cel', prefix='', random_seed=0, save_img_dir='',fea_mean=0, \
111 | fea_logvar=0, lam=0.1, clipz=False):
112 |
113 | iden = iden.view(-1).long().to(device)
114 | criterion = find_criterion(used_loss)
115 | bs = iden.shape[0]
116 |
117 | G.eval()
118 | D.eval()
119 | E.eval()
120 |
121 | #NOTE
122 | mu = Variable(torch.zeros(bs, 100), requires_grad=True)
123 | log_var = Variable(torch.ones(bs, 100), requires_grad=True)
124 |
125 | params = [mu, log_var]
126 | solver = optim.Adam(params, lr=lr)
127 | outputs_z = "{}_iter_{}_{}_dis.npy".format(prefix, random_seed, iter_times-1)
128 |
129 | if not os.path.exists(outputs_z):
130 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, 0)
131 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, 0)
132 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
133 | np.save(outputs_label,iden.detach().cpu().numpy())
134 |
135 | for i in range(iter_times):
136 | z = reparameterize(mu, log_var)
137 | if clipz==True:
138 | z = torch.clamp(z,-clip_range,clip_range).float()
139 | fake = G(z)
140 |
141 | if improved == True:
142 | _, label = D(fake)
143 | else:
144 | label = D(fake)
145 |
146 | for p in params:
147 | if p.grad is not None:
148 | p.grad.data.zero_()
149 | Iden_Loss = iden_loss(T,fake, iden, used_loss, criterion, fea_mean, fea_logvar, lam)
150 |
151 | if improved:
152 | Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label))
153 | else:
154 | Prior_Loss = - label.mean()
155 |
156 | Total_Loss = Prior_Loss + lamda * Iden_Loss
157 |
158 | Total_Loss.backward()
159 | solver.step()
160 |
161 | Prior_Loss_val = Prior_Loss.item()
162 | Iden_Loss_val = Iden_Loss.item()
163 |
164 | if (i+1) % 300 == 0:
165 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, i)
166 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, i)
167 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
168 | np.save(outputs_label,iden.detach().cpu().numpy())
169 |
170 | with torch.no_grad():
171 | z = reparameterize(mu, log_var)
172 | if clipz==True:
173 | z = torch.clamp(z,-clip_range, clip_range).float()
174 | fake_img = G(z.detach())
175 | eval_prob = E(utils.low2high(fake_img))[-1]
176 |
177 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
178 | acc = iden.eq(eval_iden.long()).sum().item() * 100.0 / bs
179 | save_tensor_images(fake_img, save_img_dir + '{}.png'.format(i+1))
180 | print("Iteration:{}\tPrior Loss:{:.2f}\tIden Loss:{:.2f}\tAttack Acc:{:.2f}".format(i+1, Prior_Loss_val, Iden_Loss_val, acc))
181 |
182 |
183 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, iter_times)
184 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, iter_times)
185 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
186 | np.save(outputs_label,iden.detach().cpu().numpy())
187 |
188 | def inversion(G, D, T, E, iden, lr=2e-2, momentum=0.9, lamda=100, \
189 | iter_times=1500, clip_range=1, improved=False, num_seeds=5, \
190 | used_loss='cel', prefix='', save_img_dir='', fea_mean=0, \
191 | fea_logvar=0, lam=0.1, istart=0, same_z=''):
192 |
193 | iden = iden.view(-1).long().to(device)
194 | criterion = find_criterion(used_loss)
195 | bs = iden.shape[0]
196 |
197 | G.eval()
198 | D.eval()
199 | E.eval()
200 |
201 | for random_seed in range(istart, num_seeds, 1):
202 | outputs_z = "{}_iter_{}_{}_z.npy".format(prefix, random_seed, iter_times-1)
203 |
204 | if not os.path.exists(outputs_z):
205 | tf = time.time()
206 | if same_z=='': #no prior z
207 | z = torch.randn(bs, 100).to(device).float()
208 | else:
209 | z_path = "{}_iter_{}_{}_z.npy".format(same_z, random_seed, 0)
210 | print('load z ', z_path)
211 | z = torch.from_numpy(np.load(z_path)).to(device).float()
212 | print('z',z)
213 | # exit()
214 | z.requires_grad = True
215 | v = torch.zeros(bs, 100).to(device).float()
216 |
217 | outputs_z = "{}_iter_{}_{}_z".format(prefix, random_seed, 0)
218 | outputs_label = "{}_iter_{}_label".format(prefix, random_seed, 0)
219 | np.save(outputs_z,z.detach().cpu().numpy())
220 | np.save(outputs_label,iden.detach().cpu().numpy())
221 |
222 | for i in range(iter_times):
223 | fake = G(z)
224 | if improved == True:
225 | _, label = D(fake)
226 | else:
227 | label = D(fake)
228 |
229 | if z.grad is not None:
230 | z.grad.data.zero_()
231 |
232 | if improved:
233 | Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label))
234 | else:
235 | Prior_Loss = - label.mean()
236 |
237 | Iden_Loss = iden_loss(T,fake, iden, used_loss, criterion, fea_mean, fea_logvar, lam)
238 |
239 | Total_Loss = Prior_Loss + lamda*Iden_Loss
240 |
241 | Total_Loss.backward()
242 |
243 | v_prev = v.clone()
244 | gradient = z.grad.data
245 | v = momentum*v - lr*gradient
246 | z = z + ( - momentum*v_prev + (1 + momentum)*v)
247 | z = torch.clamp(z.detach(), -clip_range, clip_range).float()
248 | z.requires_grad = True
249 |
250 | Prior_Loss_val = Prior_Loss.item()
251 | Iden_Loss_val = Iden_Loss.item()
252 |
253 | if (i+1) % 300 == 0:
254 | outputs_z = "{}_iter_{}_{}_z".format(prefix, random_seed, i)
255 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, i)
256 | np.save(outputs_z, z.detach().cpu().numpy())
257 | np.save(outputs_label, iden.detach().cpu().numpy())
258 | with torch.no_grad():
259 | fake_img = G(z.detach())
260 |
261 | eval_prob = E(utils.low2high(fake_img))[-1]
262 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
263 | acc = iden.eq(eval_iden.long()).sum().item() * 100.0 / bs
264 | print("Iteration:{}\tPrior Loss:{:.2f}\tIden Loss:{:.2f}\tAttack Acc:{:.3f}".format(i+1, Prior_Loss_val, Iden_Loss_val, acc))
265 |
266 |
--------------------------------------------------------------------------------
/config/celeba/attacking/celeba.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path": "./attack_results/",
3 | "dataset":{
4 | "model_name": "VGG16",
5 | "test_file_path": "./datasets/celeba/meta/testset.txt",
6 | "gan_file_path": "./datasets/celeba/meta/ganset.txt",
7 | "name": "celeba",
8 | "img_path": "./datasets/celeba/img_align_celeba",
9 | "img_gan_path": "./datasets/celeba/img_align_celeba",
10 | "n_classes":1000,
11 | "fid_real_path": "./datasets/celeba/meta/celeba_target_300ids.npy",
12 | "KNN_real_path": "./datasets/celeba/meta/fea_target_300ids.npy",
13 | "p_reg_path": "./checkpoints/p_reg"
14 | },
15 |
16 | "train":{
17 | "model_types": "VGG16,efficientnet_b0,efficientnet_b1,efficientnet_b2",
18 | "cls_ckpts": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b0_0.02_1.0/VGG16_efficientnet_b0_kd_0_20.pt,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b1_0.02_1.0/VGG16_efficientnet_b1_kd_0_20.pt,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b2_0.02_1.0/VGG16_efficientnet_b2_kd_0_20.pt",
19 | "num_seeds": 5,
20 | "Nclass": 300,
21 | "gan_model_dir": "./checkpoints/GAN",
22 | "eval_model": "FaceNet",
23 | "eval_dir": "./checkpoints/target_model/target_ckp/FaceNet_95.88.tar"
24 | },
25 |
26 | "attack":{
27 | "method": "kedmi",
28 | "variant": "L_logit",
29 | "iters_mi": 2400,
30 | "lr": 0.02,
31 | "lam": 1.0,
32 | "same_z":"",
33 | "eval_metric": "fid, acc, knn"
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/config/celeba/attacking/ffhq.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path": "./attack_results/",
3 | "dataset":{
4 | "model_name": "VGG16",
5 | "test_file_path": "./datasets/celeba/meta/testset.txt",
6 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt",
7 | "name": "ffhq",
8 | "img_path": "./datasets/celeba/img_align_celeba",
9 | "img_gan_path": "./datasets/ffhq/",
10 | "n_classes":1000,
11 | "fid_real_path": "./datasets/celeba/meta/celeba_target_300ids.npy",
12 | "KNN_real_path": "./datasets/celeba/meta/fea_target_300ids.npy",
13 | "p_reg_path": "./checkpoints/p_reg"
14 | },
15 |
16 | "train":{
17 | "model_types": "VGG16,efficientnet_b0,efficientnet_b1,efficientnet_b2",
18 | "cls_ckpts": "checkpoints/target_model/target_ckp/VGG16_88.26.tar,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b0_0.02_1.0/VGG16_efficientnet_b0_kd_0_20.pt,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b1_0.02_1.0/VGG16_efficientnet_b1_kd_0_20.pt,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b2_0.02_1.0/VGG16_efficientnet_b2_kd_0_20.pt",
19 | "num_seeds": 5,
20 | "Nclass": 300,
21 | "gan_model_dir": "./checkpoints/GAN",
22 | "eval_model": "FaceNet",
23 | "eval_dir": "./checkpoints/target_model/target_ckp/FaceNet_95.88.tar"
24 | },
25 |
26 | "attack":{
27 | "method": "kedmi",
28 | "variant": "L_logit",
29 | "iters_mi": 2400,
30 | "lr": 0.02,
31 | "lam": 1.0,
32 | "same_z":"",
33 | "eval_metric": "fid, acc, knn"
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/config/celeba/training_GAN/general_gan/celeba.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path":"./checkpoints/GAN",
3 |
4 | "dataset":{
5 | "gan_file_path": "./datasets/celeba/meta/ganset.txt",
6 | "model_name": "train_gan - first stage",
7 | "name": "celeba",
8 | "img_gan_path": "./datasets/celeba/img_align_celeba",
9 | "n_classes":1000
10 | },
11 |
12 | "train_gan - first stage":{
13 | "lr": 0.0002,
14 | "batch_size": 64,
15 | "z_dim": 100,
16 | "epochs": 120,
17 | "n_critic": 5,
18 | "unlabel_weight": 10
19 | }
20 | }
--------------------------------------------------------------------------------
/config/celeba/training_GAN/general_gan/ffhq.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path":"./checkpoints/GAN",
3 |
4 | "dataset":{
5 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt",
6 | "model_name": "train_gan - first stage",
7 | "name": "ffhq",
8 | "img_gan_path": "./datasets/ffhq/thumbnails128x128",
9 | "n_classes":1000
10 | },
11 |
12 | "train_gan - first stage":{
13 | "lr": 0.0002,
14 | "batch_size": 64,
15 | "z_dim": 100,
16 | "epochs": 120,
17 | "n_critic": 5,
18 | "unlabel_weight": 10
19 | }
20 | }
--------------------------------------------------------------------------------
/config/celeba/training_GAN/specific_gan/celeba.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path":"./checkpoints/GAN",
3 |
4 | "dataset":{
5 | "gan_file_path": "./datasets/celeba/meta/ganset.txt",
6 | "model_name": "VGG16",
7 | "name": "celeba",
8 | "img_gan_path": "./datasets/celeba/img_align_celeba",
9 | "n_classes":1000
10 | },
11 |
12 | "train":{
13 | "model_types": "VGG16",
14 | "num_seeds": 5,
15 | "Nclass": 300
16 | },
17 |
18 | "FaceNet64":{
19 | "lr": 0.0002,
20 | "batch_size": 64,
21 | "z_dim": 100,
22 | "epochs": 120,
23 | "n_critic": 5,
24 | "unlabel_weight": 10,
25 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/FaceNet64_88.50.tar"
26 | },
27 |
28 | "VGG16":{
29 | "lr": 0.0002,
30 | "batch_size": 64,
31 | "z_dim": 100,
32 | "epochs": 120,
33 | "n_critic": 5,
34 | "unlabel_weight": 10,
35 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/VGG16_87.10_allclass.tar"
36 | },
37 |
38 | "IR152":{
39 | "lr": 0.0002,
40 | "batch_size": 64,
41 | "z_dim": 100,
42 | "epochs": 120,
43 | "n_critic": 5,
44 | "unlabel_weight": 10,
45 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/IR152_91.16.tar"
46 | }
47 | }
--------------------------------------------------------------------------------
/config/celeba/training_GAN/specific_gan/ffhq.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path":"./checkpoints/GAN",
3 |
4 | "dataset":{
5 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt",
6 | "model_name": "VGG16",
7 | "name": "ffhq",
8 | "img_gan_path": "./datasets/ffhq/thumbnails128x128",
9 | "n_classes":1000
10 | },
11 |
12 | "train":{
13 | "model_types": "VGG16",
14 | "num_seeds": 5,
15 | "Nclass": 300
16 | },
17 |
18 | "FaceNet64":{
19 | "lr": 0.0002,
20 | "batch_size": 64,
21 | "z_dim": 100,
22 | "epochs": 120,
23 | "n_critic": 5,
24 | "unlabel_weight": 10,
25 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/FaceNet64_88.50.tar"
26 | },
27 |
28 | "VGG16":{
29 | "lr": 0.0002,
30 | "batch_size": 64,
31 | "z_dim": 100,
32 | "epochs": 120,
33 | "n_critic": 5,
34 | "unlabel_weight": 10,
35 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/VGG16_87.10_allclass.tar"
36 | },
37 |
38 | "IR152":{
39 | "lr": 0.0002,
40 | "batch_size": 64,
41 | "z_dim": 100,
42 | "epochs": 120,
43 | "n_critic": 5,
44 | "unlabel_weight": 10,
45 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/IR152_91.16.tar"
46 | }
47 | }
--------------------------------------------------------------------------------
/config/celeba/training_augmodel/celeba.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path": "./checkpoints/aug_ckp/",
3 | "dataset":{
4 | "gan_file_path": "./datasets/celeba/meta/ganset.txt",
5 | "test_file_path": "./datasets/celeba/meta/testset.txt",
6 | "name": "celeba",
7 | "model_name": "VGG16",
8 | "img_path": "./datasets/celeba/img_align_celeba",
9 | "img_gan_path": "./datasets/celeba/img_align_celeba",
10 | "n_classes": 1000,
11 | "batch_size": 64
12 | },
13 |
14 | "train":{
15 | "epochs": 20,
16 | "target_model_name": "VGG16",
17 | "target_model_ckpt": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar",
18 | "student_model_name": "efficientnet_b0",
19 | "device": "cuda",
20 | "lr": 0.01,
21 | "temperature": 1.0,
22 | "seed": 1,
23 | "log_interval": -1
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/config/celeba/training_augmodel/ffhq.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path": "./checkpoints/aug_ckp/",
3 | "dataset":{
4 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt",
5 | "test_file_path": "./datasets/celeba/meta/testset.txt",
6 | "name": "ffhq",
7 | "model_name": "VGG16",
8 | "img_path": "./datasets/celeba/img_align_celeba",
9 | "img_gan_path": "./datasets/ffhq/thumbnails128x128",
10 | "n_classes": 1000,
11 | "batch_size": 64
12 | },
13 |
14 | "train":{
15 | "epochs": 20,
16 | "target_model_name": "VGG16",
17 | "target_model_ckpt": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar",
18 | "student_model_name": "efficientnet_b0",
19 | "device": "cuda",
20 | "lr": 0.01,
21 | "temperature": 1.0,
22 | "seed": 1,
23 | "log_interval": -1
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/config/celeba/training_classifiers/classify.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path":"./checkpoints/target_model",
3 |
4 | "dataset":{
5 | "name":"celeba",
6 | "train_file_path":"./datasets/celeba/meta/trainset.txt",
7 | "test_file_path":"./datasets/celeba/meta/testset.txt",
8 | "img_path": "./datasets/celeba/img_align_celeba",
9 | "model_name":"FaceNet",
10 | "mode":"reg",
11 | "n_classes":1000,
12 | "device":"cuda"
13 | },
14 |
15 | "VGG16":{
16 | "epochs":50,
17 | "batch_size":64,
18 | "instance":4,
19 | "lr":1e-2,
20 | "momentum":0.9,
21 | "weight_decay":1e-4,
22 | "gamma":0.2,
23 | "adjust_epochs":[20, 35],
24 | "resume":""
25 | },
26 |
27 | "FaceNet":{
28 | "epochs":30,
29 | "batch_size":64,
30 | "instance":4,
31 | "lr":1e-2,
32 | "momentum":0.9,
33 | "weight_decay":1e-4,
34 | "adjust_lr":[1e-3, 1e-4],
35 | "adjust_epochs":[15, 25],
36 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth"
37 | },
38 |
39 | "FaceNet_all":{
40 | "epochs":100,
41 | "batch_size":64,
42 | "instance":4,
43 | "lr":1e-2,
44 | "momentum":0.9,
45 | "weight_decay":1e-4,
46 | "adjust_lr":[1e-3, 1e-4],
47 | "adjust_epochs":[15, 25],
48 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth"
49 | },
50 |
51 | "FaceNet64":{
52 | "epochs":50,
53 | "batch_size":64,
54 | "lr":1e-2,
55 | "momentum":0.9,
56 | "weight_decay":1e-4,
57 | "lrdecay_epoch":10,
58 | "lrdecay":0.1,
59 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth"
60 | },
61 |
62 | "IR152":{
63 | "epochs":40,
64 | "batch_size":64,
65 | "lr":1e-2,
66 | "momentum":0.9,
67 | "weight_decay":1e-4,
68 | "lrdecay_epoch":10,
69 | "lrdecay":0.1,
70 | "resume":"./checkpoints/backbone/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth"
71 | },
72 |
73 | "IR50":{
74 | "epochs":40,
75 | "batch_size":64,
76 | "lr":1e-2,
77 | "momentum":0.9,
78 | "weight_decay":1e-4,
79 | "lrdecay_epoch":10,
80 | "lrdecay":0.1,
81 | "resume":"./checkpoints/backbone/ir50.pth"
82 | }
83 | }
84 |
85 |
86 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os, torchvision, PIL
2 | import torch
3 | from PIL import Image
4 | import torch.nn.functional as F
5 | import torch.utils.data as data
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader
8 | from torch.nn.modules.loss import _Loss
9 | from torch.utils.data.sampler import SubsetRandomSampler
10 |
11 |
12 | class ImageFolder(data.Dataset):
13 | def __init__(self, args, file_path, mode):
14 | self.args = args
15 | self.mode = mode
16 | if mode == 'gan':
17 | self.img_path = args["dataset"]["img_gan_path"]
18 | else:
19 | self.img_path = args["dataset"]["img_path"]
20 | self.model_name = args["dataset"]["model_name"]
21 | # self.img_list = os.listdir(self.img_path)
22 | self.processor = self.get_processor()
23 | self.name_list, self.label_list = self.get_list(file_path)
24 | self.image_list = self.load_img()
25 | self.num_img = len(self.image_list)
26 | self.n_classes = args["dataset"]["n_classes"]
27 | if self.mode is not "gan":
28 | print("Load " + str(self.num_img) + " images")
29 |
30 | def get_list(self, file_path):
31 | name_list, label_list = [], []
32 | f = open(file_path, "r")
33 | for line in f.readlines():
34 | if self.mode == "gan":
35 | img_name = line.strip()
36 | else:
37 | img_name, iden = line.strip().split(' ')
38 | label_list.append(int(iden))
39 | name_list.append(img_name)
40 |
41 |
42 | return name_list, label_list
43 |
44 |
45 | def load_img(self):
46 | img_list = []
47 | for i, img_name in enumerate(self.name_list):
48 | if img_name.endswith(".png") or img_name.endswith(".jpg") or img_name.endswith(".jpeg") :
49 | path = self.img_path + "/" + img_name
50 | img = PIL.Image.open(path)
51 | img = img.convert('RGB')
52 | img_list.append(img)
53 | return img_list
54 |
55 |
56 | def get_processor(self):
57 | if self.model_name in ("FaceNet", "FaceNet_all"):
58 | re_size = 112
59 | else:
60 | re_size = 64
61 | if self.args["dataset"]["name"] =='celeba':
62 | crop_size = 108
63 | offset_height = (218 - crop_size) // 2
64 | offset_width = (178 - crop_size) // 2
65 | elif self.args["dataset"]["name"] == 'facescrub':
66 | # NOTE: dataset face scrub
67 | if self.mode=='gan':
68 | crop_size = 54
69 | offset_height = (64 - crop_size) // 2
70 | offset_width = (64 - crop_size) // 2
71 | else:
72 | crop_size = 108
73 | offset_height = (218 - crop_size) // 2
74 | offset_width = (178 - crop_size) // 2
75 | elif self.args["dataset"]["name"] == 'ffhq':
76 | # print('ffhq')
77 | #NOTE: dataset ffhq
78 | if self.mode=='gan':
79 | crop_size = 88
80 | offset_height = (128 - crop_size) // 2
81 | offset_width = (128 - crop_size) // 2
82 | else:
83 | crop_size = 108
84 | offset_height = (218 - crop_size) // 2
85 | offset_width = (178 - crop_size) // 2
86 |
87 | # #NOTE: dataset pf83
88 | # crop_size = 176
89 | # offset_height = (256 - crop_size) // 2
90 | # offset_width = (256 - crop_size) // 2
91 | crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]
92 |
93 | proc = []
94 | if self.mode == "train":
95 | proc.append(transforms.ToTensor())
96 | proc.append(transforms.Lambda(crop))
97 | proc.append(transforms.ToPILImage())
98 | proc.append(transforms.Resize((re_size, re_size)))
99 | proc.append(transforms.RandomHorizontalFlip(p=0.5))
100 | proc.append(transforms.ToTensor())
101 | else:
102 |
103 | proc.append(transforms.ToTensor())
104 | if self.mode=='test' or self.mode=='train' or self.args["dataset"]["name"] != 'facescrub':
105 | proc.append(transforms.Lambda(crop))
106 | proc.append(transforms.ToPILImage())
107 | proc.append(transforms.Resize((re_size, re_size)))
108 | proc.append(transforms.ToTensor())
109 |
110 |
111 | return transforms.Compose(proc)
112 |
113 | def __getitem__(self, index):
114 | processer = self.get_processor()
115 | img = processer(self.image_list[index])
116 | if self.mode == "gan":
117 | return img
118 | label = self.label_list[index]
119 |
120 | return img, label
121 |
122 | def __len__(self):
123 | return self.num_img
124 |
125 | class GrayFolder(data.Dataset):
126 | def __init__(self, args, file_path, mode):
127 | self.args = args
128 | self.mode = mode
129 | self.img_path = args["dataset"]["img_path"]
130 | self.img_list = os.listdir(self.img_path)
131 | self.processor = self.get_processor()
132 | self.name_list, self.label_list = self.get_list(file_path)
133 | self.image_list = self.load_img()
134 | self.num_img = len(self.image_list)
135 | self.n_classes = args["dataset"]["n_classes"]
136 | print("Load " + str(self.num_img) + " images")
137 |
138 | def get_list(self, file_path):
139 | name_list, label_list = [], []
140 | f = open(file_path, "r")
141 | for line in f.readlines():
142 | if self.mode == "gan":
143 | img_name = line.strip()
144 | else:
145 | img_name, iden = line.strip().split(' ')
146 | label_list.append(int(iden))
147 | name_list.append(img_name)
148 |
149 | return name_list, label_list
150 |
151 |
152 | def load_img(self):
153 | img_list = []
154 | for i, img_name in enumerate(self.name_list):
155 | if img_name.endswith(".png"):
156 | path = self.img_path + "/" + img_name
157 | img = PIL.Image.open(path)
158 | img = img.convert('L')
159 | img_list.append(img)
160 | return img_list
161 |
162 | def get_processor(self):
163 | proc = []
164 | if self.args['dataset']['name'] == "mnist":
165 | re_size = 32
166 | else:
167 | re_size = 64
168 | proc.append(transforms.Resize((re_size, re_size)))
169 | proc.append(transforms.ToTensor())
170 |
171 | return transforms.Compose(proc)
172 |
173 | def __getitem__(self, index):
174 | processer = self.get_processor()
175 | img = processer(self.image_list[index])
176 | if self.mode == "gan":
177 | return img
178 | label = self.label_list[index]
179 |
180 | return img, label
181 |
182 | def __len__(self):
183 | return self.num_img
184 |
185 | def load_mnist():
186 | transform = transforms.Compose([transforms.ToTensor()])
187 | trainset = torchvision.datasets.MNIST(mnist_path, train=True, transform=transform, download=True)
188 | testset = torchvision.datasets.MNIST(mnist_path, train=False, transform=transform, download=True)
189 |
190 | train_loader = DataLoader(trainset, batch_size=1)
191 | test_loader = DataLoader(testset, batch_size=1)
192 | cnt = 0
193 |
194 | for imgs, labels in train_loader:
195 | cnt += 1
196 | img_name = str(cnt) + '_' + str(labels.item()) + '.png'
197 | # utils.save_tensor_images(imgs, os.path.join(mnist_img_path, img_name))
198 | print("number of train files:", cnt)
199 |
200 | for imgs, labels in test_loader:
201 | cnt += 1
202 | img_name = str(cnt) + '_' + str(labels.item()) + '.png'
203 | # utils.save_tensor_images(imgs, os.path.join(mnist_img_path, img_name))
204 |
205 | class celeba(data.Dataset):
206 | def __init__(self, data_path=None, label_path=None):
207 | self.data_path = data_path
208 | self.label_path = label_path
209 |
210 | # Data transforms
211 | crop_size = 108
212 | offset_height = (218 - crop_size) // 2
213 | offset_width = (178 - crop_size) // 2
214 | proc = []
215 | proc.append(transforms.ToTensor())
216 | proc.append(transforms.Lambda(crop))
217 | proc.append(transforms.ToPILImage())
218 | proc.append(transforms.Resize((112, 112)))
219 | proc.append(transforms.ToTensor())
220 | proc.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
221 |
222 | self.transform = transforms.Compose(proc)
223 |
224 | def __len__(self):
225 | return len(self.data_path)
226 |
227 | def __getitem__(self, idx):
228 | image_set = Image.open(self.data_path[idx])
229 | image_tensor = self.transform(image_set)
230 | image_label = torch.Tensor(self.label_path[idx])
231 | return image_tensor, image_label
232 |
233 | def load_attri(file_path):
234 | data_path = sorted(glob.glob('./data/img_align_celeba_png/*.png'))
235 | print(len(data_path))
236 | # get label
237 | att_path = './data/list_attr_celeba.txt'
238 | att_list = open(att_path).readlines()[2:] # start from 2nd row
239 | data_label = []
240 | for i in range(len(att_list)):
241 | data_label.append(att_list[i].split())
242 |
243 | # transform label into 0 and 1
244 | for m in range(len(data_label)):
245 | data_label[m] = [n.replace('-1', '0') for n in data_label[m]][1:]
246 | data_label[m] = [int(p) for p in data_label[m]]
247 |
248 | dataset = celeba(data_path, data_label)
249 | # split data into train, valid, test set 7:2:1
250 | indices = list(range(202599))
251 | split_train = 141819
252 | split_valid = 182339
253 | train_idx, valid_idx, test_idx = indices[:split_train], indices[split_train:split_valid], indices[split_valid:]
254 |
255 | train_sampler = SubsetRandomSampler(train_idx)
256 | valid_sampler = SubsetRandomSampler(valid_idx)
257 | test_sampler = SubsetRandomSampler(test_idx)
258 |
259 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, sampler=train_sampler)
260 |
261 | validloader = torch.utils.data.DataLoader(dataset, sampler=valid_sampler)
262 |
263 | testloader = torch.utils.data.DataLoader(dataset, sampler=test_sampler)
264 |
265 | print(len(trainloader))
266 | print(len(validloader))
267 | print(len(testloader))
268 |
269 | return trainloader, validloader, testloader
270 |
271 |
272 | if __name__ == "__main__":
273 | print("ok")
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import torch, os, time, utils
2 | import torch.nn as nn
3 | from copy import deepcopy
4 | from utils import *
5 | from models.discri import MinibatchDiscriminator, DGWGAN
6 | from models.generator import Generator
7 | from models.classify import *
8 | from tensorboardX import SummaryWriter
9 |
10 |
11 | def test(model, criterion=None, dataloader=None, device='cuda'):
12 | tf = time.time()
13 | model.eval()
14 | loss, cnt, ACC, correct_top5 = 0.0, 0, 0,0
15 | with torch.no_grad():
16 | for i,(img, iden) in enumerate(dataloader):
17 | img, iden = img.to(device), iden.to(device)
18 |
19 | bs = img.size(0)
20 | iden = iden.view(-1)
21 | _,out_prob = model(img)
22 | out_iden = torch.argmax(out_prob, dim=1).view(-1)
23 | ACC += torch.sum(iden == out_iden).item()
24 |
25 |
26 | _, top5 = torch.topk(out_prob,5, dim = 1)
27 | for ind,top5pred in enumerate(top5):
28 | if iden[ind] in top5pred:
29 | correct_top5 += 1
30 |
31 | cnt += bs
32 |
33 | return ACC*100.0/cnt, correct_top5*100.0/cnt
34 |
35 | def train_reg(args, model, criterion, optimizer, trainloader, testloader, n_epochs, device='cuda'):
36 | best_ACC = (0.0, 0.0)
37 |
38 | for epoch in range(n_epochs):
39 | tf = time.time()
40 | ACC, cnt, loss_tot = 0, 0, 0.0
41 | model.train()
42 |
43 | for i, (img, iden) in enumerate(trainloader):
44 | img, iden = img.to(device), iden.to(device)
45 | bs = img.size(0)
46 | iden = iden.view(-1)
47 |
48 | feats, out_prob = model(img)
49 | cross_loss = criterion(out_prob, iden)
50 | loss = cross_loss
51 |
52 | optimizer.zero_grad()
53 | loss.backward()
54 | optimizer.step()
55 |
56 | out_iden = torch.argmax(out_prob, dim=1).view(-1)
57 | ACC += torch.sum(iden == out_iden).item()
58 | loss_tot += loss.item() * bs
59 | cnt += bs
60 |
61 | train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt
62 | test_acc = test(model, criterion, testloader)
63 |
64 | interval = time.time() - tf
65 | if test_acc[0] > best_ACC[0]:
66 | best_ACC = test_acc
67 | best_model = deepcopy(model)
68 |
69 | print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc[0]))
70 |
71 | print("Best Acc:{:.2f}".format(best_ACC[0]))
72 | return best_model, best_ACC
73 |
74 | def train_vib(args, model, criterion, optimizer, trainloader, testloader, n_epochs, device='cuda'):
75 | best_ACC = (0.0, 0.0)
76 |
77 | for epoch in range(n_epochs):
78 | tf = time.time()
79 | ACC, cnt, loss_tot = 0, 0, 0.0
80 |
81 | for i, (img, iden) in enumerate(trainloader):
82 | img, one_hot, iden = img.to(device), one_hot.to(device), iden.to(device)
83 | bs = img.size(0)
84 | iden = iden.view(-1)
85 |
86 | ___, out_prob, mu, std = model(img, "train")
87 | cross_loss = criterion(out_prob, one_hot)
88 | info_loss = - 0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(dim=1).mean()
89 | loss = cross_loss + beta * info_loss
90 |
91 | optimizer.zero_grad()
92 | loss.backward()
93 | optimizer.step()
94 |
95 | out_iden = torch.argmax(out_prob, dim=1).view(-1)
96 | ACC += torch.sum(iden == out_iden).item()
97 | loss_tot += loss.item() * bs
98 | cnt += bs
99 |
100 | train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt
101 | test_loss, test_acc = test(model, criterion, testloader)
102 |
103 | interval = time.time() - tf
104 | if test_acc[0] > best_ACC[0]:
105 | best_ACC = test_acc
106 | best_model = deepcopy(model)
107 |
108 | print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc[0]))
109 |
110 |
111 | print("Best Acc:{:.2f}".format(best_ACC[0]))
112 | return best_model, best_ACC
113 |
114 |
115 | def get_T(model_name_T, cfg):
116 | if model_name_T.startswith("VGG16"):
117 | T = VGG16(cfg['dataset']["n_classes"])
118 | elif model_name_T.startswith('IR152'):
119 | T = IR152(cfg['dataset']["n_classes"])
120 | elif model_name_T == "FaceNet64":
121 | T = FaceNet64(cfg['dataset']["n_classes"])
122 | T = torch.nn.DataParallel(T).cuda()
123 | ckp_T = torch.load(cfg[cfg['dataset']['model_name']]['cls_ckpts'])
124 | T.load_state_dict(ckp_T['state_dict'], strict=False)
125 |
126 | return T
127 |
128 |
129 | def train_specific_gan(cfg):
130 | # Hyperparams
131 | file_path = cfg['dataset']['gan_file_path']
132 | model_name_T = cfg['dataset']['model_name']
133 | batch_size = cfg[model_name_T]['batch_size']
134 | z_dim = cfg[model_name_T]['z_dim']
135 | n_critic = cfg[model_name_T]['n_critic']
136 | dataset_name = cfg['dataset']['name']
137 |
138 | # Create save folders
139 | root_path = cfg["root_path"]
140 | save_model_dir = os.path.join(root_path, os.path.join(dataset_name, model_name_T))
141 | save_img_dir = os.path.join(save_model_dir, "imgs")
142 | os.makedirs(save_model_dir, exist_ok=True)
143 | os.makedirs(save_img_dir, exist_ok=True)
144 |
145 |
146 | # Log file
147 | log_path = os.path.join(save_model_dir, "attack_logs")
148 | os.makedirs(log_path, exist_ok=True)
149 | log_file = "improvedGAN_{}.txt".format(model_name_T)
150 | utils.Tee(os.path.join(log_path, log_file), 'w')
151 | writer = SummaryWriter(log_path)
152 |
153 |
154 | # Load target model
155 | T = get_T(model_name_T=model_name_T, cfg=cfg)
156 |
157 | # Dataset
158 | dataset, dataloader = utils.init_dataloader(cfg, file_path, cfg[model_name_T]['batch_size'], mode="gan")
159 |
160 | # Start Training
161 | print("Training GAN for %s" % model_name_T)
162 | utils.print_params(cfg["dataset"], cfg[model_name_T])
163 |
164 | G = Generator(cfg[model_name_T]['z_dim'])
165 | DG = MinibatchDiscriminator()
166 |
167 | G = torch.nn.DataParallel(G).cuda()
168 | DG = torch.nn.DataParallel(DG).cuda()
169 |
170 | dg_optimizer = torch.optim.Adam(DG.parameters(), lr=cfg[model_name_T]['lr'], betas=(0.5, 0.999))
171 | g_optimizer = torch.optim.Adam(G.parameters(), lr=cfg[model_name_T]['lr'], betas=(0.5, 0.999))
172 |
173 | entropy = HLoss()
174 |
175 | step = 0
176 | for epoch in range(cfg[model_name_T]['epochs']):
177 | start = time.time()
178 | _, unlabel_loader1 = init_dataloader(cfg, file_path, batch_size, mode="gan", iterator=True)
179 | _, unlabel_loader2 = init_dataloader(cfg, file_path, batch_size, mode="gan", iterator=True)
180 |
181 | for i, imgs in enumerate(dataloader):
182 | current_iter = epoch * len(dataloader) + i + 1
183 |
184 | step += 1
185 | imgs = imgs.cuda()
186 | bs = imgs.size(0)
187 | x_unlabel = unlabel_loader1.next()
188 | x_unlabel2 = unlabel_loader2.next()
189 |
190 | freeze(G)
191 | unfreeze(DG)
192 |
193 | z = torch.randn(bs, z_dim).cuda()
194 | f_imgs = G(z)
195 |
196 | y_prob = T(imgs)[-1]
197 | y = torch.argmax(y_prob, dim=1).view(-1)
198 |
199 |
200 | _, output_label = DG(imgs)
201 | _, output_unlabel = DG(x_unlabel)
202 | _, output_fake = DG(f_imgs)
203 |
204 | loss_lab = softXEnt(output_label, y_prob)
205 | loss_unlab = 0.5*(torch.mean(F.softplus(log_sum_exp(output_unlabel)))-torch.mean(log_sum_exp(output_unlabel))+torch.mean(F.softplus(log_sum_exp(output_fake))))
206 | dg_loss = loss_lab + loss_unlab
207 |
208 | acc = torch.mean((output_label.max(1)[1] == y).float())
209 |
210 | dg_optimizer.zero_grad()
211 | dg_loss.backward()
212 | dg_optimizer.step()
213 |
214 | writer.add_scalar('loss_label_batch', loss_lab, current_iter)
215 | writer.add_scalar('loss_unlabel_batch', loss_unlab, current_iter)
216 | writer.add_scalar('DG_loss_batch', dg_loss, current_iter)
217 | writer.add_scalar('Acc_batch', acc, current_iter)
218 |
219 | # train G
220 | if step % n_critic == 0:
221 | freeze(DG)
222 | unfreeze(G)
223 | z = torch.randn(bs, z_dim).cuda()
224 | f_imgs = G(z)
225 | mom_gen, output_fake = DG(f_imgs)
226 | mom_unlabel, _ = DG(x_unlabel2)
227 |
228 | mom_gen = torch.mean(mom_gen, dim = 0)
229 | mom_unlabel = torch.mean(mom_unlabel, dim = 0)
230 |
231 | Hloss = entropy(output_fake)
232 | g_loss = torch.mean((mom_gen - mom_unlabel).abs()) + 1e-4 * Hloss
233 |
234 | g_optimizer.zero_grad()
235 | g_loss.backward()
236 | g_optimizer.step()
237 |
238 | writer.add_scalar('G_loss_batch', g_loss, current_iter)
239 |
240 | end = time.time()
241 | interval = end - start
242 |
243 | print("Epoch:%d \tTime:%.2f\tG_loss:%.2f\t train_acc:%.2f" % (epoch, interval, g_loss, acc))
244 |
245 | torch.save({'state_dict':G.state_dict()}, os.path.join(save_model_dir, "improved_{}_G.tar".format(dataset_name)))
246 | torch.save({'state_dict':DG.state_dict()}, os.path.join(save_model_dir, "improved_{}_D.tar".format(dataset_name)))
247 |
248 | if (epoch+1) % 10 == 0:
249 | z = torch.randn(32, z_dim).cuda()
250 | fake_image = G(z)
251 | save_tensor_images(fake_image.detach(), os.path.join(save_img_dir, "improved_celeba_img_{}.png".format(epoch)), nrow = 8)
252 |
253 |
254 | def train_general_gan(cfg):
255 | # Hyperparams
256 | file_path = cfg['dataset']['gan_file_path']
257 | model_name = cfg['dataset']['model_name']
258 | lr = cfg[model_name]['lr']
259 | batch_size = cfg[model_name]['batch_size']
260 | z_dim = cfg[model_name]['z_dim']
261 | epochs = cfg[model_name]['epochs']
262 | n_critic = cfg[model_name]['n_critic']
263 | dataset_name = cfg['dataset']['name']
264 |
265 |
266 | # Create save folders
267 | root_path = cfg["root_path"]
268 | save_model_dir = os.path.join(root_path, os.path.join(dataset_name, 'general_GAN'))
269 | save_img_dir = os.path.join(save_model_dir, "imgs")
270 | os.makedirs(save_model_dir, exist_ok=True)
271 | os.makedirs(save_img_dir, exist_ok=True)
272 |
273 |
274 | # Log file
275 | log_path = os.path.join(save_model_dir, "logs")
276 | os.makedirs(log_path, exist_ok=True)
277 | log_file = "GAN_{}.txt".format(dataset_name)
278 | utils.Tee(os.path.join(log_path, log_file), 'w')
279 | writer = SummaryWriter(log_path)
280 |
281 |
282 | # Dataset
283 | dataset, dataloader = init_dataloader(cfg, file_path, batch_size, mode="gan")
284 |
285 |
286 | # Start Training
287 | print("Training general GAN for %s" % dataset_name)
288 | utils.print_params(cfg["dataset"], cfg[model_name])
289 |
290 | G = Generator(z_dim)
291 | DG = DGWGAN(3)
292 |
293 | G = torch.nn.DataParallel(G).cuda()
294 | DG = torch.nn.DataParallel(DG).cuda()
295 |
296 | dg_optimizer = torch.optim.Adam(DG.parameters(), lr=lr, betas=(0.5, 0.999))
297 | g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
298 |
299 | step = 0
300 |
301 | for epoch in range(epochs):
302 | start = time.time()
303 | for i, imgs in enumerate(dataloader):
304 |
305 | step += 1
306 | imgs = imgs.cuda()
307 | bs = imgs.size(0)
308 |
309 | freeze(G)
310 | unfreeze(DG)
311 |
312 | z = torch.randn(bs, z_dim).cuda()
313 | f_imgs = G(z)
314 |
315 | r_logit = DG(imgs)
316 | f_logit = DG(f_imgs)
317 |
318 | wd = r_logit.mean() - f_logit.mean() # Wasserstein-1 Distance
319 | gp = gradient_penalty(imgs.data, f_imgs.data, DG=DG)
320 | dg_loss = - wd + gp * 10.0
321 |
322 | dg_optimizer.zero_grad()
323 | dg_loss.backward()
324 | dg_optimizer.step()
325 |
326 | # train G
327 |
328 | if step % n_critic == 0:
329 | freeze(DG)
330 | unfreeze(G)
331 | z = torch.randn(bs, z_dim).cuda()
332 | f_imgs = G(z)
333 | logit_dg = DG(f_imgs)
334 | # calculate g_loss
335 | g_loss = - logit_dg.mean()
336 |
337 | g_optimizer.zero_grad()
338 | g_loss.backward()
339 | g_optimizer.step()
340 |
341 | end = time.time()
342 | interval = end - start
343 |
344 | print("Epoch:%d \t Time:%.2f\t Generator loss:%.2f" % (epoch, interval, g_loss))
345 | if (epoch+1) % 10 == 0:
346 | z = torch.randn(32, z_dim).cuda()
347 | fake_image = G(z)
348 | save_tensor_images(fake_image.detach(), os.path.join(save_img_dir, "result_image_{}.png".format(epoch)), nrow = 8)
349 |
350 | torch.save({'state_dict':G.state_dict()}, os.path.join(save_model_dir, "celeba_G.tar"))
351 | torch.save({'state_dict':DG.state_dict()}, os.path.join(save_model_dir, "celeba_D.tar"))
352 |
353 | def train_augmodel(cfg):
354 | # Hyperparams
355 | target_model_name = cfg['train']['target_model_name']
356 | student_model_name = cfg['train']['student_model_name']
357 | device = cfg['train']['device']
358 | lr = cfg['train']['lr']
359 | temperature = cfg['train']['temperature']
360 | dataset_name = cfg['dataset']['name']
361 | n_classes = cfg['dataset']['n_classes']
362 | batch_size = cfg['dataset']['batch_size']
363 | seed = cfg['train']['seed']
364 | epochs = cfg['train']['epochs']
365 | log_interval = cfg['train']['log_interval']
366 |
367 |
368 | # Create save folder
369 | save_dir = os.path.join(cfg['root_path'], dataset_name)
370 | save_dir = os.path.join(save_dir, '{}_{}_{}_{}'.format(target_model_name, student_model_name, lr, temperature))
371 | os.makedirs(save_dir, exist_ok=True)
372 |
373 | # Log file
374 | now = datetime.now() # current date and time
375 | log_file = "studentKD_logs_{}.txt".format(now.strftime("%m_%d_%Y_%H_%M_%S"))
376 | utils.Tee(os.path.join(save_dir, log_file), 'w')
377 | torch.manual_seed(seed)
378 |
379 |
380 | kwargs = {'batch_size': batch_size}
381 | if device == 'cuda':
382 | kwargs.update({'num_workers': 1,
383 | 'pin_memory': True,
384 | 'shuffle': True},
385 | )
386 |
387 | # Get models
388 | teacher_model = get_augmodel(target_model_name, n_classes, cfg['train']['target_model_ckpt'])
389 | model = get_augmodel(student_model_name, n_classes)
390 | model = model.to(device)
391 | print('Target model {}: {} params'.format(target_model_name, count_parameters(model)))
392 | print('Augmented model {}: {} params'.format(student_model_name, count_parameters(teacher_model)))
393 |
394 |
395 | # Optimizer
396 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = 0.9, weight_decay = 1e-4)
397 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
398 |
399 |
400 | # Get dataset
401 | _, dataloader_train = init_dataloader(cfg, cfg['dataset']['gan_file_path'], batch_size, mode="gan")
402 | _, dataloader_test = init_dataloader(cfg, cfg['dataset']['test_file_path'], batch_size, mode="test")
403 |
404 |
405 | # Start training
406 | top1,top5 = test(teacher_model, dataloader=dataloader_test)
407 | print("Target model {}: top 1 = {}, top 5 = {}".format(target_model_name, top1, top5))
408 |
409 |
410 | loss_function = nn.KLDivLoss(reduction='sum')
411 | teacher_model.eval()
412 | for epoch in range(1, epochs + 1):
413 | model.train()
414 | for batch_idx, data in enumerate(dataloader_train):
415 | data = data.to(device)
416 |
417 | curr_batch_size = len(data)
418 | optimizer.zero_grad()
419 | _, output_t = teacher_model(data)
420 | _, output = model(data)
421 |
422 | loss = loss_function(
423 | F.log_softmax(output / temperature, dim=-1),
424 | F.softmax(output_t / temperature, dim=-1)
425 | ) / (temperature * temperature) / curr_batch_size
426 |
427 |
428 | loss.backward()
429 | optimizer.step()
430 |
431 | if (log_interval > 0) and (batch_idx % log_interval == 0):
432 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
433 | epoch, batch_idx * len(data), len(dataloader_train.dataset),
434 | 100. * batch_idx / len(dataloader_train), loss.item()))
435 |
436 | scheduler.step()
437 | top1, top5 = test(model, dataloader=dataloader_test)
438 | print("epoch {}: top 1 = {}, top 5 = {}".format(epoch, top1, top5))
439 |
440 | if (epoch+1)%10 == 0:
441 | save_path = os.path.join(save_dir, "{}_{}_kd_{}_{}.pt".format(target_model_name, student_model_name, seed, epoch+1))
442 | torch.save({'state_dict':model.state_dict()}, save_path)
443 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from metrics.KNN_dist import eval_KNN
3 | from metrics.eval_accuracy import eval_accuracy, eval_acc_class
4 | from metrics.fid import eval_fid
5 | from utils import load_json, get_attack_model
6 | import os
7 | import csv
8 |
9 | parser = ArgumentParser(description='Evaluation')
10 | parser.add_argument('--configs', type=str, default='./config/celeba/attacking/ffhq.json')
11 |
12 | args = parser.parse_args()
13 |
14 |
15 | def init_attack_args(cfg):
16 | if cfg["attack"]["method"] =='kedmi':
17 | args.improved_flag = True
18 | args.clipz = True
19 | args.num_seeds = 1
20 | else:
21 | args.improved_flag = False
22 | args.clipz = False
23 | args.num_seeds = 5
24 |
25 | if cfg["attack"]["variant"] == 'L_logit' or cfg["attack"]["variant"] == 'ours':
26 | args.loss = 'logit_loss'
27 | else:
28 | args.loss = 'cel'
29 |
30 | if cfg["attack"]["variant"] == 'L_aug' or cfg["attack"]["variant"] == 'ours':
31 | args.classid = '0,1,2,3'
32 | else:
33 | args.classid = '0'
34 |
35 |
36 | if __name__ == '__main__':
37 | # Load Data
38 | cfg = load_json(json_file=args.configs)
39 | init_attack_args(cfg=cfg)
40 |
41 | # Save dir
42 | if args.improved_flag == True:
43 | prefix = os.path.join(cfg["root_path"], "kedmi_300ids")
44 | else:
45 | prefix = os.path.join(cfg["root_path"], "gmi_300ids")
46 | save_folder = os.path.join("{}_{}".format(cfg["dataset"]["name"], cfg["dataset"]["model_name"]), cfg["attack"]["variant"])
47 | prefix = os.path.join(prefix, save_folder)
48 | save_dir = os.path.join(prefix, "latent")
49 | save_img_dir = os.path.join(prefix, "imgs_{}".format(cfg["attack"]["variant"]))
50 |
51 | # Load models
52 | _, E, G, _, _, _, _ = get_attack_model(args, cfg, eval_mode=True)
53 |
54 | # Metrics
55 | metric = cfg["attack"]["eval_metric"].split(',')
56 | fid = 0
57 | aver_acc, aver_acc5, aver_std, aver_std5 = 0, 0, 0, 0
58 | knn = 0, 0
59 | nsamples = 0
60 | dataset, model_types = '', ''
61 |
62 |
63 |
64 | for metric_ in metric:
65 | metric_ = metric_.strip()
66 | if metric_ == 'fid':
67 | fid, nsamples = eval_fid(G=G, E=E, save_dir=save_dir, cfg=cfg, args=args)
68 | elif metric_ == 'acc':
69 | aver_acc, aver_acc5, aver_std, aver_std5 = eval_accuracy(G=G, E=E, save_dir=save_dir, args=args)
70 | elif metric_ == 'knn':
71 | knn = eval_KNN(G=G, E=E, save_dir=save_dir, KNN_real_path=cfg["dataset"]["KNN_real_path"], args=args)
72 |
73 | csv_file = os.path.join(prefix, 'Eval_results.csv')
74 | if not os.path.exists(csv_file):
75 | header = ['Save_dir', 'Method', 'Succesful_samples',
76 | 'acc','std','acc5','std5',
77 | 'fid','knn']
78 | with open(csv_file, 'w') as f:
79 | writer = csv.writer(f)
80 | writer.writerow(header)
81 |
82 | fields=['{}'.format(save_dir),
83 | '{}'.format(cfg["attack"]["method"]),
84 | '{}'.format(cfg["attack"]["variant"]),
85 | '{:.2f}'.format(aver_acc),
86 | '{:.2f}'.format(aver_std),
87 | '{:.2f}'.format(aver_acc5),
88 | '{:.2f}'.format(aver_std5),
89 | '{:.2f}'.format(fid),
90 | '{:.2f}'.format(knn)]
91 |
92 | print("---------------Evaluation---------------")
93 | print('Method: {} '.format(cfg["attack"]["method"]))
94 |
95 | print('Variant: {}'.format(cfg["attack"]["variant"]))
96 | print('Top 1 attack accuracy:{:.2f} +/- {:.2f} '.format(aver_acc, aver_std))
97 | print('Top 5 attack accuracy:{:.2f} +/- {:.2f} '.format(aver_acc5, aver_std5))
98 | print('KNN distance: {:.3f}'.format(knn))
99 | print('FID score: {:.3f}'.format(fid))
100 |
101 | print("----------------------------------------")
102 | with open(csv_file, 'a') as f:
103 | writer = csv.writer(f)
104 | writer.writerow(fields)
105 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.loss import _Loss
3 |
4 | def completion_network_loss(input, output, mask):
5 | bs = input.size(0)
6 | loss = torch.sum(torch.abs(output * mask - input * mask)) / bs
7 | #return mse_loss(output * mask, input * mask)
8 | return loss
9 |
10 | def noise_loss(V, img1, img2):
11 | feat1 = V(img1)[0]
12 | feat2 = V(img2)[0]
13 | loss = torch.mean(torch.abs(feat1 - feat2))
14 | return loss
15 |
16 | class ContextLoss(_Loss):
17 | def forward(self, mask, gen, images):
18 | bs = gen.size(0)
19 | context_loss = torch.sum(torch.abs(torch.mul(mask, gen) - torch.mul(mask, images))) / bs
20 | return context_loss
21 |
22 | class CrossEntropyLoss(_Loss):
23 | def forward(self, out, gt):
24 | bs = out.size(0)
25 | #print(out.size(), gt.size())
26 | loss = - torch.mul(gt.float(), torch.log(out.float() + 1e-7))
27 | loss = torch.sum(loss) / bs
28 | return loss
29 |
--------------------------------------------------------------------------------
/metrics/KNN_dist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from metrics.fid import concatenate_list, gen_samples
5 | from utils import load_json, save_tensor_images
6 |
7 |
8 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
9 |
10 | def find_shortest_dist(fea_target,fea_fake):
11 | shortest_dist = 0
12 | pdist = torch.nn.PairwiseDistance(p=2)
13 |
14 | fea_target = torch.from_numpy(fea_target).to(device)
15 | fea_fake = torch.from_numpy(fea_fake).to(device)
16 | # print('---fea_fake.shape[0]',fea_fake.shape[0])
17 | for i in range(fea_fake.shape[0]):
18 | dist = pdist(fea_fake[i,:], fea_target)
19 |
20 | min_d = min(dist)
21 |
22 | # print('--KNN dist',min_d)
23 | shortest_dist = shortest_dist + min_d*min_d
24 | # print('--KNN dist',shortest_dist)
25 |
26 | return shortest_dist
27 |
28 | def run_KNN(target_dir, fake_dir):
29 | knn = 0
30 | target = np.load(target_dir,allow_pickle=True)
31 | fake = np.load(fake_dir,allow_pickle=True)
32 | target_fea = target.item().get('fea')
33 | target_y = target.item().get('label')
34 | fake_fea = fake.item().get('fea')
35 | fake_y = fake.item().get('label')
36 |
37 | fake_fea = concatenate_list(fake_fea)
38 | fake_y = concatenate_list(fake_y)
39 |
40 | N = fake_fea.shape[0]
41 | for id in range(300):
42 | id_f = fake_y == id
43 | id_t = target_y == id
44 | fea_f = fake_fea[id_f,:]
45 | fea_t = target_fea[id_t]
46 |
47 | shorted_dist = find_shortest_dist(fea_t,fea_f)
48 | knn = knn + shorted_dist
49 |
50 | return knn/N
51 |
52 | def eval_KNN(G, E, save_dir, KNN_real_path, args):
53 |
54 | fea_path, _ = gen_samples(G, E, save_dir, args.improved_flag)
55 |
56 | fea_path = fea_path + 'full.npy'
57 |
58 | knn = run_KNN(KNN_real_path, fea_path)
59 | print("KNN:{:.3f} ".format(knn))
60 | return knn
61 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/metrics/eval_accuracy.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from models.classify import *
3 | from models.generator import *
4 | from models.discri import *
5 | import torch
6 | import numpy as np
7 |
8 | from attack import attack_acc
9 | import statistics
10 |
11 | from metrics.fid import concatenate_list, gen_samples
12 |
13 |
14 | device = torch.torch.cuda.is_available()
15 |
16 | def accuracy(fake_dir, E):
17 |
18 | aver_acc, aver_acc5, aver_std, aver_std5 = 0, 0, 0, 0
19 |
20 | N = 5
21 | E.eval()
22 | for i in range(1):
23 | all_fake = np.load(fake_dir+'full.npy',allow_pickle=True)
24 | all_imgs = all_fake.item().get('imgs')
25 | all_label = all_fake.item().get('label')
26 |
27 | # calculate attack accuracy
28 | with torch.no_grad():
29 | N_succesful = 0
30 | N_failure = 0
31 |
32 | for random_seed in range(len(all_imgs)):
33 | if random_seed % N == 0:
34 | res, res5 = [], []
35 |
36 | #################### attack accuracy #################
37 | fake = all_imgs[random_seed]
38 | label = all_label[random_seed]
39 |
40 | label = torch.from_numpy(label)
41 | fake = torch.from_numpy(fake)
42 |
43 | acc,acc5 = attack_acc(fake,label,E)
44 |
45 |
46 | print("Seed:{} Top1/Top5:{:.3f}/{:.3f}\t".format(random_seed, acc,acc5))
47 | res.append(acc)
48 | res5.append(acc5)
49 |
50 |
51 | if (random_seed+1)%5 == 0:
52 | acc, acc_5 = statistics.mean(res), statistics.mean(res5)
53 | std = statistics.stdev(res)
54 | std5 = statistics.stdev(res5)
55 |
56 | print("Top1/Top5:{:.3f}/{:.3f}, std top1/top5:{:.3f}/{:.3f}".format(acc, acc_5, std, std5))
57 |
58 | aver_acc += acc / N
59 | aver_acc5 += acc5 / N
60 | aver_std += std / N
61 | aver_std5 += std5 / N
62 | print('N_succesful',N_succesful,N_failure)
63 |
64 |
65 | return aver_acc, aver_acc5, aver_std, aver_std5
66 |
67 |
68 |
69 | def eval_accuracy(G, E, save_dir, args):
70 |
71 | successful_imgs, _ = gen_samples(G, E, save_dir, args.improved_flag)
72 |
73 | aver_acc, aver_acc5, \
74 | aver_std, aver_std5 = accuracy(successful_imgs, E)
75 |
76 |
77 | return aver_acc, aver_acc5, aver_std, aver_std5
78 |
79 | def acc_class(filename,fake_dir,E):
80 |
81 | E.eval()
82 |
83 | sucessful_fake = np.load(fake_dir + 'success.npy',allow_pickle=True)
84 | sucessful_imgs = sucessful_fake.item().get('sucessful_imgs')
85 | sucessful_label = sucessful_fake.item().get('label')
86 | sucessful_imgs = concatenate_list(sucessful_imgs)
87 | sucessful_label = concatenate_list(sucessful_label)
88 |
89 | N_img = 5
90 | N_id = 300
91 | with torch.no_grad():
92 | acc = np.zeros(N_id)
93 | for id in range(N_id):
94 | index = sucessful_label == id
95 | acc[id] = sum(index)
96 |
97 | acc=acc*100.0/N_img
98 | print('acc',acc)
99 | csv_file = '{}acc_class.csv'.format(filename)
100 | print('csv_file',csv_file)
101 | import csv
102 | with open(csv_file, 'a') as f:
103 | writer = csv.writer(f)
104 | for i in range(N_id):
105 | # writer.writerow(['{}'.format(i),'{}'.format(acc[i])])
106 | writer.writerow([i,acc[i]])
107 |
108 | def eval_acc_class(G, E, save_dir, prefix, args):
109 |
110 | successful_imgs, _ = gen_samples(G, E, save_dir, args.improved_flag)
111 |
112 | filename = "{}/{}_".format(prefix, args.loss)
113 |
114 | acc_class(filename,successful_imgs,E)
115 |
116 |
--------------------------------------------------------------------------------
/metrics/fid.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 | import numpy as np
4 | from scipy import linalg
5 | from metrics import metric_utils
6 | import utils
7 | from attack import reparameterize
8 | import os
9 | from utils import save_tensor_images
10 |
11 |
12 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
13 | _feature_detector_cache = None
14 | def get_feature_detector():
15 | global _feature_detector_cache
16 | if _feature_detector_cache is None:
17 | _feature_detector_cache = metric_utils.get_feature_detector(
18 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
19 | 'metrics/inception-2015-12-05.pt', device)
20 |
21 | return _feature_detector_cache
22 |
23 |
24 | def postprocess(x):
25 | """."""
26 | return ((x * .5 + .5) * 255).to(torch.uint8)
27 |
28 |
29 | def run_fid(x1, x2):
30 | # Extract features
31 | x1 = run_batch_extract(x1, device)
32 | x2 = run_batch_extract(x2, device)
33 |
34 | npx1 = x1.detach().cpu().numpy()
35 | npx2 = x2.detach().cpu().numpy()
36 | mu1 = np.mean(npx1, axis=0)
37 | sigma1 = np.cov(npx1, rowvar=False)
38 | mu2 = np.mean(npx2, axis=0)
39 | sigma2 = np.cov(npx2, rowvar=False)
40 | frechet = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
41 | return frechet
42 |
43 |
44 | def run_feature_extractor(x):
45 | assert x.dtype == torch.uint8
46 | assert x.min() >= 0
47 | assert x.max() <= 255
48 | assert len(x.shape) == 4
49 | assert x.shape[1] == 3
50 | feature_extractor = get_feature_detector()
51 | return feature_extractor(x, return_features=True)
52 |
53 |
54 | def run_batch_extract(x, device, bs=500):
55 | z = []
56 | with torch.no_grad():
57 | for start in tqdm(range(0, len(x), bs), desc='run_batch_extract'):
58 | stop = start + bs
59 | x_ = x[start:stop].to(device)
60 | z_ = run_feature_extractor(postprocess(x_)).cpu()
61 | z.append(z_)
62 | z = torch.cat(z)
63 | return z
64 |
65 |
66 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6, return_details=False):
67 | """Numpy implementation of the Frechet Distance.
68 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
69 | and X_2 ~ N(mu_2, C_2) is
70 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
71 |
72 | Stable version by Dougal J. Sutherland.
73 |
74 | Params:
75 | -- mu1 : Numpy array containing the activations of a layer of the
76 | inception net (like returned by the function 'get_predictions')
77 | for generated samples.
78 | -- mu2 : The sample mean over activations, precalculated on an
79 | representative data set.
80 | -- sigma1: The covariance matrix over activations for generated samples.
81 | -- sigma2: The covariance matrix over activations, precalculated on an
82 | representative data set.
83 |
84 | Returns:
85 | -- : The Frechet Distance.
86 | """
87 |
88 | mu1 = np.atleast_1d(mu1)
89 | mu2 = np.atleast_1d(mu2)
90 |
91 | sigma1 = np.atleast_2d(sigma1)
92 | sigma2 = np.atleast_2d(sigma2)
93 |
94 | assert mu1.shape == mu2.shape, \
95 | 'Training and test mean vectors have different lengths'
96 | assert sigma1.shape == sigma2.shape, \
97 | 'Training and test covariances have different dimensions'
98 |
99 | diff = mu1 - mu2
100 |
101 | # Product might be almost singular
102 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
103 | if not np.isfinite(covmean).all():
104 | msg = ('fid calculation produces singular product; '
105 | 'adding %s to diagonal of cov estimates') % eps
106 | print(msg)
107 | offset = np.eye(sigma1.shape[0]) * eps
108 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
109 |
110 | # Numerical error might give slight imaginary component
111 | if np.iscomplexobj(covmean):
112 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
113 | m = np.max(np.abs(covmean.imag))
114 | raise ValueError('Imaginary component {}'.format(m))
115 | covmean = covmean.real
116 |
117 | tr_covmean = np.trace(covmean)
118 | if not return_details:
119 | return (diff.dot(diff) + np.trace(sigma1) +
120 | np.trace(sigma2) - 2 * tr_covmean)
121 | else:
122 | t1 = diff.dot(diff)
123 | t2 = np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
124 | return (t1 + t2), t1, t2
125 |
126 |
127 | def get_z(improved_gan, save_dir, loop, i, j):
128 | if improved_gan==True: #KEDMI
129 | outputs_z = os.path.join(save_dir, "{}_{}_iter_0_{}_dis.npy".format(loop, i, 2399))
130 | outputs_label = os.path.join(save_dir, "{}_{}_iter_0_{}_label.npy".format(loop, i, 2399))
131 |
132 | dis = np.load(outputs_z, allow_pickle=True)
133 | mu = torch.from_numpy(dis.item().get('mu')).to(device)
134 | log_var = torch.from_numpy(dis.item().get('log_var')).to(device)
135 | iden = np.load(outputs_label)
136 | z = reparameterize(mu, log_var)
137 | else: #GMI
138 | outputs_z = os.path.join(save_dir, "{}_{}_iter_{}_{}_z.npy".format(save_dir, loop, i, j, 2399))
139 | outputs_label = os.path.join(save_dir, "{}_{}_iter_{}_{}_label.npy".format(save_dir, loop, i, j, 2399))
140 |
141 | z = np.load(outputs_z)
142 | iden = np.load(outputs_label)
143 | z = torch.from_numpy(z).to(device)
144 | return z, iden
145 |
146 | def gen_samples(G, E, save_dir, improved_gan, n_iden=5, n_img=5):
147 | total_gen = 0
148 | seed = 9
149 | torch.manual_seed(seed)
150 | img_ids_path = os.path.join(save_dir, 'attack{}_'.format(seed))
151 |
152 | all_sucessful_imgs = []
153 | all_failure_imgs = []
154 |
155 | all_imgs = []
156 | all_fea = []
157 | all_id = []
158 | all_sucessful_imgs = []
159 | all_sucessful_id =[]
160 | all_sucessful_fea=[]
161 |
162 | all_failure_imgs = []
163 | all_failure_fea = []
164 | all_failure_id = []
165 |
166 | E.eval()
167 | G.eval()
168 | if not os.path.exists(img_ids_path + 'full.npy'):
169 | for loop in range(1):
170 | for i in range(n_iden): #300 ides
171 | for j in range(n_img): #5 images/iden
172 | z, iden = get_z(improved_gan, save_dir, loop, i, j)
173 | z = torch.clamp(z, -1.0, 1.0).float()
174 | total_gen = total_gen + z.shape[0]
175 | # calculate attack accuracy
176 | with torch.no_grad():
177 | fake = G(z.to(device))
178 | save_tensor_images(fake, os.path.join(save_dir, "gen_{}_{}.png".format(i,j)), nrow = 60)
179 |
180 | eval_fea, eval_prob = E(utils.low2high(fake))
181 |
182 | ### successfully attacked samples
183 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
184 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
185 | sucessful_iden = []
186 | failure_iden = []
187 | for id in range(iden.shape[0]):
188 | if eval_iden[id]==iden[id]:
189 | sucessful_iden.append(id)
190 | else:
191 | failure_iden.append(id)
192 |
193 |
194 | fake = fake.detach().cpu().numpy()
195 | eval_fea = eval_fea.detach().cpu().numpy()
196 |
197 | all_imgs.append(fake)
198 | all_fea.append(eval_fea)
199 | all_id.append(iden)
200 |
201 | if len(sucessful_iden)>0:
202 | sucessful_iden = np.array(sucessful_iden)
203 | sucessful_fake = fake[sucessful_iden,:,:,:]
204 | sucessful_eval_fea = eval_fea[sucessful_iden,:]
205 | sucessful_iden = iden[sucessful_iden]
206 | else:
207 | sucessful_fake = []
208 | sucessful_iden = []
209 | sucessful_eval_fea = []
210 |
211 | all_sucessful_imgs.append(sucessful_fake)
212 | all_sucessful_id.append(sucessful_iden)
213 | all_sucessful_fea.append(sucessful_eval_fea)
214 |
215 | if len(failure_iden)>0:
216 | failure_iden = np.array(failure_iden)
217 | failure_fake = fake[failure_iden,:,:,:]
218 | failure_eval_fea = eval_fea[failure_iden,:]
219 | failure_iden = iden[failure_iden]
220 | else:
221 | failure_fake = []
222 | failure_iden = []
223 | failure_eval_fea = []
224 |
225 | all_failure_imgs.append(failure_fake)
226 | all_failure_id.append(failure_iden)
227 | all_failure_fea.append(failure_eval_fea)
228 | np.save(img_ids_path+'full',{'imgs':all_imgs,'label':all_id,'fea':all_fea})
229 | np.save(img_ids_path+'success',{'sucessful_imgs':all_sucessful_imgs,'label':all_sucessful_id,'sucessful_fea':all_sucessful_fea})
230 | np.save(img_ids_path+'failure',{'failure_imgs':all_failure_imgs,'label':all_failure_id,'failure_fea':all_failure_fea})
231 |
232 | return img_ids_path, total_gen
233 |
234 |
235 | def concatenate_list(listA):
236 | result = []
237 | for i in range(len(listA)):
238 | val = listA [i]
239 | if len(val)>0:
240 | if len(result)==0:
241 | result = listA[i]
242 | else:
243 | result = np.concatenate((result,val))
244 | return result
245 |
246 |
247 | def eval_fid(G, E, save_dir, cfg, args):
248 |
249 | successful_imgs,_ = gen_samples(G, E, save_dir, args.improved_flag)
250 |
251 | #real data
252 | target_x = np.load(cfg['dataset']['fid_real_path'])
253 |
254 | # Load Samples
255 | sucessful_data = np.load(successful_imgs+'success.npy',allow_pickle=True)
256 | fake = sucessful_data.item().get('sucessful_imgs')
257 |
258 | fake = concatenate_list(fake)
259 |
260 | fake = torch.from_numpy(fake).to(device)
261 | target_x = torch.from_numpy(target_x).to(device)
262 | fid = run_fid(target_x, fake)
263 |
264 | return fid, fake.shape[0]
265 |
--------------------------------------------------------------------------------
/metrics/metric_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import pickle
13 | import copy
14 | import uuid
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | class MetricOptions:
22 | def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23 | assert 0 <= rank < num_gpus
24 | self.G = G
25 | self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
27 | self.num_gpus = num_gpus
28 | self.rank = rank
29 | self.device = device if device is not None else torch.device('cuda', rank)
30 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
31 | self.cache = cache
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | _feature_detector_cache = dict()
36 |
37 | def get_feature_detector_name(url):
38 | return os.path.splitext(url.split('/')[-1])[0]
39 |
40 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
41 | assert 0 <= rank < num_gpus
42 | print('device',device)
43 | key = (url, device)
44 | if key not in _feature_detector_cache:
45 | is_leader = (rank == 0)
46 | if not is_leader and num_gpus > 1:
47 | torch.distributed.barrier() # leader goes first
48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
50 | if is_leader and num_gpus > 1:
51 | torch.distributed.barrier() # others follow
52 | return _feature_detector_cache[key]
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | class FeatureStats:
57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
58 | self.capture_all = capture_all
59 | self.capture_mean_cov = capture_mean_cov
60 | self.max_items = max_items
61 | self.num_items = 0
62 | self.num_features = None
63 | self.all_features = None
64 | self.raw_mean = None
65 | self.raw_cov = None
66 |
67 | def set_num_features(self, num_features):
68 | if self.num_features is not None:
69 | assert num_features == self.num_features
70 | else:
71 | self.num_features = num_features
72 | self.all_features = []
73 | self.raw_mean = np.zeros([num_features], dtype=np.float64)
74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
75 |
76 | def is_full(self):
77 | return (self.max_items is not None) and (self.num_items >= self.max_items)
78 |
79 | def append(self, x):
80 | x = np.asarray(x, dtype=np.float32)
81 | assert x.ndim == 2
82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
83 | if self.num_items >= self.max_items:
84 | return
85 | x = x[:self.max_items - self.num_items]
86 |
87 | self.set_num_features(x.shape[1])
88 | self.num_items += x.shape[0]
89 | if self.capture_all:
90 | self.all_features.append(x)
91 | if self.capture_mean_cov:
92 | x64 = x.astype(np.float64)
93 | self.raw_mean += x64.sum(axis=0)
94 | self.raw_cov += x64.T @ x64
95 |
96 | def append_torch(self, x, num_gpus=1, rank=0):
97 | assert isinstance(x, torch.Tensor) and x.ndim == 2
98 | assert 0 <= rank < num_gpus
99 | if num_gpus > 1:
100 | ys = []
101 | for src in range(num_gpus):
102 | y = x.clone()
103 | torch.distributed.broadcast(y, src=src)
104 | ys.append(y)
105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
106 | self.append(x.cpu().numpy())
107 |
108 | def get_all(self):
109 | assert self.capture_all
110 | return np.concatenate(self.all_features, axis=0)
111 |
112 | def get_all_torch(self):
113 | return torch.from_numpy(self.get_all())
114 |
115 | def get_mean_cov(self):
116 | assert self.capture_mean_cov
117 | mean = self.raw_mean / self.num_items
118 | cov = self.raw_cov / self.num_items
119 | cov = cov - np.outer(mean, mean)
120 | return mean, cov
121 |
122 | def save(self, pkl_file):
123 | with open(pkl_file, 'wb') as f:
124 | pickle.dump(self.__dict__, f)
125 |
126 | @staticmethod
127 | def load(pkl_file):
128 | with open(pkl_file, 'rb') as f:
129 | s = dnnlib.EasyDict(pickle.load(f))
130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
131 | obj.__dict__.update(s)
132 | return obj
133 |
134 | #----------------------------------------------------------------------------
135 |
136 | class ProgressMonitor:
137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
138 | self.tag = tag
139 | self.num_items = num_items
140 | self.verbose = verbose
141 | self.flush_interval = flush_interval
142 | self.progress_fn = progress_fn
143 | self.pfn_lo = pfn_lo
144 | self.pfn_hi = pfn_hi
145 | self.pfn_total = pfn_total
146 | self.start_time = time.time()
147 | self.batch_time = self.start_time
148 | self.batch_items = 0
149 | if self.progress_fn is not None:
150 | self.progress_fn(self.pfn_lo, self.pfn_total)
151 |
152 | def update(self, cur_items):
153 | assert (self.num_items is None) or (cur_items <= self.num_items)
154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
155 | return
156 | cur_time = time.time()
157 | total_time = cur_time - self.start_time
158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
159 | if (self.verbose) and (self.tag is not None):
160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
161 | self.batch_time = cur_time
162 | self.batch_items = cur_items
163 |
164 | if (self.progress_fn is not None) and (self.num_items is not None):
165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
166 |
167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
168 | return ProgressMonitor(
169 | tag = tag,
170 | num_items = num_items,
171 | flush_interval = flush_interval,
172 | verbose = self.verbose,
173 | progress_fn = self.progress_fn,
174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
176 | pfn_total = self.pfn_total,
177 | )
178 |
179 | #----------------------------------------------------------------------------
180 |
181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
183 | if data_loader_kwargs is None:
184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
185 |
186 | # Try to lookup from cache.
187 | cache_file = None
188 | if opts.cache:
189 | # Choose cache file name.
190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
194 |
195 | # Check if the file exists (all processes must agree).
196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False
197 | if opts.num_gpus > 1:
198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
199 | torch.distributed.broadcast(tensor=flag, src=0)
200 | flag = (float(flag.cpu()) != 0)
201 |
202 | # Load.
203 | if flag:
204 | return FeatureStats.load(cache_file)
205 |
206 | # Initialize.
207 | num_items = len(dataset)
208 | if max_items is not None:
209 | num_items = min(num_items, max_items)
210 | stats = FeatureStats(max_items=num_items, **stats_kwargs)
211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
213 |
214 | # Main loop.
215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
217 | if images.shape[1] == 1:
218 | images = images.repeat([1, 3, 1, 1])
219 | features = detector(images.to(opts.device), **detector_kwargs)
220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
221 | progress.update(stats.num_items)
222 |
223 | # Save to cache.
224 | if cache_file is not None and opts.rank == 0:
225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
226 | temp_file = cache_file + '.' + uuid.uuid4().hex
227 | stats.save(temp_file)
228 | os.replace(temp_file, cache_file) # atomic
229 | return stats
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
234 | if batch_gen is None:
235 | batch_gen = min(batch_size, 4)
236 | assert batch_size % batch_gen == 0
237 |
238 | # Setup generator and load labels.
239 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
240 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
241 |
242 | # Image generation func.
243 | def run_generator(z, c):
244 | img = G(z=z, c=c, **opts.G_kwargs)
245 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
246 | return img
247 |
248 | # JIT.
249 | if jit:
250 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
251 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
252 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
253 |
254 | # Initialize.
255 | stats = FeatureStats(**stats_kwargs)
256 | assert stats.max_items is not None
257 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
258 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
259 |
260 | # Main loop.
261 | while not stats.is_full():
262 | images = []
263 | for _i in range(batch_size // batch_gen):
264 | z = torch.randn([batch_gen, G.z_dim], device=opts.device)
265 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
266 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
267 | images.append(run_generator(z, c))
268 | images = torch.cat(images)
269 | if images.shape[1] == 1:
270 | images = images.repeat([1, 3, 1, 1])
271 | features = detector(images, **detector_kwargs)
272 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
273 | progress.update(stats.num_items)
274 | return stats
275 |
276 | #----------------------------------------------------------------------------
277 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sutd-visual-computing-group/Re-thinking_MI/ed4b6b8900e4d7fed8f70d7e6a444893963d920e/models/__init__.py
--------------------------------------------------------------------------------
/models/classify.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models
5 | import torch.nn.functional as F
6 | # from torch.nn.modules.loss import _Loss
7 | import models.evolve as evolve
8 | import utils
9 |
10 | class Flatten(nn.Module):
11 | def forward(self, input):
12 | return input.view(input.size(0), -1)
13 |
14 | class MCNN2(nn.Module):
15 | def __init__(self, num_classes=10):
16 | super(MCNN2, self).__init__()
17 | self.feat_dim = 12800
18 | self.num_classes = num_classes
19 | self.feature = nn.Sequential(
20 | nn.Conv2d(1, 64, 7, stride=1, padding=1),
21 | nn.BatchNorm2d(64),
22 | nn.LeakyReLU(0.2),
23 | nn.MaxPool2d(2, 2),
24 | nn.Conv2d(64, 128, 5, stride=1),
25 | nn.BatchNorm2d(128),
26 | nn.LeakyReLU(0.2))
27 |
28 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
29 |
30 | def forward(self, x):
31 | feature = self.feature(x)
32 | feature = feature.view(feature.size(0), -1)
33 | out = self.fc_layer(feature)
34 | return feature,out
35 |
36 | class MCNN4(nn.Module):
37 | def __init__(self, num_classes=10):
38 | super(MCNN4, self).__init__()
39 | self.feat_dim = 128
40 | self.num_classes = num_classes
41 | self.feature = nn.Sequential(
42 | nn.Conv2d(1, 32, 7, stride=1, padding=1),
43 | nn.BatchNorm2d(32),
44 | nn.LeakyReLU(0.2),
45 | nn.MaxPool2d(2, 2),
46 | nn.Conv2d(32, 64, 5, stride=1),
47 | nn.BatchNorm2d(64),
48 | nn.LeakyReLU(0.2),
49 | nn.MaxPool2d(2, 2),
50 | nn.Conv2d(64, 128, 5, stride=1),
51 | nn.BatchNorm2d(128),
52 | nn.LeakyReLU(0.2))
53 |
54 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
55 |
56 | def forward(self, x):
57 | feature = self.feature(x)
58 | feature = feature.view(feature.size(0), -1)
59 | out = self.fc_layer(feature)
60 | return feature,out
61 |
62 | class MCNN(nn.Module):
63 | def __init__(self, num_classes=10):
64 | super(MCNN, self).__init__()
65 | self.feat_dim = 256
66 | self.num_classes = num_classes
67 | self.feature = nn.Sequential(
68 | nn.Conv2d(1, 64, 7, stride=1, padding=1),
69 | nn.BatchNorm2d(64),
70 | nn.LeakyReLU(0.2),
71 | nn.MaxPool2d(2, 2),
72 | nn.Conv2d(64, 128, 5, stride=1),
73 | nn.BatchNorm2d(128),
74 | nn.LeakyReLU(0.2),
75 | nn.MaxPool2d(2, 2),
76 | nn.Conv2d(128, 256, 5, stride=1),
77 | nn.BatchNorm2d(256),
78 | nn.LeakyReLU(0.2))
79 |
80 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
81 |
82 | def forward(self, x):
83 | feature = self.feature(x)
84 | feature = feature.view(feature.size(0), -1)
85 | out = self.fc_layer(feature)
86 | return feature,out
87 |
88 | class SCNN(nn.Module):
89 | def __init__(self, num_classes=10):
90 | super(SCNN, self).__init__()
91 | self.feat_dim = 512
92 | self.num_classes = num_classes
93 | self.feature = nn.Sequential(
94 | nn.Conv2d(1, 32, 7, stride=1, padding=1),
95 | nn.BatchNorm2d(32),
96 | nn.LeakyReLU(0.2),
97 | nn.MaxPool2d(2, 2),
98 | nn.Conv2d(32, 64, 5, stride=1),
99 | nn.BatchNorm2d(64),
100 | nn.LeakyReLU(0.2),
101 | nn.MaxPool2d(2, 2),
102 | nn.Conv2d(64, 128, 5, stride=1),
103 | nn.BatchNorm2d(128),
104 | nn.LeakyReLU(0.2),
105 | nn.Conv2d(128, 256, 3, stride=1, padding=1),
106 | nn.BatchNorm2d(256),
107 | nn.LeakyReLU(0.2),
108 | nn.Conv2d(256, 512, 3, stride=1, padding=1),
109 | nn.BatchNorm2d(512),
110 | nn.LeakyReLU(0.2))
111 |
112 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
113 |
114 | def forward(self, x):
115 | feature = self.feature(x)
116 | feature = feature.view(feature.size(0), -1)
117 | out = self.fc_layer(feature)
118 | return feature,out
119 |
120 | class Mnist_CNN(nn.Module):
121 | def __init__(self):
122 | super(Mnist_CNN, self).__init__()
123 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
124 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
125 | self.conv2_drop = nn.Dropout2d()
126 | self.fc1 = nn.Linear(500, 50)
127 | self.fc2 = nn.Linear(50, 5)
128 |
129 | def forward(self, x):
130 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
131 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
132 | x = x.view(x.size(0), -1)
133 | x = F.relu(self.fc1(x))
134 | x = F.dropout(x, training=self.training)
135 | res = self.fc2(x)
136 | return [x, res]
137 |
138 |
139 | class VGG16_xray8(nn.Module):
140 | def __init__(self, num_classes=7):
141 | super(VGG16_xray8, self).__init__()
142 | model = torchvision.models.vgg16_bn(pretrained=True)
143 | model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
144 | self.feature = model.features
145 | self.feat_dim = 2048
146 | self.num_classes = num_classes
147 | self.bn = nn.BatchNorm1d(self.feat_dim)
148 | self.bn.bias.requires_grad_(False) # no shift
149 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
150 | self.model = model
151 |
152 | def forward(self, x):
153 | feature = self.feature(x)
154 | feature = feature.view(feature.size(0), -1)
155 | feature = self.bn(feature)
156 | res = self.fc_layer(feature)
157 |
158 | return feature,res
159 |
160 | def predict(self, x):
161 | feature = self.feature(x)
162 | feature = feature.view(feature.size(0), -1)
163 | feature = self.bn(feature)
164 | res = self.fc_layer(feature)
165 |
166 | return res
167 |
168 | class VGG16(nn.Module):
169 | def __init__(self, n_classes):
170 | super(VGG16, self).__init__()
171 | model = torchvision.models.vgg16_bn(pretrained=True)
172 | self.feature = model.features
173 | self.feat_dim = 512 * 2 * 2
174 | self.n_classes = n_classes
175 | self.bn = nn.BatchNorm1d(self.feat_dim)
176 | self.bn.bias.requires_grad_(False) # no shift
177 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
178 |
179 | def forward(self, x):
180 | feature = self.feature(x)
181 | feature = feature.view(feature.size(0), -1)
182 | feature = self.bn(feature)
183 | res = self.fc_layer(feature)
184 | return feature,res
185 |
186 | def predict(self, x):
187 | feature = self.feature(x)
188 | feature = feature.view(feature.size(0), -1)
189 | feature = self.bn(feature)
190 | res = self.fc_layer(feature)
191 | out = F.softmax(res, dim=1)
192 |
193 | return feature,out
194 |
195 | class VGG16_vib(nn.Module):
196 | def __init__(self, n_classes):
197 | super(VGG16_vib, self).__init__()
198 | model = torchvision.models.vgg16_bn(pretrained=True)
199 | self.feature = model.features
200 | self.feat_dim = 512 * 2 * 2
201 | self.k = self.feat_dim // 2
202 | self.n_classes = n_classes
203 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2)
204 | self.fc_layer = nn.Linear(self.k, self.n_classes)
205 |
206 | def forward(self, x, mode="train"):
207 | feature = self.feature(x)
208 | feature = feature.view(feature.size(0), -1)
209 | statis = self.st_layer(feature)
210 | mu, std = statis[:, :self.k], statis[:, self.k:]
211 |
212 | std = F.softplus(std-5, beta=1)
213 | eps = torch.FloatTensor(std.size()).normal_().cuda()
214 | res = mu + std * eps
215 | out = self.fc_layer(res)
216 |
217 | return [feature, out, mu, std]
218 |
219 | def predict(self, x):
220 | feature = self.feature(x)
221 | feature = feature.view(feature.size(0), -1)
222 | statis = self.st_layer(feature)
223 | mu, std = statis[:, :self.k], statis[:, self.k:]
224 |
225 | std = F.softplus(std-5, beta=1)
226 | eps = torch.FloatTensor(std.size()).normal_().cuda()
227 | res = mu + std * eps
228 | out = self.fc_layer(res)
229 |
230 | return out
231 |
232 | class VGG19(nn.Module):
233 | def __init__(self, num_of_classes):
234 | super(VGG19, self).__init__()
235 | model = torchvision.models.vgg19_bn(pretrained=True)
236 | self.feature = nn.Sequential(*list(model.children())[:-2])
237 |
238 | self.feat_dim = 512 * 2 * 2
239 | self.num_of_classes = num_of_classes
240 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes)
241 | def classifier(self, x):
242 | out = self.fc_layer(x)
243 | __, iden = torch.max(out, dim = 1)
244 | iden = iden.view(-1, 1)
245 | return out, iden
246 |
247 | def forward(self, x):
248 | feature = self.feature(x)
249 | feature = feature.view(feature.size(0), -1)
250 | out, iden = self.classifier(feature)
251 | return feature, out
252 |
253 | class VGG19_xray8(nn.Module):
254 | def __init__(self, num_classes):
255 | super(VGG19_xray8, self).__init__()
256 | model = torchvision.models.vgg19_bn(pretrained=True)
257 | model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
258 | self.feature = model.features
259 | self.feat_dim = 2048
260 | self.num_classes = num_classes
261 | self.bn = nn.BatchNorm1d(self.feat_dim)
262 | self.bn.bias.requires_grad_(False) # no shift
263 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
264 | self.model = model
265 |
266 | def classifier(self, x):
267 | out = self.fc_layer(x)
268 | __, iden = torch.max(out, dim = 1)
269 | iden = iden.view(-1, 1)
270 | return out, iden
271 |
272 | def forward(self, x):
273 | feature = self.feature(x)
274 | # print(feature.shape)
275 | feature = feature.view(feature.size(0), -1)
276 | out, iden = self.classifier(feature)
277 | return feature, out
278 |
279 |
280 | class EfficientNet_b0(nn.Module):
281 | def __init__(self, n_classes):
282 | super(EfficientNet_b0, self).__init__()
283 | model = torchvision.models.efficientnet.efficientnet_b0(pretrained=True)
284 | self.feature = nn.Sequential(*list(model.children())[:-1])
285 | self.n_classes = n_classes
286 | self.feat_dim = 1280
287 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
288 |
289 | def forward(self, x):
290 | feature = self.feature(x)
291 | feature = feature.view(feature.size(0), -1)
292 | res = self.fc_layer(feature)
293 | return feature,res
294 |
295 | def predict(self, x):
296 | feature = self.feature(x)
297 | feature = feature.view(feature.size(0), -1)
298 | res = self.fc_layer(feature)
299 | out = F.softmax(res, dim=1)
300 |
301 | return feature,out
302 |
303 | class EfficientNet_b1(nn.Module):
304 | def __init__(self, n_classes):
305 | super(EfficientNet_b1, self).__init__()
306 | model = torchvision.models.efficientnet.efficientnet_b1(pretrained=True)
307 | self.feature = nn.Sequential(*list(model.children())[:-1])
308 | self.n_classes = n_classes
309 | self.feat_dim = 1280
310 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
311 |
312 | def forward(self, x):
313 | feature = self.feature(x)
314 | feature = feature.view(feature.size(0), -1)
315 | res = self.fc_layer(feature)
316 | return feature,res
317 |
318 | def predict(self, x):
319 | feature = self.feature(x)
320 | feature = feature.view(feature.size(0), -1)
321 | res = self.fc_layer(feature)
322 | out = F.softmax(res, dim=1)
323 |
324 | return feature,out
325 |
326 | class EfficientNet_b2(nn.Module):
327 | def __init__(self, n_classes):
328 | super(EfficientNet_b2, self).__init__()
329 | model = torchvision.models.efficientnet.efficientnet_b2(pretrained=True)
330 | self.feature = nn.Sequential(*list(model.children())[:-1])
331 | self.n_classes = n_classes
332 | self.feat_dim = 1408
333 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
334 |
335 | def forward(self, x):
336 | feature = self.feature(x)
337 | feature = feature.view(feature.size(0), -1)
338 | res = self.fc_layer(feature)
339 | return feature,res
340 |
341 | def predict(self, x):
342 | feature = self.feature(x)
343 | feature = feature.view(feature.size(0), -1)
344 | res = self.fc_layer(feature)
345 | out = F.softmax(res, dim=1)
346 |
347 | return feature,out
348 |
349 | class EfficientNet_v2_s2(nn.Module):
350 | def __init__(self, n_classes,dataset='celeba'):
351 | super(EfficientNet_v2_s2, self).__init__()
352 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True)
353 | self.feature = nn.Sequential(*list(model.children())[:-1])
354 | self.n_classes = n_classes
355 | self.feat_dim = 1028
356 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
357 |
358 | def forward(self, x):
359 | feature = self.feature(x)
360 | feature = feature.view(feature.size(0), -1)
361 | res = self.fc_layer(feature)
362 | return feature,res
363 |
364 | def predict(self, x):
365 | feature = self.feature(x)
366 |
367 | feature = feature.view(feature.size(0), -1)
368 | res = self.fc_layer(feature)
369 | out = F.softmax(res, dim=1)
370 |
371 | return feature,out
372 |
373 | class EfficientNet_v2_m2(nn.Module):
374 | def __init__(self, n_classes,dataset='celeba'):
375 | super(EfficientNet_v2_m2, self).__init__()
376 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True)
377 |
378 | self.feature = nn.Sequential(*list(model.children())[:-1])
379 | self.n_classes = n_classes
380 | self.feat_dim = 1028
381 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
382 |
383 | def forward(self, x):
384 | feature = self.feature(x)
385 | feature = feature.view(feature.size(0), -1)
386 | res = self.fc_layer(feature)
387 | return feature,res
388 |
389 | def predict(self, x):
390 | feature = self.feature(x)
391 | feature = feature.view(feature.size(0), -1)
392 | res = self.fc_layer(feature)
393 | out = F.softmax(res, dim=1)
394 |
395 | return feature,out
396 |
397 | class EfficientNet_v2_l2(nn.Module):
398 | def __init__(self, n_classes, dataset='celeba'):
399 | super(EfficientNet_v2_l2, self).__init__()
400 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True)
401 | self.feature = nn.Sequential(*list(model.children())[:-1])
402 | self.n_classes = n_classes
403 | self.feat_dim = 1028
404 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
405 |
406 | def forward(self, x):
407 | feature = self.feature(x)
408 | feature = feature.view(feature.size(0), -1)
409 | res = self.fc_layer(feature)
410 | return feature,res
411 |
412 | def predict(self, x):
413 | feature = self.feature(x)
414 | feature = feature.view(feature.size(0), -1)
415 | res = self.fc_layer(feature)
416 | out = F.softmax(res, dim=1)
417 |
418 | return feature,out
419 |
420 | class EfficientNet_v2_s(nn.Module):
421 | def __init__(self, n_classes, dataset='celeba'):
422 | super(EfficientNet_v2_s, self).__init__()
423 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True)
424 | self.feature = nn.Sequential(*list(model.children())[:-2])
425 | self.n_classes = n_classes
426 | self.feat_dim = 5120
427 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
428 |
429 | def forward(self, x):
430 | feature = self.feature(x)
431 | feature = feature.view(feature.size(0), -1)
432 | res = self.fc_layer(feature)
433 | return feature,res
434 |
435 | def predict(self, x):
436 | feature = self.feature(x)
437 |
438 | feature = feature.view(feature.size(0), -1)
439 | res = self.fc_layer(feature)
440 | out = F.softmax(res, dim=1)
441 |
442 | return feature,out
443 |
444 | class EfficientNet_v2_m(nn.Module):
445 | def __init__(self, n_classes, dataset='celeba'):
446 | super(EfficientNet_v2_m, self).__init__()
447 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True)
448 |
449 | self.feature = nn.Sequential(*list(model.children())[:-2])
450 | self.n_classes = n_classes
451 | self.feat_dim = 5120
452 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
453 |
454 | def forward(self, x):
455 | feature = self.feature(x)
456 | feature = feature.view(feature.size(0), -1)
457 | res = self.fc_layer(feature)
458 | return feature,res
459 |
460 | def predict(self, x):
461 | feature = self.feature(x)
462 | feature = feature.view(feature.size(0), -1)
463 | res = self.fc_layer(feature)
464 | out = F.softmax(res, dim=1)
465 |
466 | return feature,out
467 |
468 | class EfficientNet_v2_l(nn.Module):
469 | def __init__(self, n_classes, dataset='celeba'):
470 | super(EfficientNet_v2_l, self).__init__()
471 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True)
472 | self.feature = nn.Sequential(*list(model.children())[:-2])
473 | self.n_classes = n_classes
474 | self.feat_dim = 5120
475 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
476 |
477 | def forward(self, x):
478 | feature = self.feature(x)
479 | feature = feature.view(feature.size(0), -1)
480 | res = self.fc_layer(feature)
481 | return feature,res
482 |
483 | def predict(self, x):
484 | feature = self.feature(x)
485 | feature = feature.view(feature.size(0), -1)
486 | res = self.fc_layer(feature)
487 | out = F.softmax(res, dim=1)
488 |
489 | return feature,out
490 |
491 | class efficientNet_v2_l_xray(nn.Module):
492 | def __init__(self, n_classes):
493 | super(efficientNet_v2_l_xray, self).__init__()
494 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True)
495 | model.features[0][0] =nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
496 | self.feature = model.features
497 | self.n_classes = n_classes
498 | self.feat_dim = 5120
499 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
500 |
501 | def forward(self, x):
502 | feature = self.feature(x)
503 | feature = feature.view(feature.size(0), -1)
504 | res = self.fc_layer(feature)
505 | return feature,res
506 |
507 | def predict(self, x):
508 | feature = self.feature(x)
509 | feature = feature.view(feature.size(0), -1)
510 | res = self.fc_layer(feature)
511 | out = F.softmax(res, dim=1)
512 |
513 | return feature,out
514 |
515 | class efficientnet_v2_m_xray(nn.Module):
516 | def __init__(self, n_classes):
517 | super(efficientnet_v2_m_xray, self).__init__()
518 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True)
519 |
520 | model.features[0][0] =nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
521 | self.feature = model.features
522 | self.n_classes = n_classes
523 | self.feat_dim = 5120
524 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
525 |
526 | def forward(self, x):
527 | feature = self.feature(x)
528 | feature = feature.view(feature.size(0), -1)
529 | res = self.fc_layer(feature)
530 | return feature,res
531 |
532 | def predict(self, x):
533 | feature = self.feature(x)
534 | feature = feature.view(feature.size(0), -1)
535 | res = self.fc_layer(feature)
536 | out = F.softmax(res, dim=1)
537 |
538 | return feature,out
539 |
540 | class efficientnet_v2_s_xray(nn.Module):
541 | def __init__(self, n_classes):
542 | super(efficientnet_v2_s_xray, self).__init__()
543 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True)
544 | model.features[0][0] =nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
545 |
546 | self.feature = model.features
547 | self.n_classes = n_classes
548 | self.feat_dim = 5120
549 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes)
550 |
551 | def forward(self, x):
552 | feature = self.feature(x)
553 | feature = feature.view(feature.size(0), -1)
554 | res = self.fc_layer(feature)
555 | return feature,res
556 |
557 | def predict(self, x):
558 | feature = self.feature(x)
559 | feature = feature.view(feature.size(0), -1)
560 | res = self.fc_layer(feature)
561 | out = F.softmax(res, dim=1)
562 |
563 | return feature,out
564 |
565 |
566 | class ResNet18(nn.Module):
567 | def __init__(self, num_of_classes):
568 | super(ResNet18, self).__init__()
569 | model = torchvision.models.resnet18(pretrained=True)
570 | self.feature = nn.Sequential(*list(model.children())[:-2])
571 | self.feat_dim = 2048 * 1 * 1
572 | self.num_of_classes = num_of_classes
573 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes)
574 |
575 | def classifier(self, x):
576 | out = self.fc_layer(x)
577 | __, iden = torch.max(out, dim = 1)
578 | iden = iden.view(-1, 1)
579 | return out, iden
580 |
581 | def forward(self, x):
582 | feature = self.feature(x)
583 | feature = feature.view(feature.size(0), -1)
584 | out, iden = self.classifier(feature)
585 | return feature, out
586 |
587 | class ResNet34(nn.Module):
588 | def __init__(self, num_of_classes):
589 | super(ResNet34, self).__init__()
590 | model = torchvision.models.resnet34(pretrained=True)
591 | self.feature = nn.Sequential(*list(model.children())[:-2])
592 | self.feat_dim = 2048 * 1 * 1
593 | self.num_of_classes = num_of_classes
594 |
595 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes)
596 |
597 | def classifier(self, x):
598 | out = self.fc_layer(x)
599 | __, iden = torch.max(out, dim = 1)
600 | iden = iden.view(-1, 1)
601 | return out, iden
602 |
603 | def forward(self, x):
604 | feature = self.feature(x)
605 | feature = feature.view(feature.size(0), -1)
606 | out, iden = self.classifier(feature)
607 | return feature, out
608 |
609 | class ResNet34_xray8(nn.Module):
610 | def __init__(self, num_of_classes):
611 | super(ResNet34_xray8, self).__init__()
612 | model = torchvision.models.resnet34(pretrained=True)
613 | model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
614 |
615 | self.feature = nn.Sequential(*list(model.children())[:-2])
616 | self.feat_dim = 2048 * 1 * 1
617 | self.num_of_classes = num_of_classes
618 |
619 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes)
620 |
621 | def classifier(self, x):
622 | out = self.fc_layer(x)
623 | __, iden = torch.max(out, dim = 1)
624 | iden = iden.view(-1, 1)
625 | return out, iden
626 |
627 | def forward(self, x):
628 | feature = self.feature(x)
629 | feature = feature.view(feature.size(0), -1)
630 | out, iden = self.classifier(feature)
631 | return feature, out
632 |
633 |
634 | class Mobilenet_v3_small(nn.Module):
635 | def __init__(self, num_of_classes):
636 | super(Mobilenet_v3_small, self).__init__()
637 | model = torchvision.models.mobilenet_v3_small(pretrained=True)
638 | self.feature = nn.Sequential(*list(model.children())[:-2])
639 | self.feat_dim = 2304
640 | self.num_of_classes = num_of_classes
641 | self.fc_layer = nn.Sequential(
642 | nn.Linear(self.feat_dim, self.num_of_classes),)
643 |
644 | def classifier(self, x):
645 | out = self.fc_layer(x)
646 | __, iden = torch.max(out, dim = 1)
647 | iden = iden.view(-1, 1)
648 | return out, iden
649 |
650 | def forward(self, x):
651 | feature = self.feature(x)
652 | feature = feature.view(feature.size(0), -1)
653 | out, iden = self.classifier(feature)
654 | return feature,out
655 |
656 | class Mobilenet_v2(nn.Module):
657 | def __init__(self, num_of_classes):
658 | super(Mobilenet_v2, self).__init__()
659 | model = torchvision.models.mobilenet_v2(pretrained=True)
660 | self.feature = nn.Sequential(*list(model.children())[:-2])
661 | self.feat_dim = 12288
662 | self.num_of_classes = num_of_classes
663 | self.fc_layer = nn.Sequential(
664 | nn.Linear(self.feat_dim, self.num_of_classes),)
665 |
666 | def classifier(self, x):
667 | out = self.fc_layer(x)
668 | __, iden = torch.max(out, dim = 1)
669 | iden = iden.view(-1, 1)
670 | return out, iden
671 |
672 | def forward(self, x):
673 | feature = self.feature(x)
674 | feature = feature.view(feature.size(0), -1)
675 | out, iden = self.classifier(feature)
676 | return feature,out
677 |
678 | class EvolveFace(nn.Module):
679 | def __init__(self, num_of_classes, IR152):
680 | super(EvolveFace, self).__init__()
681 | if IR152:
682 | model = evolve.IR_152_64((64,64))
683 | else:
684 | model = evolve.IR_50_64((64,64))
685 | self.model = model
686 | self.feat_dim = 512
687 | self.num_classes = num_of_classes
688 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
689 | nn.Dropout(p=0.15),
690 | Flatten(),
691 | nn.Linear(512 * 4 * 4, 512),
692 | nn.BatchNorm1d(512))
693 |
694 | self.fc_layer = nn.Sequential(
695 | nn.Linear(self.feat_dim, self.num_classes),)
696 |
697 |
698 | def classifier(self, x):
699 | out = self.fc_layer(x)
700 | __, iden = torch.max(out, dim = 1)
701 | iden = iden.view(-1, 1)
702 | return out, iden
703 |
704 | def forward(self,x):
705 | feature = self.model(x)
706 | feature = self.output_layer(feature)
707 | feature = feature.view(feature.size(0), -1)
708 | out, iden = self.classifier(feature)
709 |
710 | return out
711 |
712 | class FaceNet(nn.Module):
713 | def __init__(self, num_classes=1000):
714 | super(FaceNet, self).__init__()
715 | self.feature = evolve.IR_50_112((112, 112))
716 | self.feat_dim = 512
717 | self.num_classes = num_classes
718 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
719 |
720 | def predict(self, x):
721 | feat = self.feature(x)
722 | feat = feat.view(feat.size(0), -1)
723 | out = self.fc_layer(feat)
724 | return out
725 |
726 | def forward(self, x):
727 | feat = self.feature(x)
728 | feat = feat.view(feat.size(0), -1)
729 | out = self.fc_layer(feat)
730 | return [feat, out]
731 |
732 | class FaceNet64(nn.Module):
733 | def __init__(self, num_classes = 1000):
734 | super(FaceNet64, self).__init__()
735 | self.feature = evolve.IR_50_64((64, 64))
736 | self.feat_dim = 512
737 | self.num_classes = num_classes
738 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
739 | nn.Dropout(),
740 | Flatten(),
741 | nn.Linear(512 * 4 * 4, 512),
742 | nn.BatchNorm1d(512))
743 |
744 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
745 |
746 | def forward(self, x):
747 | feat = self.feature(x)
748 | feat = self.output_layer(feat)
749 | feat = feat.view(feat.size(0), -1)
750 | out = self.fc_layer(feat)
751 | __, iden = torch.max(out, dim=1)
752 | iden = iden.view(-1, 1)
753 | return feat, out
754 |
755 |
756 | class IR152(nn.Module):
757 | def __init__(self, num_classes=1000):
758 | super(IR152, self).__init__()
759 | self.feature = evolve.IR_152_64((64, 64))
760 | self.feat_dim = 512
761 | self.num_classes = num_classes
762 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
763 | nn.Dropout(),
764 | Flatten(),
765 | nn.Linear(512 * 4 * 4, 512),
766 | nn.BatchNorm1d(512))
767 |
768 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
769 |
770 | def forward(self, x):
771 | feat = self.feature(x)
772 | feat = self.output_layer(feat)
773 | feat = feat.view(feat.size(0), -1)
774 | out = self.fc_layer(feat)
775 | return feat, out
776 |
777 | class IR152_vib(nn.Module):
778 | def __init__(self, num_classes=1000):
779 | super(IR152_vib, self).__init__()
780 | self.feature = evolve.IR_152_64((64, 64))
781 | self.feat_dim = 512
782 | self.k = self.feat_dim // 2
783 | self.n_classes = num_classes
784 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
785 | nn.Dropout(),
786 | Flatten(),
787 | nn.Linear(512 * 4 * 4, 512),
788 | nn.BatchNorm1d(512))
789 |
790 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2)
791 | self.fc_layer = nn.Sequential(
792 | nn.Linear(self.k, self.n_classes),
793 | nn.Softmax(dim = 1))
794 |
795 | def forward(self, x):
796 | feature = self.output_layer(self.feature(x))
797 | feature = feature.view(feature.size(0), -1)
798 | statis = self.st_layer(feature)
799 | mu, std = statis[:, :self.k], statis[:, self.k:]
800 |
801 | std = F.softplus(std-5, beta=1)
802 | eps = torch.FloatTensor(std.size()).normal_().cuda()
803 | res = mu + std * eps
804 | out = self.fc_layer(res)
805 | __, iden = torch.max(out, dim=1)
806 | iden = iden.view(-1, 1)
807 |
808 | return feature, out, iden, mu#, st
809 |
810 |
811 | class IR50(nn.Module):
812 | def __init__(self, num_classes=1000):
813 | super(IR50, self).__init__()
814 | self.feature = evolve.IR_50_64((64, 64))
815 | self.feat_dim = 512
816 | self.num_classes = num_classes
817 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
818 | nn.Dropout(),
819 | Flatten(),
820 | nn.Linear(512 * 4 * 4, 512),
821 | nn.BatchNorm1d(512))
822 |
823 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2)
824 | self.fc_layer = nn.Sequential(
825 | nn.Linear(self.k, self.n_classes),
826 | nn.Softmax(dim = 1))
827 |
828 | def forward(self, x):
829 | feature = self.output_layer(self.feature(x))
830 | feature = feature.view(feature.size(0), -1)
831 | statis = self.st_layer(feature)
832 | mu, std = statis[:, :self.k], statis[:, self.k:]
833 |
834 | std = F.softplus(std-5, beta=1)
835 | eps = torch.FloatTensor(std.size()).normal_().cuda()
836 | res = mu + std * eps
837 | out = self.fc_layer(res)
838 | __, iden = torch.max(out, dim=1)
839 | iden = iden.view(-1, 1)
840 |
841 | return feature, out, iden, mu, std
842 |
843 | class IR50_vib(nn.Module):
844 | def __init__(self, num_classes=1000):
845 | super(IR50_vib, self).__init__()
846 | self.feature = evolve.IR_50_64((64, 64))
847 | self.feat_dim = 512
848 | self.n_classes = num_classes
849 | self.k = self.feat_dim // 2
850 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
851 | nn.Dropout(),
852 | Flatten(),
853 | nn.Linear(512 * 4 * 4, 512),
854 | nn.BatchNorm1d(512))
855 |
856 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2)
857 | self.fc_layer = nn.Sequential(
858 | nn.Linear(self.k, self.n_classes),
859 | nn.Softmax(dim=1))
860 |
861 | def forward(self, x):
862 | feat = self.output_layer(self.feature(x))
863 | feat = feat.view(feat.size(0), -1)
864 | statis = self.st_layer(feat)
865 | mu, std = statis[:, :self.k], statis[:, self.k:]
866 |
867 | std = F.softplus(std-5, beta=1)
868 | eps = torch.FloatTensor(std.size()).normal_().cuda()
869 | res = mu + std * eps
870 | out = self.fc_layer(res)
871 | __, iden = torch.max(out, dim=1)
872 | iden = iden.view(-1, 1)
873 |
874 | return feat, out, iden, mu, std
875 |
876 |
877 |
878 | def get_classifier(model_name, mode, n_classes, resume_path):
879 | if model_name == "VGG16":
880 | if mode == "reg":
881 | net = VGG16(n_classes)
882 | elif mode == "vib":
883 | net = VGG16_vib(n_classes)
884 |
885 | elif model_name == "FaceNet":
886 | net = FaceNet(n_classes)
887 |
888 | elif model_name == "FaceNet_all":
889 | net = FaceNet(202599)
890 |
891 | elif model_name == "FaceNet64":
892 | net = FaceNet64(n_classes)
893 |
894 | elif model_name == "IR50":
895 | if mode == "reg":
896 | net = IR50(n_classes)
897 | elif mode == "vib":
898 | net = IR50_vib(n_classes)
899 |
900 | elif model_name == "IR152":
901 | if mode == "reg":
902 | net = IR152(n_classes)
903 | else:
904 | net = IR152_vib(n_classes)
905 |
906 | else:
907 | print("Model name Error")
908 | exit()
909 |
910 | if model_name in ['FaceNet', 'FaceNet_all', 'FaceNet_64', 'IR50', 'IR152']:
911 | if resume_path is not "":
912 | print("Resume")
913 | utils.load_state_dict(net.feature, torch.load(resume_path))
914 | else:
915 | print("No Resume")
916 |
917 | return net
918 |
--------------------------------------------------------------------------------
/models/discri.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from utils import LinearWeightNorm
5 | import torch.nn.init as init
6 |
7 | class MinibatchDiscriminator_MNIST(nn.Module):
8 | def __init__(self,in_dim=3, dim=64, n_classes=1000):
9 | super(MinibatchDiscriminator_MNIST, self).__init__()
10 | self.n_classes = n_classes
11 |
12 | def conv_ln_lrelu(in_dim, out_dim, k, s, p):
13 | return nn.Sequential(
14 | nn.Conv2d(in_dim, out_dim, k, s, p),
15 | # Since there is no effective implementation of LayerNorm,
16 | # we use InstanceNorm2d instead of LayerNorm here.
17 | nn.InstanceNorm2d(out_dim, affine=True),
18 | nn.LeakyReLU(0.2))
19 |
20 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2)
21 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2)
22 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2)
23 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1)
24 | self.mbd1 = MinibatchDiscrimination(dim*4*4, 64, 50)
25 | self.fc_layer = nn.Linear(dim*4*4+64, self.n_classes)
26 |
27 | def forward(self, x):
28 | out = []
29 | bs = x.shape[0]
30 | feat1 = self.layer1(x)
31 | out.append(feat1)
32 | feat2 = self.layer2(feat1)
33 | out.append(feat2)
34 | feat3 = self.layer3(feat2)
35 | out.append(feat3)
36 | feat4 = self.layer4(feat3)
37 | out.append(feat4)
38 | feat = feat4.view(bs, -1)
39 | # print('feat:', feat.shape)
40 | mb_out = self.mbd1(feat) # Nx(A+B)
41 | y = self.fc_layer(mb_out)
42 |
43 | return feat, y
44 | # return mb_out, y
45 |
46 |
47 | class MinibatchDiscrimination(nn.Module):
48 | def __init__(self, in_features, out_features, kernel_dims, mean=False):
49 | super().__init__()
50 | self.in_features = in_features
51 | self.out_features = out_features
52 | self.kernel_dims = kernel_dims
53 | self.mean = mean
54 | self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims))
55 | init.normal(self.T, 0, 1)
56 |
57 | def forward(self, x):
58 | # x is NxA
59 | # T is AxBxC
60 | matrices = x.mm(self.T.view(self.in_features, -1))
61 | matrices = matrices.view(-1, self.out_features, self.kernel_dims)
62 |
63 | M = matrices.unsqueeze(0) # 1xNxBxC
64 | M_T = M.permute(1, 0, 2, 3) # Nx1xBxC
65 | norm = torch.abs(M - M_T).sum(3) # NxNxB
66 | expnorm = torch.exp(-norm)
67 | o_b = (expnorm.sum(0) - 1) # NxB, subtract self distance
68 | if self.mean:
69 | o_b /= x.size(0) - 1
70 |
71 | x = torch.cat([x, o_b], 1)
72 | return x
73 |
74 | class MinibatchDiscriminator(nn.Module):
75 | def __init__(self,in_dim=3, dim=64, n_classes=1000):
76 | super(MinibatchDiscriminator, self).__init__()
77 | self.n_classes = n_classes
78 |
79 | def conv_ln_lrelu(in_dim, out_dim, k, s, p):
80 | return nn.Sequential(
81 | nn.Conv2d(in_dim, out_dim, k, s, p),
82 | # Since there is no effective implementation of LayerNorm,
83 | # we use InstanceNorm2d instead of LayerNorm here.
84 | nn.InstanceNorm2d(out_dim, affine=True),
85 | nn.LeakyReLU(0.2))
86 |
87 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2)
88 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2)
89 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2)
90 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1)
91 | self.mbd1 = MinibatchDiscrimination(dim*4*4*4, 64, 50)
92 | self.fc_layer = nn.Linear(dim*4*4*4+64, self.n_classes)
93 |
94 | def forward(self, x):
95 | out = []
96 | bs = x.shape[0]
97 | feat1 = self.layer1(x)
98 | out.append(feat1)
99 | feat2 = self.layer2(feat1)
100 | out.append(feat2)
101 | feat3 = self.layer3(feat2)
102 | out.append(feat3)
103 | feat4 = self.layer4(feat3)
104 | out.append(feat4)
105 | feat = feat4.view(bs, -1)
106 | # print('feat:', feat.shape)
107 | mb_out = self.mbd1(feat) # Nx(A+B)
108 | y = self.fc_layer(mb_out)
109 |
110 | return feat, y
111 | # return mb_out, y
112 |
113 |
114 | class Discriminator(nn.Module):
115 | def __init__(self,in_dim=3, dim=64, n_classes=1000):
116 | super(Discriminator, self).__init__()
117 | self.n_classes = n_classes
118 |
119 | def conv_ln_lrelu(in_dim, out_dim, k, s, p):
120 | return nn.Sequential(
121 | nn.Conv2d(in_dim, out_dim, k, s, p),
122 | # Since there is no effective implementation of LayerNorm,
123 | # we use InstanceNorm2d instead of LayerNorm here.
124 | nn.InstanceNorm2d(out_dim, affine=True),
125 | nn.LeakyReLU(0.2))
126 |
127 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2)
128 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2)
129 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2)
130 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1)
131 | self.fc_layer = nn.Linear(dim*4*4*4, self.n_classes)
132 |
133 | def forward(self, x):
134 | bs = x.shape[0]
135 | feat1 = self.layer1(x)
136 | feat2 = self.layer2(feat1)
137 | feat3 = self.layer3(feat2)
138 | feat4 = self.layer4(feat3)
139 | feat = feat4.view(bs, -1)
140 | y = self.fc_layer(feat)
141 |
142 | return feat, y
143 |
144 |
145 | class DiscriminatorMNIST(nn.Module):
146 | def __init__(self, d_input_dim=1024):
147 | super(DiscriminatorMNIST, self).__init__()
148 | self.fc1 = nn.Linear(d_input_dim, 1024)
149 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
150 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
151 | self.fc4 = nn.Linear(self.fc3.out_features, 1)
152 |
153 | # forward method
154 | def forward(self, x):
155 | x = x.view(x.size(0), -1)
156 | x = F.leaky_relu(self.fc1(x), 0.2)
157 | x = F.dropout(x, 0.3)
158 | x = F.leaky_relu(self.fc2(x), 0.2)
159 | x = F.dropout(x, 0.3)
160 | x = F.leaky_relu(self.fc3(x), 0.2)
161 | x = F.dropout(x, 0.3)
162 | y = self.fc4(x)
163 | y = y.view(-1)
164 |
165 | return y
166 |
167 | class DGWGAN32(nn.Module):
168 | def __init__(self, in_dim=1, dim=64):
169 | super(DGWGAN32, self).__init__()
170 | def conv_ln_lrelu(in_dim, out_dim):
171 | return nn.Sequential(
172 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
173 | # Since there is no effective implementation of LayerNorm,
174 | # we use InstanceNorm2d instead of LayerNorm here.
175 | nn.InstanceNorm2d(out_dim, affine=True),
176 | nn.LeakyReLU(0.2))
177 |
178 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2))
179 | self.layer2 = conv_ln_lrelu(dim, dim * 2)
180 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4)
181 | self.layer4 = nn.Conv2d(dim * 4, 1, 4)
182 |
183 | def forward(self, x):
184 | feat1 = self.layer1(x)
185 | feat2 = self.layer2(feat1)
186 | feat3 = self.layer3(feat2)
187 | y = self.layer4(feat3)
188 | y = y.view(-1)
189 | return y
190 |
191 | class DGWGAN(nn.Module):
192 | def __init__(self, in_dim=3, dim=64):
193 | super(DGWGAN, self).__init__()
194 | def conv_ln_lrelu(in_dim, out_dim):
195 | return nn.Sequential(
196 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
197 | # Since there is no effective implementation of LayerNorm,
198 | # we use InstanceNorm2d instead of LayerNorm here.
199 | nn.InstanceNorm2d(out_dim, affine=True),
200 | nn.LeakyReLU(0.2))
201 |
202 | self.ls = nn.Sequential(
203 | nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2),
204 | conv_ln_lrelu(dim, dim * 2),
205 | conv_ln_lrelu(dim * 2, dim * 4),
206 | conv_ln_lrelu(dim * 4, dim * 8),
207 | nn.Conv2d(dim * 8, 1, 4))
208 |
209 | def forward(self, x):
210 | y = self.ls(x)
211 | y = y.view(-1)
212 | return y
213 |
214 | class DLWGAN(nn.Module):
215 | def __init__(self, in_dim=3, dim=64):
216 | super(DLWGAN, self).__init__()
217 |
218 | def conv_ln_lrelu(in_dim, out_dim):
219 | return nn.Sequential(
220 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
221 | # Since there is no effective implementation of LayerNorm,
222 | # we use InstanceNorm2d instead of LayerNorm here.
223 | nn.InstanceNorm2d(out_dim, affine=True),
224 | nn.LeakyReLU(0.2))
225 |
226 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2))
227 | self.layer2 = conv_ln_lrelu(dim, dim * 2)
228 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4)
229 | self.layer4 = nn.Conv2d(dim * 4, 1, 4)
230 |
231 |
232 | def forward(self, x):
233 | feat1 = self.layer1(x)
234 | feat2 = self.layer2(feat1)
235 | feat3 = self.layer3(feat2)
236 | y = self.layer4(feat3)
237 | return y
238 |
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------
/models/evolve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \
4 | AdaptiveAvgPool2d, Sequential, Module
5 | from collections import namedtuple
6 |
7 |
8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152']
9 |
10 |
11 | class Flatten(Module):
12 | def forward(self, input):
13 | return input.view(input.size(0), -1)
14 |
15 |
16 | def l2_norm(input, axis=1):
17 | norm = torch.norm(input, 2, axis, True)
18 | output = torch.div(input, norm)
19 |
20 | return output
21 |
22 |
23 | class SEModule(Module):
24 | def __init__(self, channels, reduction):
25 | super(SEModule, self).__init__()
26 | self.avg_pool = AdaptiveAvgPool2d(1)
27 | self.fc1 = Conv2d(
28 | channels, channels // reduction, kernel_size=1, padding=0, bias=False)
29 |
30 | nn.init.xavier_uniform_(self.fc1.weight.data)
31 |
32 | self.relu = ReLU(inplace=True)
33 | self.fc2 = Conv2d(
34 | channels // reduction, channels, kernel_size=1, padding=0, bias=False)
35 |
36 | self.sigmoid = Sigmoid()
37 |
38 | def forward(self, x):
39 | module_input = x
40 | x = self.avg_pool(x)
41 | x = self.fc1(x)
42 | x = self.relu(x)
43 | x = self.fc2(x)
44 | x = self.sigmoid(x)
45 |
46 | return module_input * x
47 |
48 |
49 | class bottleneck_IR(Module):
50 | def __init__(self, in_channel, depth, stride):
51 | super(bottleneck_IR, self).__init__()
52 | if in_channel == depth:
53 | self.shortcut_layer = MaxPool2d(1, stride)
54 | else:
55 | self.shortcut_layer = Sequential(
56 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
57 | self.res_layer = Sequential(
58 | BatchNorm2d(in_channel),
59 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
60 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
61 |
62 | def forward(self, x):
63 | shortcut = self.shortcut_layer(x)
64 | res = self.res_layer(x)
65 |
66 | return res + shortcut
67 |
68 |
69 | class bottleneck_IR_SE(Module):
70 | def __init__(self, in_channel, depth, stride):
71 | super(bottleneck_IR_SE, self).__init__()
72 | if in_channel == depth:
73 | self.shortcut_layer = MaxPool2d(1, stride)
74 | else:
75 | self.shortcut_layer = Sequential(
76 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
77 | BatchNorm2d(depth))
78 | self.res_layer = Sequential(
79 | BatchNorm2d(in_channel),
80 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
81 | PReLU(depth),
82 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
83 | BatchNorm2d(depth),
84 | SEModule(depth, 16)
85 | )
86 |
87 | def forward(self, x):
88 | shortcut = self.shortcut_layer(x)
89 | res = self.res_layer(x)
90 |
91 | return res + shortcut
92 |
93 |
94 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
95 | '''A named tuple describing a ResNet block.'''
96 |
97 |
98 | def get_block(in_channel, depth, num_units, stride=2):
99 |
100 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
101 |
102 |
103 | def get_blocks(num_layers):
104 | if num_layers == 50:
105 | blocks = [
106 | get_block(in_channel=64, depth=64, num_units=3),
107 | get_block(in_channel=64, depth=128, num_units=4),
108 | get_block(in_channel=128, depth=256, num_units=14),
109 | get_block(in_channel=256, depth=512, num_units=3)
110 | ]
111 | elif num_layers == 100:
112 | blocks = [
113 | get_block(in_channel=64, depth=64, num_units=3),
114 | get_block(in_channel=64, depth=128, num_units=13),
115 | get_block(in_channel=128, depth=256, num_units=30),
116 | get_block(in_channel=256, depth=512, num_units=3)
117 | ]
118 | elif num_layers == 152:
119 | blocks = [
120 | get_block(in_channel=64, depth=64, num_units=3),
121 | get_block(in_channel=64, depth=128, num_units=8),
122 | get_block(in_channel=128, depth=256, num_units=36),
123 | get_block(in_channel=256, depth=512, num_units=3)
124 | ]
125 |
126 | return blocks
127 |
128 |
129 | class Backbone64(Module):
130 | def __init__(self, input_size, num_layers, mode='ir'):
131 | super(Backbone64, self).__init__()
132 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]"
133 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
134 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
135 | blocks = get_blocks(num_layers)
136 | if mode == 'ir':
137 | unit_module = bottleneck_IR
138 | elif mode == 'ir_se':
139 | unit_module = bottleneck_IR_SE
140 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
141 | BatchNorm2d(64),
142 | PReLU(64))
143 |
144 | modules = []
145 | for block in blocks:
146 | for bottleneck in block:
147 | modules.append(
148 | unit_module(bottleneck.in_channel,
149 | bottleneck.depth,
150 | bottleneck.stride))
151 | self.body = Sequential(*modules)
152 |
153 | self._initialize_weights()
154 |
155 | def forward(self, x):
156 | x = self.input_layer(x)
157 | x = self.body(x)
158 | #x = self.output_layer(x)
159 |
160 | return x
161 |
162 | def _initialize_weights(self):
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | nn.init.xavier_uniform_(m.weight.data)
166 | if m.bias is not None:
167 | m.bias.data.zero_()
168 | elif isinstance(m, nn.BatchNorm2d):
169 | m.weight.data.fill_(1)
170 | m.bias.data.zero_()
171 | elif isinstance(m, nn.BatchNorm1d):
172 | m.weight.data.fill_(1)
173 | m.bias.data.zero_()
174 | elif isinstance(m, nn.Linear):
175 | nn.init.xavier_uniform_(m.weight.data)
176 | if m.bias is not None:
177 | m.bias.data.zero_()
178 |
179 | class Backbone112(Module):
180 | def __init__(self, input_size, num_layers, mode='ir'):
181 | super(Backbone112, self).__init__()
182 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]"
183 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
184 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
185 | blocks = get_blocks(num_layers)
186 | if mode == 'ir':
187 | unit_module = bottleneck_IR
188 | elif mode == 'ir_se':
189 | unit_module = bottleneck_IR_SE
190 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
191 | BatchNorm2d(64),
192 | PReLU(64))
193 |
194 | if input_size[0] == 112:
195 | self.output_layer = Sequential(BatchNorm2d(512),
196 | Dropout(),
197 | Flatten(),
198 | Linear(512 * 7 * 7, 512),
199 | BatchNorm1d(512))
200 | else:
201 | self.output_layer = Sequential(BatchNorm2d(512),
202 | Dropout(),
203 | Flatten(),
204 | Linear(512 * 14 * 14, 512),
205 | BatchNorm1d(512))
206 |
207 | modules = []
208 | for block in blocks:
209 | for bottleneck in block:
210 | modules.append(
211 | unit_module(bottleneck.in_channel,
212 | bottleneck.depth,
213 | bottleneck.stride))
214 | self.body = Sequential(*modules)
215 |
216 | self._initialize_weights()
217 |
218 | def forward(self, x):
219 | x = self.input_layer(x)
220 | x = self.body(x)
221 | x = self.output_layer(x)
222 |
223 | return x
224 |
225 | def _initialize_weights(self):
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.xavier_uniform_(m.weight.data)
229 | if m.bias is not None:
230 | m.bias.data.zero_()
231 | elif isinstance(m, nn.BatchNorm2d):
232 | m.weight.data.fill_(1)
233 | m.bias.data.zero_()
234 | elif isinstance(m, nn.BatchNorm1d):
235 | m.weight.data.fill_(1)
236 | m.bias.data.zero_()
237 | elif isinstance(m, nn.Linear):
238 | nn.init.xavier_uniform_(m.weight.data)
239 | if m.bias is not None:
240 | m.bias.data.zero_()
241 |
242 |
243 | def IR_50_64(input_size):
244 | """Constructs a ir-50 model.
245 | """
246 | model = Backbone64(input_size, 50, 'ir')
247 |
248 | return model
249 |
250 | def IR_50_112(input_size):
251 | """Constructs a ir-50 model.
252 | """
253 | model = Backbone112(input_size, 50, 'ir')
254 |
255 | return model
256 |
257 |
258 | def IR_100(input_size):
259 | """Constructs a ir-100 model.
260 | """
261 | model = Backbone(input_size, 100, 'ir')
262 |
263 | return model
264 |
265 | def IR_152_64(input_size):
266 | """Constructs a ir-152 model.
267 | """
268 | model = Backbone64(input_size, 152, 'ir')
269 |
270 | return model
271 |
272 |
273 | def IR_152_112(input_size):
274 | """Constructs a ir-152 model.
275 | """
276 | model = Backbone112(input_size, 152, 'ir')
277 |
278 | return model
279 |
280 | def IR_SE_50(input_size):
281 | """Constructs a ir_se-50 model.
282 | """
283 | model = Backbone(input_size, 50, 'ir_se')
284 |
285 | return model
286 |
287 |
288 | def IR_SE_101(input_size):
289 | """Constructs a ir_se-101 model.
290 | """
291 | model = Backbone(input_size, 100, 'ir_se')
292 |
293 | return model
294 |
295 |
296 | def IR_SE_152(input_size):
297 | """Constructs a ir_se-152 model.
298 | """
299 | model = Backbone(input_size, 152, 'ir_se')
300 |
301 | return model
--------------------------------------------------------------------------------
/models/facenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \
4 | AdaptiveAvgPool2d, Sequential, Module
5 | from collections import namedtuple
6 |
7 |
8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152']
9 | class FaceNet(nn.Module):
10 | def __init__(self, num_classes = 1000):
11 | super(FaceNet, self).__init__()
12 | self.feature = IR_50_112((112, 112))
13 | self.feat_dim = 512
14 | self.num_classes = num_classes
15 | self.fc_layer = nn.Sequential(
16 | nn.Linear(self.feat_dim, self.num_classes),
17 | nn.Softmax(dim = 1))
18 |
19 | def forward(self, x):
20 | feat = self.feature(x)
21 | feat = feat.view(feat.size(0), -1)
22 | out = self.fc_layer(feat)
23 | __, iden = torch.max(out, dim = 1)
24 | iden = iden.view(-1, 1)
25 | return feat, out, iden
26 |
27 | class FaceNet64(nn.Module):
28 | def __init__(self, num_classes = 1000):
29 | super(FaceNet64, self).__init__()
30 | self.feature = IR_50_64((64, 64))
31 | self.feat_dim = 512
32 | self.num_classes = num_classes
33 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
34 | nn.Dropout(),
35 | Flatten(),
36 | nn.Linear(512 * 4 * 4, 512),
37 | nn.BatchNorm1d(512))
38 |
39 | self.fc_layer = nn.Sequential(
40 | nn.Linear(self.feat_dim, self.num_classes),
41 | nn.Softmax(dim = 1))
42 |
43 | def forward(self, x):
44 | feat = self.feature(x)
45 | feat = self.output_layer(feat)
46 | feat = feat.view(feat.size(0), -1)
47 | out = self.fc_layer(feat)
48 | __, iden = torch.max(out, dim = 1)
49 | iden = iden.view(-1, 1)
50 | return feat, out, iden
51 |
52 | class Flatten(Module):
53 | def forward(self, input):
54 | return input.view(input.size(0), -1)
55 |
56 |
57 | def l2_norm(input, axis=1):
58 | norm = torch.norm(input, 2, axis, True)
59 | output = torch.div(input, norm)
60 |
61 | return output
62 |
63 |
64 | class SEModule(Module):
65 | def __init__(self, channels, reduction):
66 | super(SEModule, self).__init__()
67 | self.avg_pool = AdaptiveAvgPool2d(1)
68 | self.fc1 = Conv2d(
69 | channels, channels // reduction, kernel_size=1, padding=0, bias=False)
70 |
71 | nn.init.xavier_uniform_(self.fc1.weight.data)
72 |
73 | self.relu = ReLU(inplace=True)
74 | self.fc2 = Conv2d(
75 | channels // reduction, channels, kernel_size=1, padding=0, bias=False)
76 |
77 | self.sigmoid = Sigmoid()
78 |
79 | def forward(self, x):
80 | module_input = x
81 | x = self.avg_pool(x)
82 | x = self.fc1(x)
83 | x = self.relu(x)
84 | x = self.fc2(x)
85 | x = self.sigmoid(x)
86 |
87 | return module_input * x
88 |
89 |
90 | class bottleneck_IR(Module):
91 | def __init__(self, in_channel, depth, stride):
92 | super(bottleneck_IR, self).__init__()
93 | if in_channel == depth:
94 | self.shortcut_layer = MaxPool2d(1, stride)
95 | else:
96 | self.shortcut_layer = Sequential(
97 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
98 | self.res_layer = Sequential(
99 | BatchNorm2d(in_channel),
100 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
101 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
102 |
103 | def forward(self, x):
104 | shortcut = self.shortcut_layer(x)
105 | res = self.res_layer(x)
106 |
107 | return res + shortcut
108 |
109 |
110 | class bottleneck_IR_SE(Module):
111 | def __init__(self, in_channel, depth, stride):
112 | super(bottleneck_IR_SE, self).__init__()
113 | if in_channel == depth:
114 | self.shortcut_layer = MaxPool2d(1, stride)
115 | else:
116 | self.shortcut_layer = Sequential(
117 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
118 | BatchNorm2d(depth))
119 | self.res_layer = Sequential(
120 | BatchNorm2d(in_channel),
121 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
122 | PReLU(depth),
123 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
124 | BatchNorm2d(depth),
125 | SEModule(depth, 16)
126 | )
127 |
128 | def forward(self, x):
129 | shortcut = self.shortcut_layer(x)
130 | res = self.res_layer(x)
131 |
132 | return res + shortcut
133 |
134 |
135 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
136 | '''A named tuple describing a ResNet block.'''
137 |
138 |
139 | def get_block(in_channel, depth, num_units, stride=2):
140 |
141 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
142 |
143 |
144 | def get_blocks(num_layers):
145 | if num_layers == 50:
146 | blocks = [
147 | get_block(in_channel=64, depth=64, num_units=3),
148 | get_block(in_channel=64, depth=128, num_units=4),
149 | get_block(in_channel=128, depth=256, num_units=14),
150 | get_block(in_channel=256, depth=512, num_units=3)
151 | ]
152 | elif num_layers == 100:
153 | blocks = [
154 | get_block(in_channel=64, depth=64, num_units=3),
155 | get_block(in_channel=64, depth=128, num_units=13),
156 | get_block(in_channel=128, depth=256, num_units=30),
157 | get_block(in_channel=256, depth=512, num_units=3)
158 | ]
159 | elif num_layers == 152:
160 | blocks = [
161 | get_block(in_channel=64, depth=64, num_units=3),
162 | get_block(in_channel=64, depth=128, num_units=8),
163 | get_block(in_channel=128, depth=256, num_units=36),
164 | get_block(in_channel=256, depth=512, num_units=3)
165 | ]
166 |
167 | return blocks
168 |
169 |
170 | class Backbone64(Module):
171 | def __init__(self, input_size, num_layers, mode='ir'):
172 | super(Backbone64, self).__init__()
173 | assert input_size[0] in [64], "input_size should be [112, 112] or [224, 224]"
174 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
175 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
176 | blocks = get_blocks(num_layers)
177 | if mode == 'ir':
178 | unit_module = bottleneck_IR
179 | elif mode == 'ir_se':
180 | unit_module = bottleneck_IR_SE
181 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
182 | BatchNorm2d(64),
183 | PReLU(64))
184 |
185 | self.output_layer = Sequential(BatchNorm2d(512),
186 | Dropout(),
187 | Flatten(),
188 | Linear(512 * 14 * 14, 512),
189 | BatchNorm1d(512))
190 |
191 | modules = []
192 | for block in blocks:
193 | for bottleneck in block:
194 | modules.append(
195 | unit_module(bottleneck.in_channel,
196 | bottleneck.depth,
197 | bottleneck.stride))
198 | self.body = Sequential(*modules)
199 |
200 | self._initialize_weights()
201 |
202 | def forward(self, x):
203 | x = self.input_layer(x)
204 | x = self.body(x)
205 |
206 | return x
207 |
208 | def _initialize_weights(self):
209 | for m in self.modules():
210 | if isinstance(m, nn.Conv2d):
211 | nn.init.xavier_uniform_(m.weight.data)
212 | if m.bias is not None:
213 | m.bias.data.zero_()
214 | elif isinstance(m, nn.BatchNorm2d):
215 | m.weight.data.fill_(1)
216 | m.bias.data.zero_()
217 | elif isinstance(m, nn.BatchNorm1d):
218 | m.weight.data.fill_(1)
219 | m.bias.data.zero_()
220 | elif isinstance(m, nn.Linear):
221 | nn.init.xavier_uniform_(m.weight.data)
222 | if m.bias is not None:
223 | m.bias.data.zero_()
224 |
225 | class Backbone112(Module):
226 | def __init__(self, input_size, num_layers, mode='ir'):
227 | super(Backbone112, self).__init__()
228 | assert input_size[0] in [112], "input_size should be [112, 112] or [224, 224]"
229 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
230 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
231 | blocks = get_blocks(num_layers)
232 | if mode == 'ir':
233 | unit_module = bottleneck_IR
234 | elif mode == 'ir_se':
235 | unit_module = bottleneck_IR_SE
236 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
237 | BatchNorm2d(64),
238 | PReLU(64))
239 |
240 | if input_size[0] == 112:
241 | self.output_layer = Sequential(BatchNorm2d(512),
242 | Dropout(),
243 | Flatten(),
244 | Linear(512 * 7 * 7, 512),
245 | BatchNorm1d(512))
246 |
247 | modules = []
248 | for block in blocks:
249 | for bottleneck in block:
250 | modules.append(
251 | unit_module(bottleneck.in_channel,
252 | bottleneck.depth,
253 | bottleneck.stride))
254 | self.body = Sequential(*modules)
255 |
256 | self._initialize_weights()
257 |
258 | def forward(self, x):
259 | x = self.input_layer(x)
260 | x = self.body(x)
261 | x = self.output_layer(x)
262 |
263 | return x
264 |
265 | def _initialize_weights(self):
266 | for m in self.modules():
267 | if isinstance(m, nn.Conv2d):
268 | nn.init.xavier_uniform_(m.weight.data)
269 | if m.bias is not None:
270 | m.bias.data.zero_()
271 | elif isinstance(m, nn.BatchNorm2d):
272 | m.weight.data.fill_(1)
273 | m.bias.data.zero_()
274 | elif isinstance(m, nn.BatchNorm1d):
275 | m.weight.data.fill_(1)
276 | m.bias.data.zero_()
277 | elif isinstance(m, nn.Linear):
278 | nn.init.xavier_uniform_(m.weight.data)
279 | if m.bias is not None:
280 | m.bias.data.zero_()
281 |
282 |
283 | def IR_50_64(input_size):
284 | """Constructs a ir-50 model.
285 | """
286 | model = Backbone64(input_size, 50, 'ir')
287 |
288 | return model
289 |
290 | def IR_50_112(input_size):
291 | """Constructs a ir-50 model.
292 | """
293 | model = Backbone112(input_size, 50, 'ir')
294 |
295 | return model
296 |
297 |
298 | def IR_101(input_size):
299 | """Constructs a ir-101 model.
300 | """
301 | model = Backbone(input_size, 100, 'ir')
302 |
303 | return model
304 |
305 |
306 | def IR_152_64(input_size):
307 | """Constructs a ir-152 model.
308 | """
309 | model = Backbone64(input_size, 152, 'ir')
310 |
311 | return model
312 |
313 | def IR_152_112(input_size):
314 | """Constructs a ir-152 model.
315 | """
316 | model = Backbone112(input_size, 152, 'ir')
317 |
318 | return model
319 |
320 |
321 | def IR_SE_50(input_size):
322 | """Constructs a ir_se-50 model.
323 | """
324 | model = Backbone(input_size, 50, 'ir_se')
325 |
326 | return model
327 |
328 |
329 | def IR_SE_101(input_size):
330 | """Constructs a ir_se-101 model.
331 | """
332 | model = Backbone(input_size, 100, 'ir_se')
333 |
334 | return model
335 |
336 |
337 | def IR_SE_152(input_size):
338 | """Constructs a ir_se-152 model.
339 | """
340 | model = Backbone(input_size, 152, 'ir_se')
341 |
342 | return model
--------------------------------------------------------------------------------
/models/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class GeneratorCXR(nn.Module):
5 | def __init__(self, in_dim=100, dim=64):
6 | super(GeneratorCXR, self).__init__()
7 | def dconv_bn_relu(in_dim, out_dim):
8 | return nn.Sequential(
9 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
10 | padding=2, output_padding=1, bias=False),
11 | nn.BatchNorm2d(out_dim),
12 | nn.ReLU())
13 |
14 | self.l1 = nn.Sequential(
15 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
16 | nn.BatchNorm1d(dim * 8 * 4 * 4),
17 | nn.ReLU())
18 | self.l2_5 = nn.Sequential(
19 | dconv_bn_relu(dim * 8, dim * 4),
20 | dconv_bn_relu(dim * 4, dim * 2),
21 | dconv_bn_relu(dim * 2, dim),
22 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1),
23 | nn.Sigmoid())
24 |
25 | def forward(self, x):
26 | y = self.l1(x)
27 | y = y.view(y.size(0), -1, 4, 4)
28 | y = self.l2_5(y)
29 | return y
30 |
31 | class GeneratorMNIST(nn.Module):
32 | def __init__(self, in_dim=100, dim=64):
33 | super(GeneratorMNIST, self).__init__()
34 | def dconv_bn_relu(in_dim, out_dim):
35 | return nn.Sequential(
36 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
37 | padding=2, output_padding=1, bias=False),
38 | nn.BatchNorm2d(out_dim),
39 | nn.ReLU())
40 |
41 | self.l1 = nn.Sequential(
42 | nn.Linear(in_dim, dim * 4 * 4 * 4, bias=False),
43 | nn.BatchNorm1d(dim * 4 * 4 * 4),
44 | nn.ReLU())
45 | self.l2_5 = nn.Sequential(
46 | dconv_bn_relu(dim * 4, dim * 2),
47 | dconv_bn_relu(dim * 2, dim),
48 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1),
49 | nn.Sigmoid())
50 |
51 | def forward(self, x):
52 | y = self.l1(x)
53 | y = y.view(y.size(0), -1, 4, 4)
54 | y = self.l2_5(y)
55 | return y
56 |
57 | class Generator(nn.Module):
58 | def __init__(self, in_dim=100, dim=64):
59 | super(Generator, self).__init__()
60 | def dconv_bn_relu(in_dim, out_dim):
61 | return nn.Sequential(
62 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
63 | padding=2, output_padding=1, bias=False),
64 | nn.BatchNorm2d(out_dim),
65 | nn.ReLU())
66 |
67 | self.l1 = nn.Sequential(
68 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
69 | nn.BatchNorm1d(dim * 8 * 4 * 4),
70 | nn.ReLU())
71 | self.l2_5 = nn.Sequential(
72 | dconv_bn_relu(dim * 8, dim * 4),
73 | dconv_bn_relu(dim * 4, dim * 2),
74 | dconv_bn_relu(dim * 2, dim),
75 | nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
76 | nn.Sigmoid())
77 |
78 | def forward(self, x):
79 | y = self.l1(x)
80 | y = y.view(y.size(0), -1, 4, 4)
81 | y = self.l2_5(y)
82 | return y
83 |
84 | class GeneratorMNIST(nn.Module):
85 | def __init__(self, in_dim=100, dim=64):
86 | super(GeneratorMNIST, self).__init__()
87 | def dconv_bn_relu(in_dim, out_dim):
88 | return nn.Sequential(
89 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
90 | padding=2, output_padding=1, bias=False),
91 | nn.BatchNorm2d(out_dim),
92 | nn.ReLU())
93 |
94 | self.l1 = nn.Sequential(
95 | nn.Linear(in_dim, dim * 4 * 4 * 4, bias=False),
96 | nn.BatchNorm1d(dim * 4 * 4 * 4),
97 | nn.ReLU())
98 | self.l2_5 = nn.Sequential(
99 | dconv_bn_relu(dim * 4, dim * 2),
100 | dconv_bn_relu(dim * 2, dim),
101 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1),
102 | nn.Sigmoid())
103 |
104 | def forward(self, x):
105 | y = self.l1(x)
106 | y = y.view(y.size(0), -1, 4, 4)
107 | y = self.l2_5(y)
108 | return y
109 |
110 | class CompletionNetwork(nn.Module):
111 | def __init__(self):
112 | super(CompletionNetwork, self).__init__()
113 | # input_shape: (None, 4, img_h, img_w)
114 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
115 | self.bn1 = nn.BatchNorm2d(32)
116 | self.act1 = nn.ReLU()
117 | # input_shape: (None, 64, img_h, img_w)
118 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
119 | self.bn2 = nn.BatchNorm2d(64)
120 | self.act2 = nn.ReLU()
121 | # input_shape: (None, 128, img_h//2, img_w//2)
122 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
123 | self.bn3 = nn.BatchNorm2d(64)
124 | self.act3 = nn.ReLU()
125 | # input_shape: (None, 128, img_h//2, img_w//2)
126 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
127 | self.bn4 = nn.BatchNorm2d(128)
128 | self.act4 = nn.ReLU()
129 | # input_shape: (None, 256, img_h//4, img_w//4)
130 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
131 | self.bn5 = nn.BatchNorm2d(128)
132 | self.act5 = nn.ReLU()
133 | # input_shape: (None, 256, img_h//4, img_w//4)
134 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
135 | self.bn6 = nn.BatchNorm2d(128)
136 | self.act6 = nn.ReLU()
137 | # input_shape: (None, 256, img_h//4, img_w//4)
138 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2)
139 | self.bn7 = nn.BatchNorm2d(128)
140 | self.act7 = nn.ReLU()
141 | # input_shape: (None, 256, img_h//4, img_w//4)
142 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4)
143 | self.bn8 = nn.BatchNorm2d(128)
144 | self.act8 = nn.ReLU()
145 | # input_shape: (None, 256, img_h//4, img_w//4)
146 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8)
147 | self.bn9 = nn.BatchNorm2d(128)
148 | self.act9 = nn.ReLU()
149 | # input_shape: (None, 256, img_h//4, img_w//4)
150 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16)
151 | self.bn10 = nn.BatchNorm2d(128)
152 | self.act10 = nn.ReLU()
153 | # input_shape: (None, 256, img_h//4, img_w//4)
154 | self.conv11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
155 | self.bn11 = nn.BatchNorm2d(128)
156 | self.act11 = nn.ReLU()
157 | # input_shape: (None, 256, img_h//4, img_w//4)
158 | self.conv12 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
159 | self.bn12 = nn.BatchNorm2d(128)
160 | self.act12 = nn.ReLU()
161 | # input_shape: (None, 256, img_h//4, img_w//4)
162 | self.deconv13 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
163 | self.bn13 = nn.BatchNorm2d(64)
164 | self.act13 = nn.ReLU()
165 | # input_shape: (None, 128, img_h//2, img_w//2)
166 | self.conv14 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
167 | self.bn14 = nn.BatchNorm2d(64)
168 | self.act14 = nn.ReLU()
169 | # input_shape: (None, 128, img_h//2, img_w//2)
170 | self.deconv15 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
171 | self.bn15 = nn.BatchNorm2d(32)
172 | self.act15 = nn.ReLU()
173 | # input_shape: (None, 64, img_h, img_w)
174 | self.conv16 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
175 | self.bn16 = nn.BatchNorm2d(32)
176 | self.act16 = nn.ReLU()
177 | # input_shape: (None, 32, img_h, img_w)
178 | self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)
179 | self.act17 = nn.Sigmoid()
180 | # output_shape: (None, 3, img_h. img_w)
181 |
182 | def forward(self, x):
183 | x = self.bn1(self.act1(self.conv1(x)))
184 | x = self.bn2(self.act2(self.conv2(x)))
185 | x = self.bn3(self.act3(self.conv3(x)))
186 | x = self.bn4(self.act4(self.conv4(x)))
187 | x = self.bn5(self.act5(self.conv5(x)))
188 | x = self.bn6(self.act6(self.conv6(x)))
189 | x = self.bn7(self.act7(self.conv7(x)))
190 | x = self.bn8(self.act8(self.conv8(x)))
191 | x = self.bn9(self.act9(self.conv9(x)))
192 | x = self.bn10(self.act10(self.conv10(x)))
193 | x = self.bn11(self.act11(self.conv11(x)))
194 | x = self.bn12(self.act12(self.conv12(x)))
195 | x = self.bn13(self.act13(self.deconv13(x)))
196 | x = self.bn14(self.act14(self.conv14(x)))
197 | x = self.bn15(self.act15(self.deconv15(x)))
198 | x = self.bn16(self.act16(self.conv16(x)))
199 | x = self.act17(self.conv17(x))
200 | return x
201 |
202 | def dconv_bn_relu(in_dim, out_dim):
203 | return nn.Sequential(
204 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
205 | padding=2, output_padding=1, bias=False),
206 | nn.BatchNorm2d(out_dim),
207 | nn.ReLU())
208 |
209 | class ContextNetwork(nn.Module):
210 | def __init__(self):
211 | super(ContextNetwork, self).__init__()
212 | # input_shape: (None, 4, img_h, img_w)
213 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
214 | self.bn1 = nn.BatchNorm2d(32)
215 | self.act1 = nn.ReLU()
216 | # input_shape: (None, 32, img_h, img_w)
217 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
218 | self.bn2 = nn.BatchNorm2d(64)
219 | self.act2 = nn.ReLU()
220 | # input_shape: (None, 64, img_h//2, img_w//2)
221 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
222 | self.bn3 = nn.BatchNorm2d(64)
223 | self.act3 = nn.ReLU()
224 | # input_shape: (None, 128, img_h//2, img_w//2)
225 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
226 | self.bn4 = nn.BatchNorm2d(128)
227 | self.act4 = nn.ReLU()
228 | # input_shape: (None, 128, img_h//4, img_w//4)
229 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
230 | self.bn5 = nn.BatchNorm2d(128)
231 | self.act5 = nn.ReLU()
232 | # input_shape: (None, 128, img_h//4, img_w//4)
233 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
234 | self.bn6 = nn.BatchNorm2d(128)
235 | self.act6 = nn.ReLU()
236 | # input_shape: (None, 128, img_h//4, img_w//4)
237 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2)
238 | self.bn7 = nn.BatchNorm2d(128)
239 | self.act7 = nn.ReLU()
240 | # input_shape: (None, 128, img_h//4, img_w//4)
241 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4)
242 | self.bn8 = nn.BatchNorm2d(128)
243 | self.act8 = nn.ReLU()
244 | # input_shape: (None, 128, img_h//4, img_w//4)
245 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8)
246 | self.bn9 = nn.BatchNorm2d(128)
247 | self.act9 = nn.ReLU()
248 | # input_shape: (None, 128, img_h//4, img_w//4)
249 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16)
250 | self.bn10 = nn.BatchNorm2d(128)
251 | self.act10 = nn.ReLU()
252 |
253 |
254 |
255 | def forward(self, x):
256 | x = self.bn1(self.act1(self.conv1(x)))
257 | x = self.bn2(self.act2(self.conv2(x)))
258 | x = self.bn3(self.act3(self.conv3(x)))
259 | x = self.bn4(self.act4(self.conv4(x)))
260 | x = self.bn5(self.act5(self.conv5(x)))
261 | x = self.bn6(self.act6(self.conv6(x)))
262 | x = self.bn7(self.act7(self.conv7(x)))
263 | x = self.bn8(self.act8(self.conv8(x)))
264 | x = self.bn9(self.act9(self.conv9(x)))
265 | x = self.bn10(self.act10(self.conv10(x)))
266 | return x
267 |
268 | class IdentityGenerator(nn.Module):
269 |
270 | def __init__(self, in_dim = 100, dim=64):
271 | super(IdentityGenerator, self).__init__()
272 |
273 | self.l1 = nn.Sequential(
274 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
275 | nn.BatchNorm1d(dim * 8 * 4 * 4),
276 | nn.ReLU())
277 | self.l2_5 = nn.Sequential(
278 | dconv_bn_relu(dim * 8, dim * 4),
279 | dconv_bn_relu(dim * 4, dim * 2))
280 |
281 | def forward(self, x):
282 | y = self.l1(x)
283 | y = y.view(y.size(0), -1, 4, 4)
284 | y = self.l2_5(y)
285 | return y
286 |
287 | class InversionNet(nn.Module):
288 | def __init__(self, out_dim = 128):
289 | super(InversionNet, self).__init__()
290 |
291 | # input [4, h, w] output [256, h // 4, w // 4]
292 | self.ContextNetwork = ContextNetwork()
293 | # input [z_dim] output[128, 16, 16]
294 | self.IdentityGenerator = IdentityGenerator()
295 |
296 | self.dim = 128 + 128
297 | self.out_dim = out_dim
298 |
299 | self.Dconv = nn.Sequential(
300 | dconv_bn_relu(self.dim, self.out_dim),
301 | dconv_bn_relu(self.out_dim, self.out_dim // 2))
302 |
303 | self.Conv = nn.Sequential(
304 | nn.Conv2d(self.out_dim // 2, self.out_dim // 4, kernel_size=3, stride=1, padding=1),
305 | nn.BatchNorm2d(self.out_dim // 4),
306 | nn.ReLU(),
307 | nn.Conv2d(self.out_dim // 4, 3, kernel_size=3, stride=1, padding=1),
308 | nn.Sigmoid())
309 |
310 |
311 | def forward(self, inp):
312 | # x.shape [4, h, w] z.shape [100]
313 | x, z = inp
314 | context_info = self.ContextNetwork(x)
315 | identity_info = self.IdentityGenerator(z)
316 | y = torch.cat((context_info, identity_info), dim=1)
317 | y = self.Dconv(y)
318 | y = self.Conv(y)
319 |
320 | return y
321 |
--------------------------------------------------------------------------------
/recovery.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from models.classify import *
3 | from models.generator import *
4 | from models.discri import *
5 | import torch
6 | import os
7 | import numpy as np
8 | from attack import inversion, dist_inversion
9 | from argparse import ArgumentParser
10 |
11 |
12 | torch.manual_seed(9)
13 |
14 | parser = ArgumentParser(description='Inversion')
15 | parser.add_argument('--configs', type=str, default='./config/celeba/attacking/ffhq.json')
16 |
17 | args = parser.parse_args()
18 |
19 |
20 |
21 | def init_attack_args(cfg):
22 | if cfg["attack"]["method"] =='kedmi':
23 | args.improved_flag = True
24 | args.clipz = True
25 | args.num_seeds = 1
26 | else:
27 | args.improved_flag = False
28 | args.clipz = False
29 | args.num_seeds = 5
30 |
31 | if cfg["attack"]["variant"] == 'L_logit' or cfg["attack"]["variant"] == 'ours':
32 | args.loss = 'logit_loss'
33 | else:
34 | args.loss = 'cel'
35 |
36 | if cfg["attack"]["variant"] == 'L_aug' or cfg["attack"]["variant"] == 'ours':
37 | args.classid = '0,1,2,3'
38 | else:
39 | args.classid = '0'
40 |
41 |
42 |
43 | if __name__ == "__main__":
44 | # global args, logger
45 |
46 | cfg = load_json(json_file=args.configs)
47 | init_attack_args(cfg=cfg)
48 |
49 | # Save dir
50 | if args.improved_flag == True:
51 | prefix = os.path.join(cfg["root_path"], "kedmi_300ids")
52 | else:
53 | prefix = os.path.join(cfg["root_path"], "gmi_300ids")
54 | save_folder = os.path.join("{}_{}".format(cfg["dataset"]["name"], cfg["dataset"]["model_name"]), cfg["attack"]["variant"])
55 | prefix = os.path.join(prefix, save_folder)
56 | save_dir = os.path.join(prefix, "latent")
57 | save_img_dir = os.path.join(prefix, "imgs_{}".format(cfg["attack"]["variant"]))
58 | args.log_path = os.path.join(prefix, "invertion_logs")
59 |
60 | os.makedirs(prefix, exist_ok=True)
61 | os.makedirs(args.log_path, exist_ok=True)
62 | os.makedirs(save_img_dir, exist_ok=True)
63 | os.makedirs(save_dir, exist_ok=True)
64 |
65 |
66 | # Load models
67 | targetnets, E, G, D, n_classes, fea_mean, fea_logvar = get_attack_model(args, cfg)
68 | N = 5
69 | bs = 60
70 |
71 |
72 | # Begin attacking
73 | for i in range(1):
74 | iden = torch.from_numpy(np.arange(bs))
75 |
76 | # evaluate on the first 300 identities only
77 | target_cosines = 0
78 | eval_cosines = 0
79 | for idx in range(5):
80 | iden = iden %n_classes
81 | print("--------------------- Attack batch [%s]------------------------------" % idx)
82 | print('Iden:{}'.format(iden))
83 | save_dir_z = '{}/{}_{}'.format(save_dir,i,idx)
84 |
85 | if args.improved_flag == True:
86 | #KEDMI
87 | print('kedmi')
88 |
89 | dist_inversion(G, D, targetnets, E, iden,
90 | lr=cfg["attack"]["lr"], iter_times=cfg["attack"]["iters_mi"],
91 | momentum=0.9, lamda=100,
92 | clip_range=1, improved=args.improved_flag,
93 | num_seeds=args.num_seeds,
94 | used_loss=args.loss,
95 | prefix=save_dir_z,
96 | save_img_dir=os.path.join(save_img_dir, '{}_'.format(idx)),
97 | fea_mean=fea_mean,
98 | fea_logvar=fea_logvar,
99 | lam=cfg["attack"]["lam"],
100 | clipz=args.clipz)
101 | else:
102 | #GMI
103 | print('gmi')
104 | if cfg["attack"]["same_z"] =='':
105 | inversion(G, D, targetnets, E, iden,
106 | lr=cfg["attack"]["lr"], iter_times=cfg["attack"]["iters_mi"],
107 | momentum=0.9, lamda=100,
108 | clip_range=1, improved=args.improved_flag,
109 | used_loss=args.loss,
110 | prefix=save_dir_z,
111 | save_img_dir=save_img_dir,
112 | num_seeds=args.num_seeds,
113 | fea_mean=fea_mean,
114 | fea_logvar=fea_logvar,lam=cfg["attack"]["lam"],
115 | istart=args.istart)
116 | else:
117 | inversion(G, D, targetnets, E, iden,
118 | lr=args.lr, iter_times=args.iters_mi,
119 | momentum=0.9, lamda=100,
120 | clip_range=1, improved=args.improved_flag,
121 | used_loss=args.loss,
122 | prefix=save_dir_z,
123 | save_img_dir=save_img_dir,
124 | num_seeds=args.num_seeds,
125 | fea_mean=fea_mean,
126 | fea_logvar=fea_logvar,lam=cfg["attack"]["lam"],
127 | istart=args.istart,
128 | same_z='{}/{}_{}'.format(args.same_z,i,idx))
129 | iden = iden + bs
130 |
131 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | brotlipy==0.7.0
3 | cachetools==5.3.0
4 | cycler==0.11.0
5 | fonttools==4.38.0
6 | google-auth==2.16.1
7 | google-auth-oauthlib==0.4.6
8 | grpcio==1.51.3
9 | importlib-metadata==6.0.0
10 | joblib==1.2.0
11 | kiwisolver==1.4.4
12 | Markdown==3.4.1
13 | MarkupSafe==2.1.2
14 | matplotlib==3.5.3
15 | mkl-fft==1.3.1
16 | mkl-service==2.4.0
17 | oauthlib==3.2.2
18 | opencv-python==4.7.0.72
19 | packaging==23.0
20 | pandas==1.3.5
21 | Pillow==9.4.0
22 | protobuf==3.20.3
23 | pyasn1==0.4.8
24 | pyasn1-modules==0.2.8
25 | pyparsing==3.0.9
26 | python-dateutil==2.8.2
27 | pytz==2022.7.1
28 | requests-oauthlib==1.3.1
29 | rsa==4.9
30 | scikit-learn==1.0.2
31 | scipy==1.7.3
32 | sklearn==0.0.post1
33 | tensorboard==2.11.2
34 | tensorboard-data-server==0.6.1
35 | tensorboard-plugin-wit==1.8.1
36 | tensorboardX==2.6
37 | threadpoolctl==3.1.0
38 | torchaudio==0.11.0
39 | torchvision==0.12.0
40 | tqdm==4.64.1
41 | Werkzeug==2.2.3
42 | zipp==3.15.0
43 |
--------------------------------------------------------------------------------
/train_augmented_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | from utils import *
7 | from models.classify import *
8 | from engine import train_augmodel
9 | from argparse import ArgumentParser
10 |
11 |
12 | # Training settings
13 | parser = ArgumentParser(description='Train Augmented model')
14 | parser.add_argument('--configs', type=str, default='./config/celeba/training_augmodel/ffhq.json')
15 | args = parser.parse_args()
16 |
17 |
18 | if __name__ == '__main__':
19 | file = args.configs
20 | cfg = load_json(json_file=file)
21 |
22 | train_augmodel(cfg)
--------------------------------------------------------------------------------
/train_classifier.py:
--------------------------------------------------------------------------------
1 | import torch, os, engine, utils
2 | import torch.nn as nn
3 | from argparse import ArgumentParser
4 | from models import classify
5 |
6 |
7 | parser = ArgumentParser(description='Train Classifier')
8 | parser.add_argument('--configs', type=str, default='./config/celeba/training_classifiers/classify.json')
9 |
10 | args = parser.parse_args()
11 |
12 |
13 |
14 | def main(args, model_name, trainloader, testloader):
15 | n_classes = args["dataset"]["n_classes"]
16 | mode = args["dataset"]["mode"]
17 |
18 | resume_path = args[args["dataset"]["model_name"]]["resume"]
19 | net = classify.get_classifier(model_name=model_name, mode=mode, n_classes=n_classes, resume_path=resume_path)
20 |
21 | print(net)
22 |
23 | optimizer = torch.optim.SGD(params=net.parameters(),
24 | lr=args[model_name]['lr'],
25 | momentum=args[model_name]['momentum'],
26 | weight_decay=args[model_name]['weight_decay'])
27 |
28 | criterion = nn.CrossEntropyLoss().cuda()
29 | net = torch.nn.DataParallel(net).to(args['dataset']['device'])
30 |
31 | mode = args["dataset"]["mode"]
32 | n_epochs = args[model_name]['epochs']
33 | print("Start Training!")
34 |
35 | if mode == "reg":
36 | best_model, best_acc = engine.train_reg(args, net, criterion, optimizer, trainloader, testloader, n_epochs)
37 | elif mode == "vib":
38 | best_model, best_acc = engine.train_vib(args, net, criterion, optimizer, trainloader, testloader, n_epochs)
39 |
40 | torch.save({'state_dict':best_model.state_dict()}, os.path.join(model_path, "{}_{:.2f}_allclass.tar").format(model_name, best_acc[0]))
41 |
42 |
43 | if __name__ == '__main__':
44 |
45 | cfg = utils.load_json(json_file=args.configs)
46 |
47 | root_path = cfg["root_path"]
48 | log_path = os.path.join(root_path, "target_logs")
49 | model_path = os.path.join(root_path, "target_ckp")
50 | os.makedirs(model_path, exist_ok=True)
51 | os.makedirs(log_path, exist_ok=True)
52 |
53 |
54 | model_name = cfg['dataset']['model_name']
55 | log_file = "{}.txt".format(model_name)
56 | utils.Tee(os.path.join(log_path, log_file), 'w')
57 |
58 | print("TRAINING %s" % model_name)
59 | utils.print_params(cfg["dataset"], cfg[model_name], dataset=cfg['dataset']['name'])
60 |
61 | train_file = cfg['dataset']['train_file_path']
62 | test_file = cfg['dataset']['test_file_path']
63 | _, trainloader = utils.init_dataloader(cfg, train_file, mode="train")
64 | _, testloader = utils.init_dataloader(cfg, test_file, mode="test")
65 |
66 | main(cfg, model_name, trainloader, testloader)
67 |
--------------------------------------------------------------------------------
/train_gan.py:
--------------------------------------------------------------------------------
1 | import engine
2 | from utils import load_json
3 | from argparse import ArgumentParser
4 |
5 |
6 | parser = ArgumentParser(description='Train GAN')
7 | parser.add_argument('--configs', type=str, default='./config/celeba/training_GAN/specific_gan/celeba.json')
8 | parser.add_argument('--mode', type=str, choices=['specific', 'general'], default='specific')
9 | args = parser.parse_args()
10 |
11 |
12 | if __name__ == "__main__":
13 | # os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7'
14 | file = args.configs
15 | cfg = load_json(json_file=file)
16 |
17 | if args.mode == 'specific':
18 | engine.train_specific_gan(cfg=cfg)
19 | elif args.mode == 'general':
20 | engine.train_general_gan(cfg=cfg)
21 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.init as init
2 | import os, models.facenet as facenet, sys
3 | import json, time, random, torch
4 | from models import classify
5 | from models.classify import *
6 | from models.discri import *
7 | from models.generator import *
8 | import numpy as np
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchvision.utils as tvls
12 | from torchvision import transforms
13 | from datetime import datetime
14 | import dataloader
15 | from torch.autograd import grad
16 |
17 | device = "cuda"
18 |
19 | class Tee(object):
20 | def __init__(self, name, mode):
21 | self.file = open(name, mode)
22 | self.stdout = sys.stdout
23 | sys.stdout = self
24 | def __del__(self):
25 | sys.stdout = self.stdout
26 | self.file.close()
27 | def write(self, data):
28 | if not '...' in data:
29 | self.file.write(data)
30 | self.stdout.write(data)
31 | self.flush()
32 | def flush(self):
33 | self.file.flush()
34 |
35 | def init_dataloader(args, file_path, batch_size=64, mode="gan", iterator=False):
36 | tf = time.time()
37 |
38 | if mode == "attack":
39 | shuffle_flag = False
40 | else:
41 | shuffle_flag = True
42 |
43 |
44 | data_set = dataloader.ImageFolder(args, file_path, mode)
45 |
46 | if iterator:
47 | data_loader = torch.utils.data.DataLoader(data_set,
48 | batch_size=batch_size,
49 | shuffle=shuffle_flag,
50 | drop_last=True,
51 | num_workers=0,
52 | pin_memory=True).__iter__()
53 | else:
54 | data_loader = torch.utils.data.DataLoader(data_set,
55 | batch_size=batch_size,
56 | shuffle=shuffle_flag,
57 | drop_last=True,
58 | num_workers=2,
59 | pin_memory=True)
60 | interval = time.time() - tf
61 | print('Initializing data loader took %ds' % interval)
62 |
63 | return data_set, data_loader
64 |
65 | def load_pretrain(self, state_dict):
66 | own_state = self.state_dict()
67 | for name, param in state_dict.items():
68 | if name.startswith("module.fc_layer"):
69 | continue
70 | if name not in own_state:
71 | print(name)
72 | continue
73 | own_state[name].copy_(param.data)
74 |
75 | def load_state_dict(self, state_dict):
76 | own_state = self.state_dict()
77 | for name, param in state_dict.items():
78 | if name not in own_state:
79 | print(name)
80 | continue
81 | own_state[name].copy_(param.data)
82 |
83 | def load_json(json_file):
84 | with open(json_file) as data_file:
85 | data = json.load(data_file)
86 | return data
87 |
88 | def print_params(info, params, dataset=None):
89 | print('-----------------------------------------------------------------')
90 | print("Running time: %s" % datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
91 | for i, (key, value) in enumerate(info.items()):
92 | print("%s: %s" % (key, str(value)))
93 | for i, (key, value) in enumerate(params.items()):
94 | print("%s: %s" % (key, str(value)))
95 | print('-----------------------------------------------------------------')
96 |
97 | def save_tensor_images(images, filename, nrow = None, normalize = True):
98 | if not nrow:
99 | tvls.save_image(images, filename, normalize = normalize, padding=0)
100 | else:
101 | tvls.save_image(images, filename, normalize = normalize, nrow=nrow, padding=0)
102 |
103 |
104 | def get_deprocessor():
105 | # resize 112,112
106 | proc = []
107 | proc.append(transforms.Resize((112, 112)))
108 | proc.append(transforms.ToTensor())
109 | return transforms.Compose(proc)
110 |
111 | def low2high(img):
112 | # 0 and 1, 64 to 112
113 | bs = img.size(0)
114 | proc = get_deprocessor()
115 | img_tensor = img.detach().cpu().float()
116 | img = torch.zeros(bs, 3, 112, 112)
117 | for i in range(bs):
118 | img_i = transforms.ToPILImage()(img_tensor[i, :, :, :]).convert('RGB')
119 | img_i = proc(img_i)
120 | img[i, :, :, :] = img_i[:, :, :]
121 |
122 | img = img.cuda()
123 | return img
124 |
125 |
126 | def get_model(attack_name, classes):
127 | if attack_name.startswith("VGG16"):
128 | T = classify.VGG16(classes)
129 | elif attack_name.startswith("IR50"):
130 | T = classify.IR50(classes)
131 | elif attack_name.startswith("IR152"):
132 | T = classify.IR152(classes)
133 | elif attack_name.startswith("FaceNet64"):
134 | T = facenet.FaceNet64(classes)
135 | else:
136 | print("Model doesn't exist")
137 | exit()
138 |
139 | T = torch.nn.DataParallel(T).cuda()
140 | return T
141 |
142 | def get_augmodel(model_name, nclass, path_T=None, dataset='celeba'):
143 | if model_name=="VGG16":
144 | model = VGG16(nclass)
145 | elif model_name=="FaceNet":
146 | model = FaceNet(nclass)
147 | elif model_name=="FaceNet64":
148 | model = FaceNet64(nclass)
149 | elif model_name=="IR152":
150 | model = IR152(nclass)
151 | elif model_name =="efficientnet_b0":
152 | model = classify.EfficientNet_b0(nclass)
153 | elif model_name =="efficientnet_b1":
154 | model = classify.EfficientNet_b1(nclass)
155 | elif model_name =="efficientnet_b2":
156 | model = classify.EfficientNet_b2(nclass)
157 |
158 | model = torch.nn.DataParallel(model).cuda()
159 | if path_T is not None:
160 |
161 | ckp_T = torch.load(path_T)
162 | model.load_state_dict(ckp_T['state_dict'], strict=True)
163 | return model
164 |
165 |
166 | def log_sum_exp(x, axis = 1):
167 | m = torch.max(x, dim = 1)[0]
168 | return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))
169 |
170 | # define "soft" cross-entropy with pytorch tensor operations
171 | def softXEnt (input, target):
172 | targetprobs = nn.functional.softmax (target, dim = 1)
173 | logprobs = nn.functional.log_softmax (input, dim = 1)
174 | return -(targetprobs * logprobs).sum() / input.shape[0]
175 |
176 | class HLoss(nn.Module):
177 | def __init__(self):
178 | super(HLoss, self).__init__()
179 |
180 | def forward(self, x):
181 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
182 | b = -1.0 * b.sum()
183 | return b
184 |
185 | def freeze(net):
186 | for p in net.parameters():
187 | p.requires_grad_(False)
188 |
189 | def unfreeze(net):
190 | for p in net.parameters():
191 | p.requires_grad_(True)
192 |
193 | def gradient_penalty(x, y, DG):
194 | # interpolation
195 | shape = [x.size(0)] + [1] * (x.dim() - 1)
196 | alpha = torch.rand(shape).cuda()
197 | z = x + alpha * (y - x)
198 | z = z.cuda()
199 | z.requires_grad = True
200 |
201 | o = DG(z)
202 | g = grad(o, z, grad_outputs = torch.ones(o.size()).cuda(), create_graph = True)[0].view(z.size(0), -1)
203 | gp = ((g.norm(p = 2, dim = 1) - 1) ** 2).mean()
204 |
205 | return gp
206 |
207 | def log_sum_exp(x, axis = 1):
208 | m = torch.max(x, dim = 1)[0]
209 | return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))
210 |
211 | def count_parameters(model):
212 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
213 |
214 |
215 |
216 | def get_GAN(dataset, gan_type, gan_model_dir, n_classes, z_dim, target_model):
217 |
218 | G = Generator(z_dim)
219 | if gan_type == True:
220 | D = MinibatchDiscriminator(n_classes=n_classes)
221 | else:
222 | D = DGWGAN(3)
223 |
224 | if gan_type == True:
225 | path = os.path.join(os.path.join(gan_model_dir, dataset), target_model)
226 | path_G = os.path.join(path, "improved_{}_G.tar".format(dataset))
227 | path_D = os.path.join(path, "improved_{}_D.tar".format(dataset))
228 | else:
229 | path = os.path.join(gan_model_dir, dataset)
230 | path_G = os.path.join(path, "{}_G.tar".format(dataset))
231 | path_D = os.path.join(path, "{}_D.tar".format(dataset))
232 |
233 | print('path_G',path_G)
234 | print('path_D',path_D)
235 |
236 | G = torch.nn.DataParallel(G).to(device)
237 | D = torch.nn.DataParallel(D).to(device)
238 | ckp_G = torch.load(path_G)
239 | G.load_state_dict(ckp_G['state_dict'], strict=True)
240 | ckp_D = torch.load(path_D)
241 | D.load_state_dict(ckp_D['state_dict'], strict=True)
242 |
243 | return G, D
244 |
245 |
246 | def get_attack_model(args, args_json, eval_mode=False):
247 | now = datetime.now() # current date and time
248 |
249 | if not eval_mode:
250 | log_file = "invertion_logs_{}_{}.txt".format(args.loss,now.strftime("%m_%d_%Y_%H_%M_%S"))
251 | utils.Tee(os.path.join(args.log_path, log_file), 'w')
252 |
253 | n_classes=args_json['dataset']['n_classes']
254 |
255 |
256 | model_types_ = args_json['train']['model_types'].split(',')
257 | checkpoints = args_json['train']['cls_ckpts'].split(',')
258 |
259 | G, D = get_GAN(args_json['dataset']['name'],gan_type=args.improved_flag,
260 | gan_model_dir=args_json['train']['gan_model_dir'],
261 | n_classes=n_classes,z_dim=100,target_model=model_types_[0])
262 |
263 | dataset = args_json['dataset']['name']
264 | cid = args.classid.split(',')
265 | # target and student classifiers
266 | for i in range(len(cid)):
267 | id_ = int(cid[i])
268 | model_types_[id_] = model_types_[id_].strip()
269 | checkpoints[id_] = checkpoints[id_].strip()
270 | print('Load classifier {} at {}'.format(model_types_[id_], checkpoints[id_]))
271 | model = get_augmodel(model_types_[id_],n_classes,checkpoints[id_],dataset)
272 | model = model.to(device)
273 | model = model.eval()
274 | if i==0:
275 | targetnets = [model]
276 | else:
277 | targetnets.append(model)
278 |
279 | # p_reg
280 | if args.loss=='logit_loss':
281 | if model_types_[id_] == "IR152" or model_types_[id_]=="VGG16" or model_types_[id_]=="FaceNet64":
282 | #target model
283 | p_reg = os.path.join(args_json["dataset"]["p_reg_path"], '{}_{}_p_reg.pt'.format(dataset,model_types_[id_])) #'./p_reg/{}_{}_p_reg.pt'.format(dataset,model_types_[id_])
284 | else:
285 | #aug model
286 | p_reg = os.path.join(args_json["dataset"]["p_reg_path"], '{}_{}_{}_p_reg.pt'.format(dataset,model_types_[0],model_types_[id_])) #'./p_reg/{}_{}_{}_p_reg.pt'.format(dataset,model_types_[0],model_types_[id_])
287 | # print('p_reg',p_reg)
288 | if not os.path.exists(p_reg):
289 | _, dataloader_gan = init_dataloader(args_json, args_json['dataset']['gan_file_path'], 50, mode="gan")
290 | from attack import get_act_reg
291 | fea_mean_,fea_logvar_ = get_act_reg(dataloader_gan,model,device)
292 | torch.save({'fea_mean':fea_mean_,'fea_logvar':fea_logvar_},p_reg)
293 | else:
294 | fea_reg = torch.load(p_reg)
295 | fea_mean_ = fea_reg['fea_mean']
296 | fea_logvar_ = fea_reg['fea_logvar']
297 | if i == 0:
298 | fea_mean = [fea_mean_.to(device)]
299 | fea_logvar = [fea_logvar_.to(device)]
300 | else:
301 | fea_mean.append(fea_mean_)
302 | fea_logvar.append(fea_logvar_)
303 | # print('fea_logvar_',i,fea_logvar_.shape,fea_mean_.shape)
304 |
305 | else:
306 | fea_mean,fea_logvar = 0,0
307 |
308 | # evaluation classifier
309 | E = get_augmodel(args_json['train']['eval_model'],n_classes,args_json['train']['eval_dir'])
310 | E.eval()
311 | G.eval()
312 | D.eval()
313 |
314 | return targetnets, E, G, D, n_classes, fea_mean, fea_logvar
315 |
--------------------------------------------------------------------------------