├── README.md ├── attack.py ├── config └── celeba │ ├── attacking │ ├── celeba.json │ └── ffhq.json │ ├── training_GAN │ ├── general_gan │ │ ├── celeba.json │ │ └── ffhq.json │ └── specific_gan │ │ ├── celeba.json │ │ └── ffhq.json │ ├── training_augmodel │ ├── celeba.json │ └── ffhq.json │ └── training_classifiers │ └── classify.json ├── dataloader.py ├── engine.py ├── evaluation.py ├── losses.py ├── metrics ├── KNN_dist.py ├── __init__.py ├── eval_accuracy.py ├── fid.py └── metric_utils.py ├── models ├── __init__.py ├── classify.py ├── discri.py ├── evolve.py ├── facenet.py └── generator.py ├── recovery.py ├── requirements.txt ├── train_augmented_model.py ├── train_classifier.py ├── train_gan.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of paper "Re-thinking Model Inversion Attacks Against Deep Neural Networks" - CVPR 2023 2 | [Paper](https://arxiv.org/pdf/2304.01669.pdf) | [Project page](https://ngoc-nguyen-0.github.io/re-thinking_model_inversion_attacks/) 3 | ## 1. Setup Environment 4 | This code has been tested with Python 3.7, PyTorch 1.11.0 and Cuda 11.3. 5 | 6 | ``` 7 | conda create -n MI python=3.7 8 | 9 | conda activate MI 10 | 11 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 12 | 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## 2. Prepare Dataset & Checkpoints 17 | 18 | * Dowload CelebA and FFHQ dataset at the official website. 19 | - CelebA: download and extract the [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Then, place the `img_align_celeba` folder to `.\datasets\celeba` 20 | 21 | - FFHQ: download and extract the [FFHQ](https://github.com/NVlabs/ffhq-dataset). Then, place the `thumbnails128x128` folder to `.\datasets\ffhq` 22 | 23 | * Download meta data for the experiments at: https://drive.google.com/drive/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ?usp=sharing 24 | 25 | 26 | * We use the same target models and GAN as previous papers. You can download target models and generator at https://drive.google.com/drive/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ?usp=sharing 27 | 28 | Otherwise, you can train the target classifier and GAN as follow: 29 | 30 | 31 | ### 2.1. Training the target classifier (Optional) 32 | 33 | - Modify the configuration in `.\config\celeba\classify.json` 34 | - Then, run the following command line to get the target model 35 | ``` 36 | python train_classifier.py 37 | ``` 38 | 39 | ### 2.2. Training GAN (Optional) 40 | 41 | SOTA MI attacks work with a general GAN[1]. However, Inversion-Specific GANs[2] help improve the attack accuracy. In this repo, we provide codes for both training general GAN and Inversion-Specific GAN. 42 | 43 | #### 2.2.1. Build a inversion-specific GAN 44 | * Modify the configuration in 45 | * `./config/celeba/training_GAN/specific_gan/celeba.json` if training a Inversion-Specific GAN on CelebA (KEDMI[2]). 46 | * `./config/celeba/training_GAN/specific_gan/ffhq.json` if training a Inversion-Specific GAN on FFHQ (KEDMI[2]). 47 | 48 | * Then, run the following command line to get the Inversion-Specific GAN 49 | ``` 50 | python train_gan.py --configs path/to/config.json --mode "specific" 51 | ``` 52 | 53 | #### 2.2.2. Build a general GAN 54 | * Modify the configuration in 55 | * `./config/celeba/training_GAN/general_gan/celeba.json` if training a general GAN on CelebA (GMI[1]). 56 | * `./config/celeba/training_GAN/general_gan/ffhq.json` if training a general GAN on FFHQ (GMI[1]). 57 | 58 | * Then, run the following command line to get the General GAN 59 | ``` 60 | python train_gan.py --configs path/to/config.json --mode "general" 61 | ``` 62 | 63 | ## 3. Learn augmented models 64 | We provide code to train augmented models (i.e., `efficientnet_b0`, `efficientnet_b1`, and `efficientnet_b2`) from a ***target model***. 65 | * Modify the configuration in 66 | * `./config/celeba/training_augmodel/celeba.json` if training an augmented model on CelebA 67 | * `./config/celeba/training_augmodel/ffhq.json` if training an augmented model on FFHQ 68 | 69 | * Then, run the following command line to train augmented models 70 | ``` 71 | python train_augmented_model.py --configs path/to/config.json 72 | ``` 73 | 74 | Pretrained augmented models and p_reg can be downloaded at https://drive.google.com/drive/u/2/folders/1kq4ArFiPmCWYKY7iiV0WxxUSXtP70bFQ 75 | 76 | ***We remark that if you train augmented models, please do not use our p_reg***. Delete files in `./p_reg/` before inversion. Our code will automatically estimate p_reg with new augmented models. 77 | 78 | ## 4. Model Inversion Attack 79 | 80 | * Modify the configuration in 81 | * `./config/celeba/attacking/celeba.json` if training an augmented model on CelebA 82 | * `./config/celeba/attacking/ffhq.json` if training an augmented model on FFHQ 83 | 84 | * Important arguments: 85 | * `method`: select the method either ***gmi*** or ***kedmi*** 86 | * `variant` select the variant either ***baseline***, ***L_aug***, ***L_logit***, or ***ours*** 87 | 88 | * Then, run the following command line to attack 89 | ``` 90 | python recovery.py --configs path/to/config.json 91 | ``` 92 | 93 | ## 5. Evaluation 94 | 95 | After attack, use the same configuration file to run the following command line to get the result:\ 96 | ``` 97 | python evaluation.py --configs path/to/config.json 98 | ``` 99 | 100 | ## Acknowledgements 101 | We gratefully acknowledge the following works: 102 | - Knowledge-Enriched-Distributional-Model-Inversion-Attacks: https://github.com/SCccc21/Knowledge-Enriched-DMI 103 | - EfficientNet (Pytorch): https://pytorch.org/vision/stable/models/efficientnet.html 104 | - Experiment Tracking with Weights and Biases : https://www.wandb.com/ 105 | 106 | 107 | ## Reference 108 | [1] 109 | Zhang, Yuheng, et al. "The secret revealer: Generative model-inversion attacks against deep neural networks." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. 110 | 111 | 112 | [2] Si Chen, Mostafa Kahla, Ruoxi Jia, and Guo-Jun Qi. Knowledge-enriched distributional model inversion attacks. In Proceedings of the IEEE/CVF international conference on computer vision, pages 16178–16187, 2021 113 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import torch, os, time, utils 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import log_sum_exp, save_tensor_images 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | 12 | def reg_loss(featureT,fea_mean, fea_logvar): 13 | 14 | fea_reg = reparameterize(fea_mean, fea_logvar) 15 | fea_reg = fea_mean.repeat(featureT.shape[0],1) 16 | loss_reg = torch.mean((featureT - fea_reg).pow(2)) 17 | # print('loss_reg',loss_reg) 18 | return loss_reg 19 | 20 | def attack_acc(fake,iden,E,): 21 | 22 | eval_prob = E(utils.low2high(fake))[-1] 23 | 24 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1) 25 | 26 | cnt, cnt5 = 0, 0 27 | bs = fake.shape[0] 28 | # print('correct id') 29 | for i in range(bs): 30 | gt = iden[i].item() 31 | if eval_iden[i].item() == gt: 32 | cnt += 1 33 | # print(gt) 34 | _, top5_idx = torch.topk(eval_prob[i], 5) 35 | if gt in top5_idx: 36 | cnt5 += 1 37 | return cnt*100.0/bs, cnt5*100.0/bs 38 | 39 | def reparameterize(mu, logvar): 40 | """ 41 | Reparameterization trick to sample from N(mu, var) from 42 | N(0,1). 43 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 44 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 45 | :return: (Tensor) [B x D] 46 | """ 47 | std = torch.exp(0.5 * logvar) 48 | eps = torch.randn_like(std) 49 | 50 | return eps * std + mu 51 | 52 | def find_criterion(used_loss): 53 | criterion = None 54 | if used_loss=='logit_loss': 55 | criterion = nn.NLLLoss().to(device) 56 | print('criterion:{}'.format(used_loss)) 57 | elif used_loss=='cel': 58 | criterion = nn.CrossEntropyLoss().to(device) 59 | print('criterion',criterion) 60 | else: 61 | print('criterion:{}'.format(used_loss)) 62 | return criterion 63 | 64 | def get_act_reg(train_loader,T,device,Nsample=5000): 65 | all_fea = [] 66 | with torch.no_grad(): 67 | for batch_idx, data in enumerate(train_loader): # batchsize =100 68 | # print(data.shape) 69 | if batch_idx*len(data) > Nsample: 70 | break 71 | data = data.to(device) 72 | fea,_ = T(data) 73 | if batch_idx == 0: 74 | all_fea = fea 75 | else: 76 | all_fea = torch.cat((all_fea,fea)) 77 | fea_mean = torch.mean(all_fea,dim=0) 78 | fea_logvar = torch.std(all_fea,dim=0) 79 | 80 | print(fea_mean.shape, fea_logvar.shape, all_fea.shape) 81 | return fea_mean,fea_logvar 82 | 83 | def iden_loss(T,fake, iden, used_loss,criterion,fea_mean=0, fea_logvar=0,lam=0.1): 84 | Iden_Loss = 0 85 | loss_reg = 0 86 | for tn in T: 87 | 88 | feat,out = tn(fake) 89 | if used_loss == 'logit_loss': #reg only with the target classifier, reg is randomly from distribution 90 | if Iden_Loss ==0: 91 | loss_sdt = criterion(out, iden) 92 | loss_reg = lam*reg_loss(feat,fea_mean[0], fea_logvar[0]) #reg only with the target classifier 93 | 94 | Iden_Loss = Iden_Loss + loss_sdt 95 | else: 96 | loss_sdt = criterion(out, iden) 97 | Iden_Loss = Iden_Loss + loss_sdt 98 | 99 | else: 100 | loss_sdt = criterion(out, iden) 101 | Iden_Loss = Iden_Loss + loss_sdt 102 | 103 | Iden_Loss = Iden_Loss/len(T) + loss_reg 104 | return Iden_Loss 105 | 106 | 107 | 108 | def dist_inversion(G, D, T, E, iden, lr=2e-2, momentum=0.9, lamda=100, \ 109 | iter_times=1500, clip_range=1.0, improved=False, num_seeds=5, \ 110 | used_loss='cel', prefix='', random_seed=0, save_img_dir='',fea_mean=0, \ 111 | fea_logvar=0, lam=0.1, clipz=False): 112 | 113 | iden = iden.view(-1).long().to(device) 114 | criterion = find_criterion(used_loss) 115 | bs = iden.shape[0] 116 | 117 | G.eval() 118 | D.eval() 119 | E.eval() 120 | 121 | #NOTE 122 | mu = Variable(torch.zeros(bs, 100), requires_grad=True) 123 | log_var = Variable(torch.ones(bs, 100), requires_grad=True) 124 | 125 | params = [mu, log_var] 126 | solver = optim.Adam(params, lr=lr) 127 | outputs_z = "{}_iter_{}_{}_dis.npy".format(prefix, random_seed, iter_times-1) 128 | 129 | if not os.path.exists(outputs_z): 130 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, 0) 131 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, 0) 132 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()}) 133 | np.save(outputs_label,iden.detach().cpu().numpy()) 134 | 135 | for i in range(iter_times): 136 | z = reparameterize(mu, log_var) 137 | if clipz==True: 138 | z = torch.clamp(z,-clip_range,clip_range).float() 139 | fake = G(z) 140 | 141 | if improved == True: 142 | _, label = D(fake) 143 | else: 144 | label = D(fake) 145 | 146 | for p in params: 147 | if p.grad is not None: 148 | p.grad.data.zero_() 149 | Iden_Loss = iden_loss(T,fake, iden, used_loss, criterion, fea_mean, fea_logvar, lam) 150 | 151 | if improved: 152 | Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label)) 153 | else: 154 | Prior_Loss = - label.mean() 155 | 156 | Total_Loss = Prior_Loss + lamda * Iden_Loss 157 | 158 | Total_Loss.backward() 159 | solver.step() 160 | 161 | Prior_Loss_val = Prior_Loss.item() 162 | Iden_Loss_val = Iden_Loss.item() 163 | 164 | if (i+1) % 300 == 0: 165 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, i) 166 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, i) 167 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()}) 168 | np.save(outputs_label,iden.detach().cpu().numpy()) 169 | 170 | with torch.no_grad(): 171 | z = reparameterize(mu, log_var) 172 | if clipz==True: 173 | z = torch.clamp(z,-clip_range, clip_range).float() 174 | fake_img = G(z.detach()) 175 | eval_prob = E(utils.low2high(fake_img))[-1] 176 | 177 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1) 178 | acc = iden.eq(eval_iden.long()).sum().item() * 100.0 / bs 179 | save_tensor_images(fake_img, save_img_dir + '{}.png'.format(i+1)) 180 | print("Iteration:{}\tPrior Loss:{:.2f}\tIden Loss:{:.2f}\tAttack Acc:{:.2f}".format(i+1, Prior_Loss_val, Iden_Loss_val, acc)) 181 | 182 | 183 | outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, iter_times) 184 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, iter_times) 185 | np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()}) 186 | np.save(outputs_label,iden.detach().cpu().numpy()) 187 | 188 | def inversion(G, D, T, E, iden, lr=2e-2, momentum=0.9, lamda=100, \ 189 | iter_times=1500, clip_range=1, improved=False, num_seeds=5, \ 190 | used_loss='cel', prefix='', save_img_dir='', fea_mean=0, \ 191 | fea_logvar=0, lam=0.1, istart=0, same_z=''): 192 | 193 | iden = iden.view(-1).long().to(device) 194 | criterion = find_criterion(used_loss) 195 | bs = iden.shape[0] 196 | 197 | G.eval() 198 | D.eval() 199 | E.eval() 200 | 201 | for random_seed in range(istart, num_seeds, 1): 202 | outputs_z = "{}_iter_{}_{}_z.npy".format(prefix, random_seed, iter_times-1) 203 | 204 | if not os.path.exists(outputs_z): 205 | tf = time.time() 206 | if same_z=='': #no prior z 207 | z = torch.randn(bs, 100).to(device).float() 208 | else: 209 | z_path = "{}_iter_{}_{}_z.npy".format(same_z, random_seed, 0) 210 | print('load z ', z_path) 211 | z = torch.from_numpy(np.load(z_path)).to(device).float() 212 | print('z',z) 213 | # exit() 214 | z.requires_grad = True 215 | v = torch.zeros(bs, 100).to(device).float() 216 | 217 | outputs_z = "{}_iter_{}_{}_z".format(prefix, random_seed, 0) 218 | outputs_label = "{}_iter_{}_label".format(prefix, random_seed, 0) 219 | np.save(outputs_z,z.detach().cpu().numpy()) 220 | np.save(outputs_label,iden.detach().cpu().numpy()) 221 | 222 | for i in range(iter_times): 223 | fake = G(z) 224 | if improved == True: 225 | _, label = D(fake) 226 | else: 227 | label = D(fake) 228 | 229 | if z.grad is not None: 230 | z.grad.data.zero_() 231 | 232 | if improved: 233 | Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label)) 234 | else: 235 | Prior_Loss = - label.mean() 236 | 237 | Iden_Loss = iden_loss(T,fake, iden, used_loss, criterion, fea_mean, fea_logvar, lam) 238 | 239 | Total_Loss = Prior_Loss + lamda*Iden_Loss 240 | 241 | Total_Loss.backward() 242 | 243 | v_prev = v.clone() 244 | gradient = z.grad.data 245 | v = momentum*v - lr*gradient 246 | z = z + ( - momentum*v_prev + (1 + momentum)*v) 247 | z = torch.clamp(z.detach(), -clip_range, clip_range).float() 248 | z.requires_grad = True 249 | 250 | Prior_Loss_val = Prior_Loss.item() 251 | Iden_Loss_val = Iden_Loss.item() 252 | 253 | if (i+1) % 300 == 0: 254 | outputs_z = "{}_iter_{}_{}_z".format(prefix, random_seed, i) 255 | outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, i) 256 | np.save(outputs_z, z.detach().cpu().numpy()) 257 | np.save(outputs_label, iden.detach().cpu().numpy()) 258 | with torch.no_grad(): 259 | fake_img = G(z.detach()) 260 | 261 | eval_prob = E(utils.low2high(fake_img))[-1] 262 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1) 263 | acc = iden.eq(eval_iden.long()).sum().item() * 100.0 / bs 264 | print("Iteration:{}\tPrior Loss:{:.2f}\tIden Loss:{:.2f}\tAttack Acc:{:.3f}".format(i+1, Prior_Loss_val, Iden_Loss_val, acc)) 265 | 266 | -------------------------------------------------------------------------------- /config/celeba/attacking/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path": "./attack_results/", 3 | "dataset":{ 4 | "model_name": "VGG16", 5 | "test_file_path": "./datasets/celeba/meta/testset.txt", 6 | "gan_file_path": "./datasets/celeba/meta/ganset.txt", 7 | "name": "celeba", 8 | "img_path": "./datasets/celeba/img_align_celeba", 9 | "img_gan_path": "./datasets/celeba/img_align_celeba", 10 | "n_classes":1000, 11 | "fid_real_path": "./datasets/celeba/meta/celeba_target_300ids.npy", 12 | "KNN_real_path": "./datasets/celeba/meta/fea_target_300ids.npy", 13 | "p_reg_path": "./checkpoints/p_reg" 14 | }, 15 | 16 | "train":{ 17 | "model_types": "VGG16,efficientnet_b0,efficientnet_b1,efficientnet_b2", 18 | "cls_ckpts": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b0_0.02_1.0/VGG16_efficientnet_b0_kd_0_20.pt,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b1_0.02_1.0/VGG16_efficientnet_b1_kd_0_20.pt,./checkpoints/aug_ckp/celeba/VGG16_efficientnet_b2_0.02_1.0/VGG16_efficientnet_b2_kd_0_20.pt", 19 | "num_seeds": 5, 20 | "Nclass": 300, 21 | "gan_model_dir": "./checkpoints/GAN", 22 | "eval_model": "FaceNet", 23 | "eval_dir": "./checkpoints/target_model/target_ckp/FaceNet_95.88.tar" 24 | }, 25 | 26 | "attack":{ 27 | "method": "kedmi", 28 | "variant": "L_logit", 29 | "iters_mi": 2400, 30 | "lr": 0.02, 31 | "lam": 1.0, 32 | "same_z":"", 33 | "eval_metric": "fid, acc, knn" 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /config/celeba/attacking/ffhq.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path": "./attack_results/", 3 | "dataset":{ 4 | "model_name": "VGG16", 5 | "test_file_path": "./datasets/celeba/meta/testset.txt", 6 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt", 7 | "name": "ffhq", 8 | "img_path": "./datasets/celeba/img_align_celeba", 9 | "img_gan_path": "./datasets/ffhq/", 10 | "n_classes":1000, 11 | "fid_real_path": "./datasets/celeba/meta/celeba_target_300ids.npy", 12 | "KNN_real_path": "./datasets/celeba/meta/fea_target_300ids.npy", 13 | "p_reg_path": "./checkpoints/p_reg" 14 | }, 15 | 16 | "train":{ 17 | "model_types": "VGG16,efficientnet_b0,efficientnet_b1,efficientnet_b2", 18 | "cls_ckpts": "checkpoints/target_model/target_ckp/VGG16_88.26.tar,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b0_0.02_1.0/VGG16_efficientnet_b0_kd_0_20.pt,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b1_0.02_1.0/VGG16_efficientnet_b1_kd_0_20.pt,./checkpoints/aug_ckp/ffhq/VGG16_efficientnet_b2_0.02_1.0/VGG16_efficientnet_b2_kd_0_20.pt", 19 | "num_seeds": 5, 20 | "Nclass": 300, 21 | "gan_model_dir": "./checkpoints/GAN", 22 | "eval_model": "FaceNet", 23 | "eval_dir": "./checkpoints/target_model/target_ckp/FaceNet_95.88.tar" 24 | }, 25 | 26 | "attack":{ 27 | "method": "kedmi", 28 | "variant": "L_logit", 29 | "iters_mi": 2400, 30 | "lr": 0.02, 31 | "lam": 1.0, 32 | "same_z":"", 33 | "eval_metric": "fid, acc, knn" 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /config/celeba/training_GAN/general_gan/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"./checkpoints/GAN", 3 | 4 | "dataset":{ 5 | "gan_file_path": "./datasets/celeba/meta/ganset.txt", 6 | "model_name": "train_gan - first stage", 7 | "name": "celeba", 8 | "img_gan_path": "./datasets/celeba/img_align_celeba", 9 | "n_classes":1000 10 | }, 11 | 12 | "train_gan - first stage":{ 13 | "lr": 0.0002, 14 | "batch_size": 64, 15 | "z_dim": 100, 16 | "epochs": 120, 17 | "n_critic": 5, 18 | "unlabel_weight": 10 19 | } 20 | } -------------------------------------------------------------------------------- /config/celeba/training_GAN/general_gan/ffhq.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"./checkpoints/GAN", 3 | 4 | "dataset":{ 5 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt", 6 | "model_name": "train_gan - first stage", 7 | "name": "ffhq", 8 | "img_gan_path": "./datasets/ffhq/thumbnails128x128", 9 | "n_classes":1000 10 | }, 11 | 12 | "train_gan - first stage":{ 13 | "lr": 0.0002, 14 | "batch_size": 64, 15 | "z_dim": 100, 16 | "epochs": 120, 17 | "n_critic": 5, 18 | "unlabel_weight": 10 19 | } 20 | } -------------------------------------------------------------------------------- /config/celeba/training_GAN/specific_gan/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"./checkpoints/GAN", 3 | 4 | "dataset":{ 5 | "gan_file_path": "./datasets/celeba/meta/ganset.txt", 6 | "model_name": "VGG16", 7 | "name": "celeba", 8 | "img_gan_path": "./datasets/celeba/img_align_celeba", 9 | "n_classes":1000 10 | }, 11 | 12 | "train":{ 13 | "model_types": "VGG16", 14 | "num_seeds": 5, 15 | "Nclass": 300 16 | }, 17 | 18 | "FaceNet64":{ 19 | "lr": 0.0002, 20 | "batch_size": 64, 21 | "z_dim": 100, 22 | "epochs": 120, 23 | "n_critic": 5, 24 | "unlabel_weight": 10, 25 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/FaceNet64_88.50.tar" 26 | }, 27 | 28 | "VGG16":{ 29 | "lr": 0.0002, 30 | "batch_size": 64, 31 | "z_dim": 100, 32 | "epochs": 120, 33 | "n_critic": 5, 34 | "unlabel_weight": 10, 35 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/VGG16_87.10_allclass.tar" 36 | }, 37 | 38 | "IR152":{ 39 | "lr": 0.0002, 40 | "batch_size": 64, 41 | "z_dim": 100, 42 | "epochs": 120, 43 | "n_critic": 5, 44 | "unlabel_weight": 10, 45 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/IR152_91.16.tar" 46 | } 47 | } -------------------------------------------------------------------------------- /config/celeba/training_GAN/specific_gan/ffhq.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"./checkpoints/GAN", 3 | 4 | "dataset":{ 5 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt", 6 | "model_name": "VGG16", 7 | "name": "ffhq", 8 | "img_gan_path": "./datasets/ffhq/thumbnails128x128", 9 | "n_classes":1000 10 | }, 11 | 12 | "train":{ 13 | "model_types": "VGG16", 14 | "num_seeds": 5, 15 | "Nclass": 300 16 | }, 17 | 18 | "FaceNet64":{ 19 | "lr": 0.0002, 20 | "batch_size": 64, 21 | "z_dim": 100, 22 | "epochs": 120, 23 | "n_critic": 5, 24 | "unlabel_weight": 10, 25 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/FaceNet64_88.50.tar" 26 | }, 27 | 28 | "VGG16":{ 29 | "lr": 0.0002, 30 | "batch_size": 64, 31 | "z_dim": 100, 32 | "epochs": 120, 33 | "n_critic": 5, 34 | "unlabel_weight": 10, 35 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/VGG16_87.10_allclass.tar" 36 | }, 37 | 38 | "IR152":{ 39 | "lr": 0.0002, 40 | "batch_size": 64, 41 | "z_dim": 100, 42 | "epochs": 120, 43 | "n_critic": 5, 44 | "unlabel_weight": 10, 45 | "cls_ckpts": "./checkpoints/download/target_model/target_ckp/IR152_91.16.tar" 46 | } 47 | } -------------------------------------------------------------------------------- /config/celeba/training_augmodel/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path": "./checkpoints/aug_ckp/", 3 | "dataset":{ 4 | "gan_file_path": "./datasets/celeba/meta/ganset.txt", 5 | "test_file_path": "./datasets/celeba/meta/testset.txt", 6 | "name": "celeba", 7 | "model_name": "VGG16", 8 | "img_path": "./datasets/celeba/img_align_celeba", 9 | "img_gan_path": "./datasets/celeba/img_align_celeba", 10 | "n_classes": 1000, 11 | "batch_size": 64 12 | }, 13 | 14 | "train":{ 15 | "epochs": 20, 16 | "target_model_name": "VGG16", 17 | "target_model_ckpt": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar", 18 | "student_model_name": "efficientnet_b0", 19 | "device": "cuda", 20 | "lr": 0.01, 21 | "temperature": 1.0, 22 | "seed": 1, 23 | "log_interval": -1 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /config/celeba/training_augmodel/ffhq.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path": "./checkpoints/aug_ckp/", 3 | "dataset":{ 4 | "gan_file_path": "./datasets/ffhq/meta/ganset_ffhq.txt", 5 | "test_file_path": "./datasets/celeba/meta/testset.txt", 6 | "name": "ffhq", 7 | "model_name": "VGG16", 8 | "img_path": "./datasets/celeba/img_align_celeba", 9 | "img_gan_path": "./datasets/ffhq/thumbnails128x128", 10 | "n_classes": 1000, 11 | "batch_size": 64 12 | }, 13 | 14 | "train":{ 15 | "epochs": 20, 16 | "target_model_name": "VGG16", 17 | "target_model_ckpt": "./checkpoints/target_model/target_ckp/VGG16_88.26.tar", 18 | "student_model_name": "efficientnet_b0", 19 | "device": "cuda", 20 | "lr": 0.01, 21 | "temperature": 1.0, 22 | "seed": 1, 23 | "log_interval": -1 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /config/celeba/training_classifiers/classify.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"./checkpoints/target_model", 3 | 4 | "dataset":{ 5 | "name":"celeba", 6 | "train_file_path":"./datasets/celeba/meta/trainset.txt", 7 | "test_file_path":"./datasets/celeba/meta/testset.txt", 8 | "img_path": "./datasets/celeba/img_align_celeba", 9 | "model_name":"FaceNet", 10 | "mode":"reg", 11 | "n_classes":1000, 12 | "device":"cuda" 13 | }, 14 | 15 | "VGG16":{ 16 | "epochs":50, 17 | "batch_size":64, 18 | "instance":4, 19 | "lr":1e-2, 20 | "momentum":0.9, 21 | "weight_decay":1e-4, 22 | "gamma":0.2, 23 | "adjust_epochs":[20, 35], 24 | "resume":"" 25 | }, 26 | 27 | "FaceNet":{ 28 | "epochs":30, 29 | "batch_size":64, 30 | "instance":4, 31 | "lr":1e-2, 32 | "momentum":0.9, 33 | "weight_decay":1e-4, 34 | "adjust_lr":[1e-3, 1e-4], 35 | "adjust_epochs":[15, 25], 36 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth" 37 | }, 38 | 39 | "FaceNet_all":{ 40 | "epochs":100, 41 | "batch_size":64, 42 | "instance":4, 43 | "lr":1e-2, 44 | "momentum":0.9, 45 | "weight_decay":1e-4, 46 | "adjust_lr":[1e-3, 1e-4], 47 | "adjust_epochs":[15, 25], 48 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth" 49 | }, 50 | 51 | "FaceNet64":{ 52 | "epochs":50, 53 | "batch_size":64, 54 | "lr":1e-2, 55 | "momentum":0.9, 56 | "weight_decay":1e-4, 57 | "lrdecay_epoch":10, 58 | "lrdecay":0.1, 59 | "resume":"./checkpoints/backbone/backbone_ir50_ms1m_epoch120.pth" 60 | }, 61 | 62 | "IR152":{ 63 | "epochs":40, 64 | "batch_size":64, 65 | "lr":1e-2, 66 | "momentum":0.9, 67 | "weight_decay":1e-4, 68 | "lrdecay_epoch":10, 69 | "lrdecay":0.1, 70 | "resume":"./checkpoints/backbone/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth" 71 | }, 72 | 73 | "IR50":{ 74 | "epochs":40, 75 | "batch_size":64, 76 | "lr":1e-2, 77 | "momentum":0.9, 78 | "weight_decay":1e-4, 79 | "lrdecay_epoch":10, 80 | "lrdecay":0.1, 81 | "resume":"./checkpoints/backbone/ir50.pth" 82 | } 83 | } 84 | 85 | 86 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os, torchvision, PIL 2 | import torch 3 | from PIL import Image 4 | import torch.nn.functional as F 5 | import torch.utils.data as data 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | from torch.nn.modules.loss import _Loss 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | 11 | 12 | class ImageFolder(data.Dataset): 13 | def __init__(self, args, file_path, mode): 14 | self.args = args 15 | self.mode = mode 16 | if mode == 'gan': 17 | self.img_path = args["dataset"]["img_gan_path"] 18 | else: 19 | self.img_path = args["dataset"]["img_path"] 20 | self.model_name = args["dataset"]["model_name"] 21 | # self.img_list = os.listdir(self.img_path) 22 | self.processor = self.get_processor() 23 | self.name_list, self.label_list = self.get_list(file_path) 24 | self.image_list = self.load_img() 25 | self.num_img = len(self.image_list) 26 | self.n_classes = args["dataset"]["n_classes"] 27 | if self.mode is not "gan": 28 | print("Load " + str(self.num_img) + " images") 29 | 30 | def get_list(self, file_path): 31 | name_list, label_list = [], [] 32 | f = open(file_path, "r") 33 | for line in f.readlines(): 34 | if self.mode == "gan": 35 | img_name = line.strip() 36 | else: 37 | img_name, iden = line.strip().split(' ') 38 | label_list.append(int(iden)) 39 | name_list.append(img_name) 40 | 41 | 42 | return name_list, label_list 43 | 44 | 45 | def load_img(self): 46 | img_list = [] 47 | for i, img_name in enumerate(self.name_list): 48 | if img_name.endswith(".png") or img_name.endswith(".jpg") or img_name.endswith(".jpeg") : 49 | path = self.img_path + "/" + img_name 50 | img = PIL.Image.open(path) 51 | img = img.convert('RGB') 52 | img_list.append(img) 53 | return img_list 54 | 55 | 56 | def get_processor(self): 57 | if self.model_name in ("FaceNet", "FaceNet_all"): 58 | re_size = 112 59 | else: 60 | re_size = 64 61 | if self.args["dataset"]["name"] =='celeba': 62 | crop_size = 108 63 | offset_height = (218 - crop_size) // 2 64 | offset_width = (178 - crop_size) // 2 65 | elif self.args["dataset"]["name"] == 'facescrub': 66 | # NOTE: dataset face scrub 67 | if self.mode=='gan': 68 | crop_size = 54 69 | offset_height = (64 - crop_size) // 2 70 | offset_width = (64 - crop_size) // 2 71 | else: 72 | crop_size = 108 73 | offset_height = (218 - crop_size) // 2 74 | offset_width = (178 - crop_size) // 2 75 | elif self.args["dataset"]["name"] == 'ffhq': 76 | # print('ffhq') 77 | #NOTE: dataset ffhq 78 | if self.mode=='gan': 79 | crop_size = 88 80 | offset_height = (128 - crop_size) // 2 81 | offset_width = (128 - crop_size) // 2 82 | else: 83 | crop_size = 108 84 | offset_height = (218 - crop_size) // 2 85 | offset_width = (178 - crop_size) // 2 86 | 87 | # #NOTE: dataset pf83 88 | # crop_size = 176 89 | # offset_height = (256 - crop_size) // 2 90 | # offset_width = (256 - crop_size) // 2 91 | crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size] 92 | 93 | proc = [] 94 | if self.mode == "train": 95 | proc.append(transforms.ToTensor()) 96 | proc.append(transforms.Lambda(crop)) 97 | proc.append(transforms.ToPILImage()) 98 | proc.append(transforms.Resize((re_size, re_size))) 99 | proc.append(transforms.RandomHorizontalFlip(p=0.5)) 100 | proc.append(transforms.ToTensor()) 101 | else: 102 | 103 | proc.append(transforms.ToTensor()) 104 | if self.mode=='test' or self.mode=='train' or self.args["dataset"]["name"] != 'facescrub': 105 | proc.append(transforms.Lambda(crop)) 106 | proc.append(transforms.ToPILImage()) 107 | proc.append(transforms.Resize((re_size, re_size))) 108 | proc.append(transforms.ToTensor()) 109 | 110 | 111 | return transforms.Compose(proc) 112 | 113 | def __getitem__(self, index): 114 | processer = self.get_processor() 115 | img = processer(self.image_list[index]) 116 | if self.mode == "gan": 117 | return img 118 | label = self.label_list[index] 119 | 120 | return img, label 121 | 122 | def __len__(self): 123 | return self.num_img 124 | 125 | class GrayFolder(data.Dataset): 126 | def __init__(self, args, file_path, mode): 127 | self.args = args 128 | self.mode = mode 129 | self.img_path = args["dataset"]["img_path"] 130 | self.img_list = os.listdir(self.img_path) 131 | self.processor = self.get_processor() 132 | self.name_list, self.label_list = self.get_list(file_path) 133 | self.image_list = self.load_img() 134 | self.num_img = len(self.image_list) 135 | self.n_classes = args["dataset"]["n_classes"] 136 | print("Load " + str(self.num_img) + " images") 137 | 138 | def get_list(self, file_path): 139 | name_list, label_list = [], [] 140 | f = open(file_path, "r") 141 | for line in f.readlines(): 142 | if self.mode == "gan": 143 | img_name = line.strip() 144 | else: 145 | img_name, iden = line.strip().split(' ') 146 | label_list.append(int(iden)) 147 | name_list.append(img_name) 148 | 149 | return name_list, label_list 150 | 151 | 152 | def load_img(self): 153 | img_list = [] 154 | for i, img_name in enumerate(self.name_list): 155 | if img_name.endswith(".png"): 156 | path = self.img_path + "/" + img_name 157 | img = PIL.Image.open(path) 158 | img = img.convert('L') 159 | img_list.append(img) 160 | return img_list 161 | 162 | def get_processor(self): 163 | proc = [] 164 | if self.args['dataset']['name'] == "mnist": 165 | re_size = 32 166 | else: 167 | re_size = 64 168 | proc.append(transforms.Resize((re_size, re_size))) 169 | proc.append(transforms.ToTensor()) 170 | 171 | return transforms.Compose(proc) 172 | 173 | def __getitem__(self, index): 174 | processer = self.get_processor() 175 | img = processer(self.image_list[index]) 176 | if self.mode == "gan": 177 | return img 178 | label = self.label_list[index] 179 | 180 | return img, label 181 | 182 | def __len__(self): 183 | return self.num_img 184 | 185 | def load_mnist(): 186 | transform = transforms.Compose([transforms.ToTensor()]) 187 | trainset = torchvision.datasets.MNIST(mnist_path, train=True, transform=transform, download=True) 188 | testset = torchvision.datasets.MNIST(mnist_path, train=False, transform=transform, download=True) 189 | 190 | train_loader = DataLoader(trainset, batch_size=1) 191 | test_loader = DataLoader(testset, batch_size=1) 192 | cnt = 0 193 | 194 | for imgs, labels in train_loader: 195 | cnt += 1 196 | img_name = str(cnt) + '_' + str(labels.item()) + '.png' 197 | # utils.save_tensor_images(imgs, os.path.join(mnist_img_path, img_name)) 198 | print("number of train files:", cnt) 199 | 200 | for imgs, labels in test_loader: 201 | cnt += 1 202 | img_name = str(cnt) + '_' + str(labels.item()) + '.png' 203 | # utils.save_tensor_images(imgs, os.path.join(mnist_img_path, img_name)) 204 | 205 | class celeba(data.Dataset): 206 | def __init__(self, data_path=None, label_path=None): 207 | self.data_path = data_path 208 | self.label_path = label_path 209 | 210 | # Data transforms 211 | crop_size = 108 212 | offset_height = (218 - crop_size) // 2 213 | offset_width = (178 - crop_size) // 2 214 | proc = [] 215 | proc.append(transforms.ToTensor()) 216 | proc.append(transforms.Lambda(crop)) 217 | proc.append(transforms.ToPILImage()) 218 | proc.append(transforms.Resize((112, 112))) 219 | proc.append(transforms.ToTensor()) 220 | proc.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 221 | 222 | self.transform = transforms.Compose(proc) 223 | 224 | def __len__(self): 225 | return len(self.data_path) 226 | 227 | def __getitem__(self, idx): 228 | image_set = Image.open(self.data_path[idx]) 229 | image_tensor = self.transform(image_set) 230 | image_label = torch.Tensor(self.label_path[idx]) 231 | return image_tensor, image_label 232 | 233 | def load_attri(file_path): 234 | data_path = sorted(glob.glob('./data/img_align_celeba_png/*.png')) 235 | print(len(data_path)) 236 | # get label 237 | att_path = './data/list_attr_celeba.txt' 238 | att_list = open(att_path).readlines()[2:] # start from 2nd row 239 | data_label = [] 240 | for i in range(len(att_list)): 241 | data_label.append(att_list[i].split()) 242 | 243 | # transform label into 0 and 1 244 | for m in range(len(data_label)): 245 | data_label[m] = [n.replace('-1', '0') for n in data_label[m]][1:] 246 | data_label[m] = [int(p) for p in data_label[m]] 247 | 248 | dataset = celeba(data_path, data_label) 249 | # split data into train, valid, test set 7:2:1 250 | indices = list(range(202599)) 251 | split_train = 141819 252 | split_valid = 182339 253 | train_idx, valid_idx, test_idx = indices[:split_train], indices[split_train:split_valid], indices[split_valid:] 254 | 255 | train_sampler = SubsetRandomSampler(train_idx) 256 | valid_sampler = SubsetRandomSampler(valid_idx) 257 | test_sampler = SubsetRandomSampler(test_idx) 258 | 259 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, sampler=train_sampler) 260 | 261 | validloader = torch.utils.data.DataLoader(dataset, sampler=valid_sampler) 262 | 263 | testloader = torch.utils.data.DataLoader(dataset, sampler=test_sampler) 264 | 265 | print(len(trainloader)) 266 | print(len(validloader)) 267 | print(len(testloader)) 268 | 269 | return trainloader, validloader, testloader 270 | 271 | 272 | if __name__ == "__main__": 273 | print("ok") -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch, os, time, utils 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from utils import * 5 | from models.discri import MinibatchDiscriminator, DGWGAN 6 | from models.generator import Generator 7 | from models.classify import * 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | def test(model, criterion=None, dataloader=None, device='cuda'): 12 | tf = time.time() 13 | model.eval() 14 | loss, cnt, ACC, correct_top5 = 0.0, 0, 0,0 15 | with torch.no_grad(): 16 | for i,(img, iden) in enumerate(dataloader): 17 | img, iden = img.to(device), iden.to(device) 18 | 19 | bs = img.size(0) 20 | iden = iden.view(-1) 21 | _,out_prob = model(img) 22 | out_iden = torch.argmax(out_prob, dim=1).view(-1) 23 | ACC += torch.sum(iden == out_iden).item() 24 | 25 | 26 | _, top5 = torch.topk(out_prob,5, dim = 1) 27 | for ind,top5pred in enumerate(top5): 28 | if iden[ind] in top5pred: 29 | correct_top5 += 1 30 | 31 | cnt += bs 32 | 33 | return ACC*100.0/cnt, correct_top5*100.0/cnt 34 | 35 | def train_reg(args, model, criterion, optimizer, trainloader, testloader, n_epochs, device='cuda'): 36 | best_ACC = (0.0, 0.0) 37 | 38 | for epoch in range(n_epochs): 39 | tf = time.time() 40 | ACC, cnt, loss_tot = 0, 0, 0.0 41 | model.train() 42 | 43 | for i, (img, iden) in enumerate(trainloader): 44 | img, iden = img.to(device), iden.to(device) 45 | bs = img.size(0) 46 | iden = iden.view(-1) 47 | 48 | feats, out_prob = model(img) 49 | cross_loss = criterion(out_prob, iden) 50 | loss = cross_loss 51 | 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | out_iden = torch.argmax(out_prob, dim=1).view(-1) 57 | ACC += torch.sum(iden == out_iden).item() 58 | loss_tot += loss.item() * bs 59 | cnt += bs 60 | 61 | train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt 62 | test_acc = test(model, criterion, testloader) 63 | 64 | interval = time.time() - tf 65 | if test_acc[0] > best_ACC[0]: 66 | best_ACC = test_acc 67 | best_model = deepcopy(model) 68 | 69 | print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc[0])) 70 | 71 | print("Best Acc:{:.2f}".format(best_ACC[0])) 72 | return best_model, best_ACC 73 | 74 | def train_vib(args, model, criterion, optimizer, trainloader, testloader, n_epochs, device='cuda'): 75 | best_ACC = (0.0, 0.0) 76 | 77 | for epoch in range(n_epochs): 78 | tf = time.time() 79 | ACC, cnt, loss_tot = 0, 0, 0.0 80 | 81 | for i, (img, iden) in enumerate(trainloader): 82 | img, one_hot, iden = img.to(device), one_hot.to(device), iden.to(device) 83 | bs = img.size(0) 84 | iden = iden.view(-1) 85 | 86 | ___, out_prob, mu, std = model(img, "train") 87 | cross_loss = criterion(out_prob, one_hot) 88 | info_loss = - 0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(dim=1).mean() 89 | loss = cross_loss + beta * info_loss 90 | 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | out_iden = torch.argmax(out_prob, dim=1).view(-1) 96 | ACC += torch.sum(iden == out_iden).item() 97 | loss_tot += loss.item() * bs 98 | cnt += bs 99 | 100 | train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt 101 | test_loss, test_acc = test(model, criterion, testloader) 102 | 103 | interval = time.time() - tf 104 | if test_acc[0] > best_ACC[0]: 105 | best_ACC = test_acc 106 | best_model = deepcopy(model) 107 | 108 | print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc[0])) 109 | 110 | 111 | print("Best Acc:{:.2f}".format(best_ACC[0])) 112 | return best_model, best_ACC 113 | 114 | 115 | def get_T(model_name_T, cfg): 116 | if model_name_T.startswith("VGG16"): 117 | T = VGG16(cfg['dataset']["n_classes"]) 118 | elif model_name_T.startswith('IR152'): 119 | T = IR152(cfg['dataset']["n_classes"]) 120 | elif model_name_T == "FaceNet64": 121 | T = FaceNet64(cfg['dataset']["n_classes"]) 122 | T = torch.nn.DataParallel(T).cuda() 123 | ckp_T = torch.load(cfg[cfg['dataset']['model_name']]['cls_ckpts']) 124 | T.load_state_dict(ckp_T['state_dict'], strict=False) 125 | 126 | return T 127 | 128 | 129 | def train_specific_gan(cfg): 130 | # Hyperparams 131 | file_path = cfg['dataset']['gan_file_path'] 132 | model_name_T = cfg['dataset']['model_name'] 133 | batch_size = cfg[model_name_T]['batch_size'] 134 | z_dim = cfg[model_name_T]['z_dim'] 135 | n_critic = cfg[model_name_T]['n_critic'] 136 | dataset_name = cfg['dataset']['name'] 137 | 138 | # Create save folders 139 | root_path = cfg["root_path"] 140 | save_model_dir = os.path.join(root_path, os.path.join(dataset_name, model_name_T)) 141 | save_img_dir = os.path.join(save_model_dir, "imgs") 142 | os.makedirs(save_model_dir, exist_ok=True) 143 | os.makedirs(save_img_dir, exist_ok=True) 144 | 145 | 146 | # Log file 147 | log_path = os.path.join(save_model_dir, "attack_logs") 148 | os.makedirs(log_path, exist_ok=True) 149 | log_file = "improvedGAN_{}.txt".format(model_name_T) 150 | utils.Tee(os.path.join(log_path, log_file), 'w') 151 | writer = SummaryWriter(log_path) 152 | 153 | 154 | # Load target model 155 | T = get_T(model_name_T=model_name_T, cfg=cfg) 156 | 157 | # Dataset 158 | dataset, dataloader = utils.init_dataloader(cfg, file_path, cfg[model_name_T]['batch_size'], mode="gan") 159 | 160 | # Start Training 161 | print("Training GAN for %s" % model_name_T) 162 | utils.print_params(cfg["dataset"], cfg[model_name_T]) 163 | 164 | G = Generator(cfg[model_name_T]['z_dim']) 165 | DG = MinibatchDiscriminator() 166 | 167 | G = torch.nn.DataParallel(G).cuda() 168 | DG = torch.nn.DataParallel(DG).cuda() 169 | 170 | dg_optimizer = torch.optim.Adam(DG.parameters(), lr=cfg[model_name_T]['lr'], betas=(0.5, 0.999)) 171 | g_optimizer = torch.optim.Adam(G.parameters(), lr=cfg[model_name_T]['lr'], betas=(0.5, 0.999)) 172 | 173 | entropy = HLoss() 174 | 175 | step = 0 176 | for epoch in range(cfg[model_name_T]['epochs']): 177 | start = time.time() 178 | _, unlabel_loader1 = init_dataloader(cfg, file_path, batch_size, mode="gan", iterator=True) 179 | _, unlabel_loader2 = init_dataloader(cfg, file_path, batch_size, mode="gan", iterator=True) 180 | 181 | for i, imgs in enumerate(dataloader): 182 | current_iter = epoch * len(dataloader) + i + 1 183 | 184 | step += 1 185 | imgs = imgs.cuda() 186 | bs = imgs.size(0) 187 | x_unlabel = unlabel_loader1.next() 188 | x_unlabel2 = unlabel_loader2.next() 189 | 190 | freeze(G) 191 | unfreeze(DG) 192 | 193 | z = torch.randn(bs, z_dim).cuda() 194 | f_imgs = G(z) 195 | 196 | y_prob = T(imgs)[-1] 197 | y = torch.argmax(y_prob, dim=1).view(-1) 198 | 199 | 200 | _, output_label = DG(imgs) 201 | _, output_unlabel = DG(x_unlabel) 202 | _, output_fake = DG(f_imgs) 203 | 204 | loss_lab = softXEnt(output_label, y_prob) 205 | loss_unlab = 0.5*(torch.mean(F.softplus(log_sum_exp(output_unlabel)))-torch.mean(log_sum_exp(output_unlabel))+torch.mean(F.softplus(log_sum_exp(output_fake)))) 206 | dg_loss = loss_lab + loss_unlab 207 | 208 | acc = torch.mean((output_label.max(1)[1] == y).float()) 209 | 210 | dg_optimizer.zero_grad() 211 | dg_loss.backward() 212 | dg_optimizer.step() 213 | 214 | writer.add_scalar('loss_label_batch', loss_lab, current_iter) 215 | writer.add_scalar('loss_unlabel_batch', loss_unlab, current_iter) 216 | writer.add_scalar('DG_loss_batch', dg_loss, current_iter) 217 | writer.add_scalar('Acc_batch', acc, current_iter) 218 | 219 | # train G 220 | if step % n_critic == 0: 221 | freeze(DG) 222 | unfreeze(G) 223 | z = torch.randn(bs, z_dim).cuda() 224 | f_imgs = G(z) 225 | mom_gen, output_fake = DG(f_imgs) 226 | mom_unlabel, _ = DG(x_unlabel2) 227 | 228 | mom_gen = torch.mean(mom_gen, dim = 0) 229 | mom_unlabel = torch.mean(mom_unlabel, dim = 0) 230 | 231 | Hloss = entropy(output_fake) 232 | g_loss = torch.mean((mom_gen - mom_unlabel).abs()) + 1e-4 * Hloss 233 | 234 | g_optimizer.zero_grad() 235 | g_loss.backward() 236 | g_optimizer.step() 237 | 238 | writer.add_scalar('G_loss_batch', g_loss, current_iter) 239 | 240 | end = time.time() 241 | interval = end - start 242 | 243 | print("Epoch:%d \tTime:%.2f\tG_loss:%.2f\t train_acc:%.2f" % (epoch, interval, g_loss, acc)) 244 | 245 | torch.save({'state_dict':G.state_dict()}, os.path.join(save_model_dir, "improved_{}_G.tar".format(dataset_name))) 246 | torch.save({'state_dict':DG.state_dict()}, os.path.join(save_model_dir, "improved_{}_D.tar".format(dataset_name))) 247 | 248 | if (epoch+1) % 10 == 0: 249 | z = torch.randn(32, z_dim).cuda() 250 | fake_image = G(z) 251 | save_tensor_images(fake_image.detach(), os.path.join(save_img_dir, "improved_celeba_img_{}.png".format(epoch)), nrow = 8) 252 | 253 | 254 | def train_general_gan(cfg): 255 | # Hyperparams 256 | file_path = cfg['dataset']['gan_file_path'] 257 | model_name = cfg['dataset']['model_name'] 258 | lr = cfg[model_name]['lr'] 259 | batch_size = cfg[model_name]['batch_size'] 260 | z_dim = cfg[model_name]['z_dim'] 261 | epochs = cfg[model_name]['epochs'] 262 | n_critic = cfg[model_name]['n_critic'] 263 | dataset_name = cfg['dataset']['name'] 264 | 265 | 266 | # Create save folders 267 | root_path = cfg["root_path"] 268 | save_model_dir = os.path.join(root_path, os.path.join(dataset_name, 'general_GAN')) 269 | save_img_dir = os.path.join(save_model_dir, "imgs") 270 | os.makedirs(save_model_dir, exist_ok=True) 271 | os.makedirs(save_img_dir, exist_ok=True) 272 | 273 | 274 | # Log file 275 | log_path = os.path.join(save_model_dir, "logs") 276 | os.makedirs(log_path, exist_ok=True) 277 | log_file = "GAN_{}.txt".format(dataset_name) 278 | utils.Tee(os.path.join(log_path, log_file), 'w') 279 | writer = SummaryWriter(log_path) 280 | 281 | 282 | # Dataset 283 | dataset, dataloader = init_dataloader(cfg, file_path, batch_size, mode="gan") 284 | 285 | 286 | # Start Training 287 | print("Training general GAN for %s" % dataset_name) 288 | utils.print_params(cfg["dataset"], cfg[model_name]) 289 | 290 | G = Generator(z_dim) 291 | DG = DGWGAN(3) 292 | 293 | G = torch.nn.DataParallel(G).cuda() 294 | DG = torch.nn.DataParallel(DG).cuda() 295 | 296 | dg_optimizer = torch.optim.Adam(DG.parameters(), lr=lr, betas=(0.5, 0.999)) 297 | g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) 298 | 299 | step = 0 300 | 301 | for epoch in range(epochs): 302 | start = time.time() 303 | for i, imgs in enumerate(dataloader): 304 | 305 | step += 1 306 | imgs = imgs.cuda() 307 | bs = imgs.size(0) 308 | 309 | freeze(G) 310 | unfreeze(DG) 311 | 312 | z = torch.randn(bs, z_dim).cuda() 313 | f_imgs = G(z) 314 | 315 | r_logit = DG(imgs) 316 | f_logit = DG(f_imgs) 317 | 318 | wd = r_logit.mean() - f_logit.mean() # Wasserstein-1 Distance 319 | gp = gradient_penalty(imgs.data, f_imgs.data, DG=DG) 320 | dg_loss = - wd + gp * 10.0 321 | 322 | dg_optimizer.zero_grad() 323 | dg_loss.backward() 324 | dg_optimizer.step() 325 | 326 | # train G 327 | 328 | if step % n_critic == 0: 329 | freeze(DG) 330 | unfreeze(G) 331 | z = torch.randn(bs, z_dim).cuda() 332 | f_imgs = G(z) 333 | logit_dg = DG(f_imgs) 334 | # calculate g_loss 335 | g_loss = - logit_dg.mean() 336 | 337 | g_optimizer.zero_grad() 338 | g_loss.backward() 339 | g_optimizer.step() 340 | 341 | end = time.time() 342 | interval = end - start 343 | 344 | print("Epoch:%d \t Time:%.2f\t Generator loss:%.2f" % (epoch, interval, g_loss)) 345 | if (epoch+1) % 10 == 0: 346 | z = torch.randn(32, z_dim).cuda() 347 | fake_image = G(z) 348 | save_tensor_images(fake_image.detach(), os.path.join(save_img_dir, "result_image_{}.png".format(epoch)), nrow = 8) 349 | 350 | torch.save({'state_dict':G.state_dict()}, os.path.join(save_model_dir, "celeba_G.tar")) 351 | torch.save({'state_dict':DG.state_dict()}, os.path.join(save_model_dir, "celeba_D.tar")) 352 | 353 | def train_augmodel(cfg): 354 | # Hyperparams 355 | target_model_name = cfg['train']['target_model_name'] 356 | student_model_name = cfg['train']['student_model_name'] 357 | device = cfg['train']['device'] 358 | lr = cfg['train']['lr'] 359 | temperature = cfg['train']['temperature'] 360 | dataset_name = cfg['dataset']['name'] 361 | n_classes = cfg['dataset']['n_classes'] 362 | batch_size = cfg['dataset']['batch_size'] 363 | seed = cfg['train']['seed'] 364 | epochs = cfg['train']['epochs'] 365 | log_interval = cfg['train']['log_interval'] 366 | 367 | 368 | # Create save folder 369 | save_dir = os.path.join(cfg['root_path'], dataset_name) 370 | save_dir = os.path.join(save_dir, '{}_{}_{}_{}'.format(target_model_name, student_model_name, lr, temperature)) 371 | os.makedirs(save_dir, exist_ok=True) 372 | 373 | # Log file 374 | now = datetime.now() # current date and time 375 | log_file = "studentKD_logs_{}.txt".format(now.strftime("%m_%d_%Y_%H_%M_%S")) 376 | utils.Tee(os.path.join(save_dir, log_file), 'w') 377 | torch.manual_seed(seed) 378 | 379 | 380 | kwargs = {'batch_size': batch_size} 381 | if device == 'cuda': 382 | kwargs.update({'num_workers': 1, 383 | 'pin_memory': True, 384 | 'shuffle': True}, 385 | ) 386 | 387 | # Get models 388 | teacher_model = get_augmodel(target_model_name, n_classes, cfg['train']['target_model_ckpt']) 389 | model = get_augmodel(student_model_name, n_classes) 390 | model = model.to(device) 391 | print('Target model {}: {} params'.format(target_model_name, count_parameters(model))) 392 | print('Augmented model {}: {} params'.format(student_model_name, count_parameters(teacher_model))) 393 | 394 | 395 | # Optimizer 396 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = 0.9, weight_decay = 1e-4) 397 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 398 | 399 | 400 | # Get dataset 401 | _, dataloader_train = init_dataloader(cfg, cfg['dataset']['gan_file_path'], batch_size, mode="gan") 402 | _, dataloader_test = init_dataloader(cfg, cfg['dataset']['test_file_path'], batch_size, mode="test") 403 | 404 | 405 | # Start training 406 | top1,top5 = test(teacher_model, dataloader=dataloader_test) 407 | print("Target model {}: top 1 = {}, top 5 = {}".format(target_model_name, top1, top5)) 408 | 409 | 410 | loss_function = nn.KLDivLoss(reduction='sum') 411 | teacher_model.eval() 412 | for epoch in range(1, epochs + 1): 413 | model.train() 414 | for batch_idx, data in enumerate(dataloader_train): 415 | data = data.to(device) 416 | 417 | curr_batch_size = len(data) 418 | optimizer.zero_grad() 419 | _, output_t = teacher_model(data) 420 | _, output = model(data) 421 | 422 | loss = loss_function( 423 | F.log_softmax(output / temperature, dim=-1), 424 | F.softmax(output_t / temperature, dim=-1) 425 | ) / (temperature * temperature) / curr_batch_size 426 | 427 | 428 | loss.backward() 429 | optimizer.step() 430 | 431 | if (log_interval > 0) and (batch_idx % log_interval == 0): 432 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 433 | epoch, batch_idx * len(data), len(dataloader_train.dataset), 434 | 100. * batch_idx / len(dataloader_train), loss.item())) 435 | 436 | scheduler.step() 437 | top1, top5 = test(model, dataloader=dataloader_test) 438 | print("epoch {}: top 1 = {}, top 5 = {}".format(epoch, top1, top5)) 439 | 440 | if (epoch+1)%10 == 0: 441 | save_path = os.path.join(save_dir, "{}_{}_kd_{}_{}.pt".format(target_model_name, student_model_name, seed, epoch+1)) 442 | torch.save({'state_dict':model.state_dict()}, save_path) 443 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from metrics.KNN_dist import eval_KNN 3 | from metrics.eval_accuracy import eval_accuracy, eval_acc_class 4 | from metrics.fid import eval_fid 5 | from utils import load_json, get_attack_model 6 | import os 7 | import csv 8 | 9 | parser = ArgumentParser(description='Evaluation') 10 | parser.add_argument('--configs', type=str, default='./config/celeba/attacking/ffhq.json') 11 | 12 | args = parser.parse_args() 13 | 14 | 15 | def init_attack_args(cfg): 16 | if cfg["attack"]["method"] =='kedmi': 17 | args.improved_flag = True 18 | args.clipz = True 19 | args.num_seeds = 1 20 | else: 21 | args.improved_flag = False 22 | args.clipz = False 23 | args.num_seeds = 5 24 | 25 | if cfg["attack"]["variant"] == 'L_logit' or cfg["attack"]["variant"] == 'ours': 26 | args.loss = 'logit_loss' 27 | else: 28 | args.loss = 'cel' 29 | 30 | if cfg["attack"]["variant"] == 'L_aug' or cfg["attack"]["variant"] == 'ours': 31 | args.classid = '0,1,2,3' 32 | else: 33 | args.classid = '0' 34 | 35 | 36 | if __name__ == '__main__': 37 | # Load Data 38 | cfg = load_json(json_file=args.configs) 39 | init_attack_args(cfg=cfg) 40 | 41 | # Save dir 42 | if args.improved_flag == True: 43 | prefix = os.path.join(cfg["root_path"], "kedmi_300ids") 44 | else: 45 | prefix = os.path.join(cfg["root_path"], "gmi_300ids") 46 | save_folder = os.path.join("{}_{}".format(cfg["dataset"]["name"], cfg["dataset"]["model_name"]), cfg["attack"]["variant"]) 47 | prefix = os.path.join(prefix, save_folder) 48 | save_dir = os.path.join(prefix, "latent") 49 | save_img_dir = os.path.join(prefix, "imgs_{}".format(cfg["attack"]["variant"])) 50 | 51 | # Load models 52 | _, E, G, _, _, _, _ = get_attack_model(args, cfg, eval_mode=True) 53 | 54 | # Metrics 55 | metric = cfg["attack"]["eval_metric"].split(',') 56 | fid = 0 57 | aver_acc, aver_acc5, aver_std, aver_std5 = 0, 0, 0, 0 58 | knn = 0, 0 59 | nsamples = 0 60 | dataset, model_types = '', '' 61 | 62 | 63 | 64 | for metric_ in metric: 65 | metric_ = metric_.strip() 66 | if metric_ == 'fid': 67 | fid, nsamples = eval_fid(G=G, E=E, save_dir=save_dir, cfg=cfg, args=args) 68 | elif metric_ == 'acc': 69 | aver_acc, aver_acc5, aver_std, aver_std5 = eval_accuracy(G=G, E=E, save_dir=save_dir, args=args) 70 | elif metric_ == 'knn': 71 | knn = eval_KNN(G=G, E=E, save_dir=save_dir, KNN_real_path=cfg["dataset"]["KNN_real_path"], args=args) 72 | 73 | csv_file = os.path.join(prefix, 'Eval_results.csv') 74 | if not os.path.exists(csv_file): 75 | header = ['Save_dir', 'Method', 'Succesful_samples', 76 | 'acc','std','acc5','std5', 77 | 'fid','knn'] 78 | with open(csv_file, 'w') as f: 79 | writer = csv.writer(f) 80 | writer.writerow(header) 81 | 82 | fields=['{}'.format(save_dir), 83 | '{}'.format(cfg["attack"]["method"]), 84 | '{}'.format(cfg["attack"]["variant"]), 85 | '{:.2f}'.format(aver_acc), 86 | '{:.2f}'.format(aver_std), 87 | '{:.2f}'.format(aver_acc5), 88 | '{:.2f}'.format(aver_std5), 89 | '{:.2f}'.format(fid), 90 | '{:.2f}'.format(knn)] 91 | 92 | print("---------------Evaluation---------------") 93 | print('Method: {} '.format(cfg["attack"]["method"])) 94 | 95 | print('Variant: {}'.format(cfg["attack"]["variant"])) 96 | print('Top 1 attack accuracy:{:.2f} +/- {:.2f} '.format(aver_acc, aver_std)) 97 | print('Top 5 attack accuracy:{:.2f} +/- {:.2f} '.format(aver_acc5, aver_std5)) 98 | print('KNN distance: {:.3f}'.format(knn)) 99 | print('FID score: {:.3f}'.format(fid)) 100 | 101 | print("----------------------------------------") 102 | with open(csv_file, 'a') as f: 103 | writer = csv.writer(f) 104 | writer.writerow(fields) 105 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | def completion_network_loss(input, output, mask): 5 | bs = input.size(0) 6 | loss = torch.sum(torch.abs(output * mask - input * mask)) / bs 7 | #return mse_loss(output * mask, input * mask) 8 | return loss 9 | 10 | def noise_loss(V, img1, img2): 11 | feat1 = V(img1)[0] 12 | feat2 = V(img2)[0] 13 | loss = torch.mean(torch.abs(feat1 - feat2)) 14 | return loss 15 | 16 | class ContextLoss(_Loss): 17 | def forward(self, mask, gen, images): 18 | bs = gen.size(0) 19 | context_loss = torch.sum(torch.abs(torch.mul(mask, gen) - torch.mul(mask, images))) / bs 20 | return context_loss 21 | 22 | class CrossEntropyLoss(_Loss): 23 | def forward(self, out, gt): 24 | bs = out.size(0) 25 | #print(out.size(), gt.size()) 26 | loss = - torch.mul(gt.float(), torch.log(out.float() + 1e-7)) 27 | loss = torch.sum(loss) / bs 28 | return loss 29 | -------------------------------------------------------------------------------- /metrics/KNN_dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from metrics.fid import concatenate_list, gen_samples 5 | from utils import load_json, save_tensor_images 6 | 7 | 8 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | def find_shortest_dist(fea_target,fea_fake): 11 | shortest_dist = 0 12 | pdist = torch.nn.PairwiseDistance(p=2) 13 | 14 | fea_target = torch.from_numpy(fea_target).to(device) 15 | fea_fake = torch.from_numpy(fea_fake).to(device) 16 | # print('---fea_fake.shape[0]',fea_fake.shape[0]) 17 | for i in range(fea_fake.shape[0]): 18 | dist = pdist(fea_fake[i,:], fea_target) 19 | 20 | min_d = min(dist) 21 | 22 | # print('--KNN dist',min_d) 23 | shortest_dist = shortest_dist + min_d*min_d 24 | # print('--KNN dist',shortest_dist) 25 | 26 | return shortest_dist 27 | 28 | def run_KNN(target_dir, fake_dir): 29 | knn = 0 30 | target = np.load(target_dir,allow_pickle=True) 31 | fake = np.load(fake_dir,allow_pickle=True) 32 | target_fea = target.item().get('fea') 33 | target_y = target.item().get('label') 34 | fake_fea = fake.item().get('fea') 35 | fake_y = fake.item().get('label') 36 | 37 | fake_fea = concatenate_list(fake_fea) 38 | fake_y = concatenate_list(fake_y) 39 | 40 | N = fake_fea.shape[0] 41 | for id in range(300): 42 | id_f = fake_y == id 43 | id_t = target_y == id 44 | fea_f = fake_fea[id_f,:] 45 | fea_t = target_fea[id_t] 46 | 47 | shorted_dist = find_shortest_dist(fea_t,fea_f) 48 | knn = knn + shorted_dist 49 | 50 | return knn/N 51 | 52 | def eval_KNN(G, E, save_dir, KNN_real_path, args): 53 | 54 | fea_path, _ = gen_samples(G, E, save_dir, args.improved_flag) 55 | 56 | fea_path = fea_path + 'full.npy' 57 | 58 | knn = run_KNN(KNN_real_path, fea_path) 59 | print("KNN:{:.3f} ".format(knn)) 60 | return knn 61 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/eval_accuracy.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from models.classify import * 3 | from models.generator import * 4 | from models.discri import * 5 | import torch 6 | import numpy as np 7 | 8 | from attack import attack_acc 9 | import statistics 10 | 11 | from metrics.fid import concatenate_list, gen_samples 12 | 13 | 14 | device = torch.torch.cuda.is_available() 15 | 16 | def accuracy(fake_dir, E): 17 | 18 | aver_acc, aver_acc5, aver_std, aver_std5 = 0, 0, 0, 0 19 | 20 | N = 5 21 | E.eval() 22 | for i in range(1): 23 | all_fake = np.load(fake_dir+'full.npy',allow_pickle=True) 24 | all_imgs = all_fake.item().get('imgs') 25 | all_label = all_fake.item().get('label') 26 | 27 | # calculate attack accuracy 28 | with torch.no_grad(): 29 | N_succesful = 0 30 | N_failure = 0 31 | 32 | for random_seed in range(len(all_imgs)): 33 | if random_seed % N == 0: 34 | res, res5 = [], [] 35 | 36 | #################### attack accuracy ################# 37 | fake = all_imgs[random_seed] 38 | label = all_label[random_seed] 39 | 40 | label = torch.from_numpy(label) 41 | fake = torch.from_numpy(fake) 42 | 43 | acc,acc5 = attack_acc(fake,label,E) 44 | 45 | 46 | print("Seed:{} Top1/Top5:{:.3f}/{:.3f}\t".format(random_seed, acc,acc5)) 47 | res.append(acc) 48 | res5.append(acc5) 49 | 50 | 51 | if (random_seed+1)%5 == 0: 52 | acc, acc_5 = statistics.mean(res), statistics.mean(res5) 53 | std = statistics.stdev(res) 54 | std5 = statistics.stdev(res5) 55 | 56 | print("Top1/Top5:{:.3f}/{:.3f}, std top1/top5:{:.3f}/{:.3f}".format(acc, acc_5, std, std5)) 57 | 58 | aver_acc += acc / N 59 | aver_acc5 += acc5 / N 60 | aver_std += std / N 61 | aver_std5 += std5 / N 62 | print('N_succesful',N_succesful,N_failure) 63 | 64 | 65 | return aver_acc, aver_acc5, aver_std, aver_std5 66 | 67 | 68 | 69 | def eval_accuracy(G, E, save_dir, args): 70 | 71 | successful_imgs, _ = gen_samples(G, E, save_dir, args.improved_flag) 72 | 73 | aver_acc, aver_acc5, \ 74 | aver_std, aver_std5 = accuracy(successful_imgs, E) 75 | 76 | 77 | return aver_acc, aver_acc5, aver_std, aver_std5 78 | 79 | def acc_class(filename,fake_dir,E): 80 | 81 | E.eval() 82 | 83 | sucessful_fake = np.load(fake_dir + 'success.npy',allow_pickle=True) 84 | sucessful_imgs = sucessful_fake.item().get('sucessful_imgs') 85 | sucessful_label = sucessful_fake.item().get('label') 86 | sucessful_imgs = concatenate_list(sucessful_imgs) 87 | sucessful_label = concatenate_list(sucessful_label) 88 | 89 | N_img = 5 90 | N_id = 300 91 | with torch.no_grad(): 92 | acc = np.zeros(N_id) 93 | for id in range(N_id): 94 | index = sucessful_label == id 95 | acc[id] = sum(index) 96 | 97 | acc=acc*100.0/N_img 98 | print('acc',acc) 99 | csv_file = '{}acc_class.csv'.format(filename) 100 | print('csv_file',csv_file) 101 | import csv 102 | with open(csv_file, 'a') as f: 103 | writer = csv.writer(f) 104 | for i in range(N_id): 105 | # writer.writerow(['{}'.format(i),'{}'.format(acc[i])]) 106 | writer.writerow([i,acc[i]]) 107 | 108 | def eval_acc_class(G, E, save_dir, prefix, args): 109 | 110 | successful_imgs, _ = gen_samples(G, E, save_dir, args.improved_flag) 111 | 112 | filename = "{}/{}_".format(prefix, args.loss) 113 | 114 | acc_class(filename,successful_imgs,E) 115 | 116 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import numpy as np 4 | from scipy import linalg 5 | from metrics import metric_utils 6 | import utils 7 | from attack import reparameterize 8 | import os 9 | from utils import save_tensor_images 10 | 11 | 12 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | _feature_detector_cache = None 14 | def get_feature_detector(): 15 | global _feature_detector_cache 16 | if _feature_detector_cache is None: 17 | _feature_detector_cache = metric_utils.get_feature_detector( 18 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 19 | 'metrics/inception-2015-12-05.pt', device) 20 | 21 | return _feature_detector_cache 22 | 23 | 24 | def postprocess(x): 25 | """.""" 26 | return ((x * .5 + .5) * 255).to(torch.uint8) 27 | 28 | 29 | def run_fid(x1, x2): 30 | # Extract features 31 | x1 = run_batch_extract(x1, device) 32 | x2 = run_batch_extract(x2, device) 33 | 34 | npx1 = x1.detach().cpu().numpy() 35 | npx2 = x2.detach().cpu().numpy() 36 | mu1 = np.mean(npx1, axis=0) 37 | sigma1 = np.cov(npx1, rowvar=False) 38 | mu2 = np.mean(npx2, axis=0) 39 | sigma2 = np.cov(npx2, rowvar=False) 40 | frechet = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) 41 | return frechet 42 | 43 | 44 | def run_feature_extractor(x): 45 | assert x.dtype == torch.uint8 46 | assert x.min() >= 0 47 | assert x.max() <= 255 48 | assert len(x.shape) == 4 49 | assert x.shape[1] == 3 50 | feature_extractor = get_feature_detector() 51 | return feature_extractor(x, return_features=True) 52 | 53 | 54 | def run_batch_extract(x, device, bs=500): 55 | z = [] 56 | with torch.no_grad(): 57 | for start in tqdm(range(0, len(x), bs), desc='run_batch_extract'): 58 | stop = start + bs 59 | x_ = x[start:stop].to(device) 60 | z_ = run_feature_extractor(postprocess(x_)).cpu() 61 | z.append(z_) 62 | z = torch.cat(z) 63 | return z 64 | 65 | 66 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6, return_details=False): 67 | """Numpy implementation of the Frechet Distance. 68 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 69 | and X_2 ~ N(mu_2, C_2) is 70 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 71 | 72 | Stable version by Dougal J. Sutherland. 73 | 74 | Params: 75 | -- mu1 : Numpy array containing the activations of a layer of the 76 | inception net (like returned by the function 'get_predictions') 77 | for generated samples. 78 | -- mu2 : The sample mean over activations, precalculated on an 79 | representative data set. 80 | -- sigma1: The covariance matrix over activations for generated samples. 81 | -- sigma2: The covariance matrix over activations, precalculated on an 82 | representative data set. 83 | 84 | Returns: 85 | -- : The Frechet Distance. 86 | """ 87 | 88 | mu1 = np.atleast_1d(mu1) 89 | mu2 = np.atleast_1d(mu2) 90 | 91 | sigma1 = np.atleast_2d(sigma1) 92 | sigma2 = np.atleast_2d(sigma2) 93 | 94 | assert mu1.shape == mu2.shape, \ 95 | 'Training and test mean vectors have different lengths' 96 | assert sigma1.shape == sigma2.shape, \ 97 | 'Training and test covariances have different dimensions' 98 | 99 | diff = mu1 - mu2 100 | 101 | # Product might be almost singular 102 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 103 | if not np.isfinite(covmean).all(): 104 | msg = ('fid calculation produces singular product; ' 105 | 'adding %s to diagonal of cov estimates') % eps 106 | print(msg) 107 | offset = np.eye(sigma1.shape[0]) * eps 108 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 109 | 110 | # Numerical error might give slight imaginary component 111 | if np.iscomplexobj(covmean): 112 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 113 | m = np.max(np.abs(covmean.imag)) 114 | raise ValueError('Imaginary component {}'.format(m)) 115 | covmean = covmean.real 116 | 117 | tr_covmean = np.trace(covmean) 118 | if not return_details: 119 | return (diff.dot(diff) + np.trace(sigma1) + 120 | np.trace(sigma2) - 2 * tr_covmean) 121 | else: 122 | t1 = diff.dot(diff) 123 | t2 = np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 124 | return (t1 + t2), t1, t2 125 | 126 | 127 | def get_z(improved_gan, save_dir, loop, i, j): 128 | if improved_gan==True: #KEDMI 129 | outputs_z = os.path.join(save_dir, "{}_{}_iter_0_{}_dis.npy".format(loop, i, 2399)) 130 | outputs_label = os.path.join(save_dir, "{}_{}_iter_0_{}_label.npy".format(loop, i, 2399)) 131 | 132 | dis = np.load(outputs_z, allow_pickle=True) 133 | mu = torch.from_numpy(dis.item().get('mu')).to(device) 134 | log_var = torch.from_numpy(dis.item().get('log_var')).to(device) 135 | iden = np.load(outputs_label) 136 | z = reparameterize(mu, log_var) 137 | else: #GMI 138 | outputs_z = os.path.join(save_dir, "{}_{}_iter_{}_{}_z.npy".format(save_dir, loop, i, j, 2399)) 139 | outputs_label = os.path.join(save_dir, "{}_{}_iter_{}_{}_label.npy".format(save_dir, loop, i, j, 2399)) 140 | 141 | z = np.load(outputs_z) 142 | iden = np.load(outputs_label) 143 | z = torch.from_numpy(z).to(device) 144 | return z, iden 145 | 146 | def gen_samples(G, E, save_dir, improved_gan, n_iden=5, n_img=5): 147 | total_gen = 0 148 | seed = 9 149 | torch.manual_seed(seed) 150 | img_ids_path = os.path.join(save_dir, 'attack{}_'.format(seed)) 151 | 152 | all_sucessful_imgs = [] 153 | all_failure_imgs = [] 154 | 155 | all_imgs = [] 156 | all_fea = [] 157 | all_id = [] 158 | all_sucessful_imgs = [] 159 | all_sucessful_id =[] 160 | all_sucessful_fea=[] 161 | 162 | all_failure_imgs = [] 163 | all_failure_fea = [] 164 | all_failure_id = [] 165 | 166 | E.eval() 167 | G.eval() 168 | if not os.path.exists(img_ids_path + 'full.npy'): 169 | for loop in range(1): 170 | for i in range(n_iden): #300 ides 171 | for j in range(n_img): #5 images/iden 172 | z, iden = get_z(improved_gan, save_dir, loop, i, j) 173 | z = torch.clamp(z, -1.0, 1.0).float() 174 | total_gen = total_gen + z.shape[0] 175 | # calculate attack accuracy 176 | with torch.no_grad(): 177 | fake = G(z.to(device)) 178 | save_tensor_images(fake, os.path.join(save_dir, "gen_{}_{}.png".format(i,j)), nrow = 60) 179 | 180 | eval_fea, eval_prob = E(utils.low2high(fake)) 181 | 182 | ### successfully attacked samples 183 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1) 184 | eval_iden = torch.argmax(eval_prob, dim=1).view(-1) 185 | sucessful_iden = [] 186 | failure_iden = [] 187 | for id in range(iden.shape[0]): 188 | if eval_iden[id]==iden[id]: 189 | sucessful_iden.append(id) 190 | else: 191 | failure_iden.append(id) 192 | 193 | 194 | fake = fake.detach().cpu().numpy() 195 | eval_fea = eval_fea.detach().cpu().numpy() 196 | 197 | all_imgs.append(fake) 198 | all_fea.append(eval_fea) 199 | all_id.append(iden) 200 | 201 | if len(sucessful_iden)>0: 202 | sucessful_iden = np.array(sucessful_iden) 203 | sucessful_fake = fake[sucessful_iden,:,:,:] 204 | sucessful_eval_fea = eval_fea[sucessful_iden,:] 205 | sucessful_iden = iden[sucessful_iden] 206 | else: 207 | sucessful_fake = [] 208 | sucessful_iden = [] 209 | sucessful_eval_fea = [] 210 | 211 | all_sucessful_imgs.append(sucessful_fake) 212 | all_sucessful_id.append(sucessful_iden) 213 | all_sucessful_fea.append(sucessful_eval_fea) 214 | 215 | if len(failure_iden)>0: 216 | failure_iden = np.array(failure_iden) 217 | failure_fake = fake[failure_iden,:,:,:] 218 | failure_eval_fea = eval_fea[failure_iden,:] 219 | failure_iden = iden[failure_iden] 220 | else: 221 | failure_fake = [] 222 | failure_iden = [] 223 | failure_eval_fea = [] 224 | 225 | all_failure_imgs.append(failure_fake) 226 | all_failure_id.append(failure_iden) 227 | all_failure_fea.append(failure_eval_fea) 228 | np.save(img_ids_path+'full',{'imgs':all_imgs,'label':all_id,'fea':all_fea}) 229 | np.save(img_ids_path+'success',{'sucessful_imgs':all_sucessful_imgs,'label':all_sucessful_id,'sucessful_fea':all_sucessful_fea}) 230 | np.save(img_ids_path+'failure',{'failure_imgs':all_failure_imgs,'label':all_failure_id,'failure_fea':all_failure_fea}) 231 | 232 | return img_ids_path, total_gen 233 | 234 | 235 | def concatenate_list(listA): 236 | result = [] 237 | for i in range(len(listA)): 238 | val = listA [i] 239 | if len(val)>0: 240 | if len(result)==0: 241 | result = listA[i] 242 | else: 243 | result = np.concatenate((result,val)) 244 | return result 245 | 246 | 247 | def eval_fid(G, E, save_dir, cfg, args): 248 | 249 | successful_imgs,_ = gen_samples(G, E, save_dir, args.improved_flag) 250 | 251 | #real data 252 | target_x = np.load(cfg['dataset']['fid_real_path']) 253 | 254 | # Load Samples 255 | sucessful_data = np.load(successful_imgs+'success.npy',allow_pickle=True) 256 | fake = sucessful_data.item().get('sucessful_imgs') 257 | 258 | fake = concatenate_list(fake) 259 | 260 | fake = torch.from_numpy(fake).to(device) 261 | target_x = torch.from_numpy(target_x).to(device) 262 | fid = run_fid(target_x, fake) 263 | 264 | return fid, fake.shape[0] 265 | -------------------------------------------------------------------------------- /metrics/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import pickle 13 | import copy 14 | import uuid 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class MetricOptions: 22 | def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): 23 | assert 0 <= rank < num_gpus 24 | self.G = G 25 | self.G_kwargs = dnnlib.EasyDict(G_kwargs) 26 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) 27 | self.num_gpus = num_gpus 28 | self.rank = rank 29 | self.device = device if device is not None else torch.device('cuda', rank) 30 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() 31 | self.cache = cache 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _feature_detector_cache = dict() 36 | 37 | def get_feature_detector_name(url): 38 | return os.path.splitext(url.split('/')[-1])[0] 39 | 40 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 41 | assert 0 <= rank < num_gpus 42 | print('device',device) 43 | key = (url, device) 44 | if key not in _feature_detector_cache: 45 | is_leader = (rank == 0) 46 | if not is_leader and num_gpus > 1: 47 | torch.distributed.barrier() # leader goes first 48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: 49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 50 | if is_leader and num_gpus > 1: 51 | torch.distributed.barrier() # others follow 52 | return _feature_detector_cache[key] 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | class FeatureStats: 57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): 58 | self.capture_all = capture_all 59 | self.capture_mean_cov = capture_mean_cov 60 | self.max_items = max_items 61 | self.num_items = 0 62 | self.num_features = None 63 | self.all_features = None 64 | self.raw_mean = None 65 | self.raw_cov = None 66 | 67 | def set_num_features(self, num_features): 68 | if self.num_features is not None: 69 | assert num_features == self.num_features 70 | else: 71 | self.num_features = num_features 72 | self.all_features = [] 73 | self.raw_mean = np.zeros([num_features], dtype=np.float64) 74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) 75 | 76 | def is_full(self): 77 | return (self.max_items is not None) and (self.num_items >= self.max_items) 78 | 79 | def append(self, x): 80 | x = np.asarray(x, dtype=np.float32) 81 | assert x.ndim == 2 82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): 83 | if self.num_items >= self.max_items: 84 | return 85 | x = x[:self.max_items - self.num_items] 86 | 87 | self.set_num_features(x.shape[1]) 88 | self.num_items += x.shape[0] 89 | if self.capture_all: 90 | self.all_features.append(x) 91 | if self.capture_mean_cov: 92 | x64 = x.astype(np.float64) 93 | self.raw_mean += x64.sum(axis=0) 94 | self.raw_cov += x64.T @ x64 95 | 96 | def append_torch(self, x, num_gpus=1, rank=0): 97 | assert isinstance(x, torch.Tensor) and x.ndim == 2 98 | assert 0 <= rank < num_gpus 99 | if num_gpus > 1: 100 | ys = [] 101 | for src in range(num_gpus): 102 | y = x.clone() 103 | torch.distributed.broadcast(y, src=src) 104 | ys.append(y) 105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples 106 | self.append(x.cpu().numpy()) 107 | 108 | def get_all(self): 109 | assert self.capture_all 110 | return np.concatenate(self.all_features, axis=0) 111 | 112 | def get_all_torch(self): 113 | return torch.from_numpy(self.get_all()) 114 | 115 | def get_mean_cov(self): 116 | assert self.capture_mean_cov 117 | mean = self.raw_mean / self.num_items 118 | cov = self.raw_cov / self.num_items 119 | cov = cov - np.outer(mean, mean) 120 | return mean, cov 121 | 122 | def save(self, pkl_file): 123 | with open(pkl_file, 'wb') as f: 124 | pickle.dump(self.__dict__, f) 125 | 126 | @staticmethod 127 | def load(pkl_file): 128 | with open(pkl_file, 'rb') as f: 129 | s = dnnlib.EasyDict(pickle.load(f)) 130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) 131 | obj.__dict__.update(s) 132 | return obj 133 | 134 | #---------------------------------------------------------------------------- 135 | 136 | class ProgressMonitor: 137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): 138 | self.tag = tag 139 | self.num_items = num_items 140 | self.verbose = verbose 141 | self.flush_interval = flush_interval 142 | self.progress_fn = progress_fn 143 | self.pfn_lo = pfn_lo 144 | self.pfn_hi = pfn_hi 145 | self.pfn_total = pfn_total 146 | self.start_time = time.time() 147 | self.batch_time = self.start_time 148 | self.batch_items = 0 149 | if self.progress_fn is not None: 150 | self.progress_fn(self.pfn_lo, self.pfn_total) 151 | 152 | def update(self, cur_items): 153 | assert (self.num_items is None) or (cur_items <= self.num_items) 154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): 155 | return 156 | cur_time = time.time() 157 | total_time = cur_time - self.start_time 158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) 159 | if (self.verbose) and (self.tag is not None): 160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') 161 | self.batch_time = cur_time 162 | self.batch_items = cur_items 163 | 164 | if (self.progress_fn is not None) and (self.num_items is not None): 165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) 166 | 167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): 168 | return ProgressMonitor( 169 | tag = tag, 170 | num_items = num_items, 171 | flush_interval = flush_interval, 172 | verbose = self.verbose, 173 | progress_fn = self.progress_fn, 174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, 175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, 176 | pfn_total = self.pfn_total, 177 | ) 178 | 179 | #---------------------------------------------------------------------------- 180 | 181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): 182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 183 | if data_loader_kwargs is None: 184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) 185 | 186 | # Try to lookup from cache. 187 | cache_file = None 188 | if opts.cache: 189 | # Choose cache file name. 190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) 191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) 192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' 193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') 194 | 195 | # Check if the file exists (all processes must agree). 196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False 197 | if opts.num_gpus > 1: 198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) 199 | torch.distributed.broadcast(tensor=flag, src=0) 200 | flag = (float(flag.cpu()) != 0) 201 | 202 | # Load. 203 | if flag: 204 | return FeatureStats.load(cache_file) 205 | 206 | # Initialize. 207 | num_items = len(dataset) 208 | if max_items is not None: 209 | num_items = min(num_items, max_items) 210 | stats = FeatureStats(max_items=num_items, **stats_kwargs) 211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) 212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 213 | 214 | # Main loop. 215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] 216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): 217 | if images.shape[1] == 1: 218 | images = images.repeat([1, 3, 1, 1]) 219 | features = detector(images.to(opts.device), **detector_kwargs) 220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 221 | progress.update(stats.num_items) 222 | 223 | # Save to cache. 224 | if cache_file is not None and opts.rank == 0: 225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 226 | temp_file = cache_file + '.' + uuid.uuid4().hex 227 | stats.save(temp_file) 228 | os.replace(temp_file, cache_file) # atomic 229 | return stats 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs): 234 | if batch_gen is None: 235 | batch_gen = min(batch_size, 4) 236 | assert batch_size % batch_gen == 0 237 | 238 | # Setup generator and load labels. 239 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 240 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 241 | 242 | # Image generation func. 243 | def run_generator(z, c): 244 | img = G(z=z, c=c, **opts.G_kwargs) 245 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 246 | return img 247 | 248 | # JIT. 249 | if jit: 250 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device) 251 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device) 252 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) 253 | 254 | # Initialize. 255 | stats = FeatureStats(**stats_kwargs) 256 | assert stats.max_items is not None 257 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) 258 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 259 | 260 | # Main loop. 261 | while not stats.is_full(): 262 | images = [] 263 | for _i in range(batch_size // batch_gen): 264 | z = torch.randn([batch_gen, G.z_dim], device=opts.device) 265 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] 266 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 267 | images.append(run_generator(z, c)) 268 | images = torch.cat(images) 269 | if images.shape[1] == 1: 270 | images = images.repeat([1, 3, 1, 1]) 271 | features = detector(images, **detector_kwargs) 272 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 273 | progress.update(stats.num_items) 274 | return stats 275 | 276 | #---------------------------------------------------------------------------- 277 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sutd-visual-computing-group/Re-thinking_MI/ed4b6b8900e4d7fed8f70d7e6a444893963d920e/models/__init__.py -------------------------------------------------------------------------------- /models/classify.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models 5 | import torch.nn.functional as F 6 | # from torch.nn.modules.loss import _Loss 7 | import models.evolve as evolve 8 | import utils 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | class MCNN2(nn.Module): 15 | def __init__(self, num_classes=10): 16 | super(MCNN2, self).__init__() 17 | self.feat_dim = 12800 18 | self.num_classes = num_classes 19 | self.feature = nn.Sequential( 20 | nn.Conv2d(1, 64, 7, stride=1, padding=1), 21 | nn.BatchNorm2d(64), 22 | nn.LeakyReLU(0.2), 23 | nn.MaxPool2d(2, 2), 24 | nn.Conv2d(64, 128, 5, stride=1), 25 | nn.BatchNorm2d(128), 26 | nn.LeakyReLU(0.2)) 27 | 28 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 29 | 30 | def forward(self, x): 31 | feature = self.feature(x) 32 | feature = feature.view(feature.size(0), -1) 33 | out = self.fc_layer(feature) 34 | return feature,out 35 | 36 | class MCNN4(nn.Module): 37 | def __init__(self, num_classes=10): 38 | super(MCNN4, self).__init__() 39 | self.feat_dim = 128 40 | self.num_classes = num_classes 41 | self.feature = nn.Sequential( 42 | nn.Conv2d(1, 32, 7, stride=1, padding=1), 43 | nn.BatchNorm2d(32), 44 | nn.LeakyReLU(0.2), 45 | nn.MaxPool2d(2, 2), 46 | nn.Conv2d(32, 64, 5, stride=1), 47 | nn.BatchNorm2d(64), 48 | nn.LeakyReLU(0.2), 49 | nn.MaxPool2d(2, 2), 50 | nn.Conv2d(64, 128, 5, stride=1), 51 | nn.BatchNorm2d(128), 52 | nn.LeakyReLU(0.2)) 53 | 54 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 55 | 56 | def forward(self, x): 57 | feature = self.feature(x) 58 | feature = feature.view(feature.size(0), -1) 59 | out = self.fc_layer(feature) 60 | return feature,out 61 | 62 | class MCNN(nn.Module): 63 | def __init__(self, num_classes=10): 64 | super(MCNN, self).__init__() 65 | self.feat_dim = 256 66 | self.num_classes = num_classes 67 | self.feature = nn.Sequential( 68 | nn.Conv2d(1, 64, 7, stride=1, padding=1), 69 | nn.BatchNorm2d(64), 70 | nn.LeakyReLU(0.2), 71 | nn.MaxPool2d(2, 2), 72 | nn.Conv2d(64, 128, 5, stride=1), 73 | nn.BatchNorm2d(128), 74 | nn.LeakyReLU(0.2), 75 | nn.MaxPool2d(2, 2), 76 | nn.Conv2d(128, 256, 5, stride=1), 77 | nn.BatchNorm2d(256), 78 | nn.LeakyReLU(0.2)) 79 | 80 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 81 | 82 | def forward(self, x): 83 | feature = self.feature(x) 84 | feature = feature.view(feature.size(0), -1) 85 | out = self.fc_layer(feature) 86 | return feature,out 87 | 88 | class SCNN(nn.Module): 89 | def __init__(self, num_classes=10): 90 | super(SCNN, self).__init__() 91 | self.feat_dim = 512 92 | self.num_classes = num_classes 93 | self.feature = nn.Sequential( 94 | nn.Conv2d(1, 32, 7, stride=1, padding=1), 95 | nn.BatchNorm2d(32), 96 | nn.LeakyReLU(0.2), 97 | nn.MaxPool2d(2, 2), 98 | nn.Conv2d(32, 64, 5, stride=1), 99 | nn.BatchNorm2d(64), 100 | nn.LeakyReLU(0.2), 101 | nn.MaxPool2d(2, 2), 102 | nn.Conv2d(64, 128, 5, stride=1), 103 | nn.BatchNorm2d(128), 104 | nn.LeakyReLU(0.2), 105 | nn.Conv2d(128, 256, 3, stride=1, padding=1), 106 | nn.BatchNorm2d(256), 107 | nn.LeakyReLU(0.2), 108 | nn.Conv2d(256, 512, 3, stride=1, padding=1), 109 | nn.BatchNorm2d(512), 110 | nn.LeakyReLU(0.2)) 111 | 112 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 113 | 114 | def forward(self, x): 115 | feature = self.feature(x) 116 | feature = feature.view(feature.size(0), -1) 117 | out = self.fc_layer(feature) 118 | return feature,out 119 | 120 | class Mnist_CNN(nn.Module): 121 | def __init__(self): 122 | super(Mnist_CNN, self).__init__() 123 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 124 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 125 | self.conv2_drop = nn.Dropout2d() 126 | self.fc1 = nn.Linear(500, 50) 127 | self.fc2 = nn.Linear(50, 5) 128 | 129 | def forward(self, x): 130 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 131 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 132 | x = x.view(x.size(0), -1) 133 | x = F.relu(self.fc1(x)) 134 | x = F.dropout(x, training=self.training) 135 | res = self.fc2(x) 136 | return [x, res] 137 | 138 | 139 | class VGG16_xray8(nn.Module): 140 | def __init__(self, num_classes=7): 141 | super(VGG16_xray8, self).__init__() 142 | model = torchvision.models.vgg16_bn(pretrained=True) 143 | model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 144 | self.feature = model.features 145 | self.feat_dim = 2048 146 | self.num_classes = num_classes 147 | self.bn = nn.BatchNorm1d(self.feat_dim) 148 | self.bn.bias.requires_grad_(False) # no shift 149 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 150 | self.model = model 151 | 152 | def forward(self, x): 153 | feature = self.feature(x) 154 | feature = feature.view(feature.size(0), -1) 155 | feature = self.bn(feature) 156 | res = self.fc_layer(feature) 157 | 158 | return feature,res 159 | 160 | def predict(self, x): 161 | feature = self.feature(x) 162 | feature = feature.view(feature.size(0), -1) 163 | feature = self.bn(feature) 164 | res = self.fc_layer(feature) 165 | 166 | return res 167 | 168 | class VGG16(nn.Module): 169 | def __init__(self, n_classes): 170 | super(VGG16, self).__init__() 171 | model = torchvision.models.vgg16_bn(pretrained=True) 172 | self.feature = model.features 173 | self.feat_dim = 512 * 2 * 2 174 | self.n_classes = n_classes 175 | self.bn = nn.BatchNorm1d(self.feat_dim) 176 | self.bn.bias.requires_grad_(False) # no shift 177 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 178 | 179 | def forward(self, x): 180 | feature = self.feature(x) 181 | feature = feature.view(feature.size(0), -1) 182 | feature = self.bn(feature) 183 | res = self.fc_layer(feature) 184 | return feature,res 185 | 186 | def predict(self, x): 187 | feature = self.feature(x) 188 | feature = feature.view(feature.size(0), -1) 189 | feature = self.bn(feature) 190 | res = self.fc_layer(feature) 191 | out = F.softmax(res, dim=1) 192 | 193 | return feature,out 194 | 195 | class VGG16_vib(nn.Module): 196 | def __init__(self, n_classes): 197 | super(VGG16_vib, self).__init__() 198 | model = torchvision.models.vgg16_bn(pretrained=True) 199 | self.feature = model.features 200 | self.feat_dim = 512 * 2 * 2 201 | self.k = self.feat_dim // 2 202 | self.n_classes = n_classes 203 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 204 | self.fc_layer = nn.Linear(self.k, self.n_classes) 205 | 206 | def forward(self, x, mode="train"): 207 | feature = self.feature(x) 208 | feature = feature.view(feature.size(0), -1) 209 | statis = self.st_layer(feature) 210 | mu, std = statis[:, :self.k], statis[:, self.k:] 211 | 212 | std = F.softplus(std-5, beta=1) 213 | eps = torch.FloatTensor(std.size()).normal_().cuda() 214 | res = mu + std * eps 215 | out = self.fc_layer(res) 216 | 217 | return [feature, out, mu, std] 218 | 219 | def predict(self, x): 220 | feature = self.feature(x) 221 | feature = feature.view(feature.size(0), -1) 222 | statis = self.st_layer(feature) 223 | mu, std = statis[:, :self.k], statis[:, self.k:] 224 | 225 | std = F.softplus(std-5, beta=1) 226 | eps = torch.FloatTensor(std.size()).normal_().cuda() 227 | res = mu + std * eps 228 | out = self.fc_layer(res) 229 | 230 | return out 231 | 232 | class VGG19(nn.Module): 233 | def __init__(self, num_of_classes): 234 | super(VGG19, self).__init__() 235 | model = torchvision.models.vgg19_bn(pretrained=True) 236 | self.feature = nn.Sequential(*list(model.children())[:-2]) 237 | 238 | self.feat_dim = 512 * 2 * 2 239 | self.num_of_classes = num_of_classes 240 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes) 241 | def classifier(self, x): 242 | out = self.fc_layer(x) 243 | __, iden = torch.max(out, dim = 1) 244 | iden = iden.view(-1, 1) 245 | return out, iden 246 | 247 | def forward(self, x): 248 | feature = self.feature(x) 249 | feature = feature.view(feature.size(0), -1) 250 | out, iden = self.classifier(feature) 251 | return feature, out 252 | 253 | class VGG19_xray8(nn.Module): 254 | def __init__(self, num_classes): 255 | super(VGG19_xray8, self).__init__() 256 | model = torchvision.models.vgg19_bn(pretrained=True) 257 | model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 258 | self.feature = model.features 259 | self.feat_dim = 2048 260 | self.num_classes = num_classes 261 | self.bn = nn.BatchNorm1d(self.feat_dim) 262 | self.bn.bias.requires_grad_(False) # no shift 263 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 264 | self.model = model 265 | 266 | def classifier(self, x): 267 | out = self.fc_layer(x) 268 | __, iden = torch.max(out, dim = 1) 269 | iden = iden.view(-1, 1) 270 | return out, iden 271 | 272 | def forward(self, x): 273 | feature = self.feature(x) 274 | # print(feature.shape) 275 | feature = feature.view(feature.size(0), -1) 276 | out, iden = self.classifier(feature) 277 | return feature, out 278 | 279 | 280 | class EfficientNet_b0(nn.Module): 281 | def __init__(self, n_classes): 282 | super(EfficientNet_b0, self).__init__() 283 | model = torchvision.models.efficientnet.efficientnet_b0(pretrained=True) 284 | self.feature = nn.Sequential(*list(model.children())[:-1]) 285 | self.n_classes = n_classes 286 | self.feat_dim = 1280 287 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 288 | 289 | def forward(self, x): 290 | feature = self.feature(x) 291 | feature = feature.view(feature.size(0), -1) 292 | res = self.fc_layer(feature) 293 | return feature,res 294 | 295 | def predict(self, x): 296 | feature = self.feature(x) 297 | feature = feature.view(feature.size(0), -1) 298 | res = self.fc_layer(feature) 299 | out = F.softmax(res, dim=1) 300 | 301 | return feature,out 302 | 303 | class EfficientNet_b1(nn.Module): 304 | def __init__(self, n_classes): 305 | super(EfficientNet_b1, self).__init__() 306 | model = torchvision.models.efficientnet.efficientnet_b1(pretrained=True) 307 | self.feature = nn.Sequential(*list(model.children())[:-1]) 308 | self.n_classes = n_classes 309 | self.feat_dim = 1280 310 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 311 | 312 | def forward(self, x): 313 | feature = self.feature(x) 314 | feature = feature.view(feature.size(0), -1) 315 | res = self.fc_layer(feature) 316 | return feature,res 317 | 318 | def predict(self, x): 319 | feature = self.feature(x) 320 | feature = feature.view(feature.size(0), -1) 321 | res = self.fc_layer(feature) 322 | out = F.softmax(res, dim=1) 323 | 324 | return feature,out 325 | 326 | class EfficientNet_b2(nn.Module): 327 | def __init__(self, n_classes): 328 | super(EfficientNet_b2, self).__init__() 329 | model = torchvision.models.efficientnet.efficientnet_b2(pretrained=True) 330 | self.feature = nn.Sequential(*list(model.children())[:-1]) 331 | self.n_classes = n_classes 332 | self.feat_dim = 1408 333 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 334 | 335 | def forward(self, x): 336 | feature = self.feature(x) 337 | feature = feature.view(feature.size(0), -1) 338 | res = self.fc_layer(feature) 339 | return feature,res 340 | 341 | def predict(self, x): 342 | feature = self.feature(x) 343 | feature = feature.view(feature.size(0), -1) 344 | res = self.fc_layer(feature) 345 | out = F.softmax(res, dim=1) 346 | 347 | return feature,out 348 | 349 | class EfficientNet_v2_s2(nn.Module): 350 | def __init__(self, n_classes,dataset='celeba'): 351 | super(EfficientNet_v2_s2, self).__init__() 352 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True) 353 | self.feature = nn.Sequential(*list(model.children())[:-1]) 354 | self.n_classes = n_classes 355 | self.feat_dim = 1028 356 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 357 | 358 | def forward(self, x): 359 | feature = self.feature(x) 360 | feature = feature.view(feature.size(0), -1) 361 | res = self.fc_layer(feature) 362 | return feature,res 363 | 364 | def predict(self, x): 365 | feature = self.feature(x) 366 | 367 | feature = feature.view(feature.size(0), -1) 368 | res = self.fc_layer(feature) 369 | out = F.softmax(res, dim=1) 370 | 371 | return feature,out 372 | 373 | class EfficientNet_v2_m2(nn.Module): 374 | def __init__(self, n_classes,dataset='celeba'): 375 | super(EfficientNet_v2_m2, self).__init__() 376 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True) 377 | 378 | self.feature = nn.Sequential(*list(model.children())[:-1]) 379 | self.n_classes = n_classes 380 | self.feat_dim = 1028 381 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 382 | 383 | def forward(self, x): 384 | feature = self.feature(x) 385 | feature = feature.view(feature.size(0), -1) 386 | res = self.fc_layer(feature) 387 | return feature,res 388 | 389 | def predict(self, x): 390 | feature = self.feature(x) 391 | feature = feature.view(feature.size(0), -1) 392 | res = self.fc_layer(feature) 393 | out = F.softmax(res, dim=1) 394 | 395 | return feature,out 396 | 397 | class EfficientNet_v2_l2(nn.Module): 398 | def __init__(self, n_classes, dataset='celeba'): 399 | super(EfficientNet_v2_l2, self).__init__() 400 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True) 401 | self.feature = nn.Sequential(*list(model.children())[:-1]) 402 | self.n_classes = n_classes 403 | self.feat_dim = 1028 404 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 405 | 406 | def forward(self, x): 407 | feature = self.feature(x) 408 | feature = feature.view(feature.size(0), -1) 409 | res = self.fc_layer(feature) 410 | return feature,res 411 | 412 | def predict(self, x): 413 | feature = self.feature(x) 414 | feature = feature.view(feature.size(0), -1) 415 | res = self.fc_layer(feature) 416 | out = F.softmax(res, dim=1) 417 | 418 | return feature,out 419 | 420 | class EfficientNet_v2_s(nn.Module): 421 | def __init__(self, n_classes, dataset='celeba'): 422 | super(EfficientNet_v2_s, self).__init__() 423 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True) 424 | self.feature = nn.Sequential(*list(model.children())[:-2]) 425 | self.n_classes = n_classes 426 | self.feat_dim = 5120 427 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 428 | 429 | def forward(self, x): 430 | feature = self.feature(x) 431 | feature = feature.view(feature.size(0), -1) 432 | res = self.fc_layer(feature) 433 | return feature,res 434 | 435 | def predict(self, x): 436 | feature = self.feature(x) 437 | 438 | feature = feature.view(feature.size(0), -1) 439 | res = self.fc_layer(feature) 440 | out = F.softmax(res, dim=1) 441 | 442 | return feature,out 443 | 444 | class EfficientNet_v2_m(nn.Module): 445 | def __init__(self, n_classes, dataset='celeba'): 446 | super(EfficientNet_v2_m, self).__init__() 447 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True) 448 | 449 | self.feature = nn.Sequential(*list(model.children())[:-2]) 450 | self.n_classes = n_classes 451 | self.feat_dim = 5120 452 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 453 | 454 | def forward(self, x): 455 | feature = self.feature(x) 456 | feature = feature.view(feature.size(0), -1) 457 | res = self.fc_layer(feature) 458 | return feature,res 459 | 460 | def predict(self, x): 461 | feature = self.feature(x) 462 | feature = feature.view(feature.size(0), -1) 463 | res = self.fc_layer(feature) 464 | out = F.softmax(res, dim=1) 465 | 466 | return feature,out 467 | 468 | class EfficientNet_v2_l(nn.Module): 469 | def __init__(self, n_classes, dataset='celeba'): 470 | super(EfficientNet_v2_l, self).__init__() 471 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True) 472 | self.feature = nn.Sequential(*list(model.children())[:-2]) 473 | self.n_classes = n_classes 474 | self.feat_dim = 5120 475 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 476 | 477 | def forward(self, x): 478 | feature = self.feature(x) 479 | feature = feature.view(feature.size(0), -1) 480 | res = self.fc_layer(feature) 481 | return feature,res 482 | 483 | def predict(self, x): 484 | feature = self.feature(x) 485 | feature = feature.view(feature.size(0), -1) 486 | res = self.fc_layer(feature) 487 | out = F.softmax(res, dim=1) 488 | 489 | return feature,out 490 | 491 | class efficientNet_v2_l_xray(nn.Module): 492 | def __init__(self, n_classes): 493 | super(efficientNet_v2_l_xray, self).__init__() 494 | model = torchvision.models.efficientnet.efficientnet_v2_l(pretrained=True) 495 | model.features[0][0] =nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 496 | self.feature = model.features 497 | self.n_classes = n_classes 498 | self.feat_dim = 5120 499 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 500 | 501 | def forward(self, x): 502 | feature = self.feature(x) 503 | feature = feature.view(feature.size(0), -1) 504 | res = self.fc_layer(feature) 505 | return feature,res 506 | 507 | def predict(self, x): 508 | feature = self.feature(x) 509 | feature = feature.view(feature.size(0), -1) 510 | res = self.fc_layer(feature) 511 | out = F.softmax(res, dim=1) 512 | 513 | return feature,out 514 | 515 | class efficientnet_v2_m_xray(nn.Module): 516 | def __init__(self, n_classes): 517 | super(efficientnet_v2_m_xray, self).__init__() 518 | model = torchvision.models.efficientnet.efficientnet_v2_m(pretrained=True) 519 | 520 | model.features[0][0] =nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 521 | self.feature = model.features 522 | self.n_classes = n_classes 523 | self.feat_dim = 5120 524 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 525 | 526 | def forward(self, x): 527 | feature = self.feature(x) 528 | feature = feature.view(feature.size(0), -1) 529 | res = self.fc_layer(feature) 530 | return feature,res 531 | 532 | def predict(self, x): 533 | feature = self.feature(x) 534 | feature = feature.view(feature.size(0), -1) 535 | res = self.fc_layer(feature) 536 | out = F.softmax(res, dim=1) 537 | 538 | return feature,out 539 | 540 | class efficientnet_v2_s_xray(nn.Module): 541 | def __init__(self, n_classes): 542 | super(efficientnet_v2_s_xray, self).__init__() 543 | model = torchvision.models.efficientnet.efficientnet_v2_s(pretrained=True) 544 | model.features[0][0] =nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 545 | 546 | self.feature = model.features 547 | self.n_classes = n_classes 548 | self.feat_dim = 5120 549 | self.fc_layer = nn.Linear(self.feat_dim, self.n_classes) 550 | 551 | def forward(self, x): 552 | feature = self.feature(x) 553 | feature = feature.view(feature.size(0), -1) 554 | res = self.fc_layer(feature) 555 | return feature,res 556 | 557 | def predict(self, x): 558 | feature = self.feature(x) 559 | feature = feature.view(feature.size(0), -1) 560 | res = self.fc_layer(feature) 561 | out = F.softmax(res, dim=1) 562 | 563 | return feature,out 564 | 565 | 566 | class ResNet18(nn.Module): 567 | def __init__(self, num_of_classes): 568 | super(ResNet18, self).__init__() 569 | model = torchvision.models.resnet18(pretrained=True) 570 | self.feature = nn.Sequential(*list(model.children())[:-2]) 571 | self.feat_dim = 2048 * 1 * 1 572 | self.num_of_classes = num_of_classes 573 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes) 574 | 575 | def classifier(self, x): 576 | out = self.fc_layer(x) 577 | __, iden = torch.max(out, dim = 1) 578 | iden = iden.view(-1, 1) 579 | return out, iden 580 | 581 | def forward(self, x): 582 | feature = self.feature(x) 583 | feature = feature.view(feature.size(0), -1) 584 | out, iden = self.classifier(feature) 585 | return feature, out 586 | 587 | class ResNet34(nn.Module): 588 | def __init__(self, num_of_classes): 589 | super(ResNet34, self).__init__() 590 | model = torchvision.models.resnet34(pretrained=True) 591 | self.feature = nn.Sequential(*list(model.children())[:-2]) 592 | self.feat_dim = 2048 * 1 * 1 593 | self.num_of_classes = num_of_classes 594 | 595 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes) 596 | 597 | def classifier(self, x): 598 | out = self.fc_layer(x) 599 | __, iden = torch.max(out, dim = 1) 600 | iden = iden.view(-1, 1) 601 | return out, iden 602 | 603 | def forward(self, x): 604 | feature = self.feature(x) 605 | feature = feature.view(feature.size(0), -1) 606 | out, iden = self.classifier(feature) 607 | return feature, out 608 | 609 | class ResNet34_xray8(nn.Module): 610 | def __init__(self, num_of_classes): 611 | super(ResNet34_xray8, self).__init__() 612 | model = torchvision.models.resnet34(pretrained=True) 613 | model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 614 | 615 | self.feature = nn.Sequential(*list(model.children())[:-2]) 616 | self.feat_dim = 2048 * 1 * 1 617 | self.num_of_classes = num_of_classes 618 | 619 | self.fc_layer = nn.Linear(self.feat_dim, self.num_of_classes) 620 | 621 | def classifier(self, x): 622 | out = self.fc_layer(x) 623 | __, iden = torch.max(out, dim = 1) 624 | iden = iden.view(-1, 1) 625 | return out, iden 626 | 627 | def forward(self, x): 628 | feature = self.feature(x) 629 | feature = feature.view(feature.size(0), -1) 630 | out, iden = self.classifier(feature) 631 | return feature, out 632 | 633 | 634 | class Mobilenet_v3_small(nn.Module): 635 | def __init__(self, num_of_classes): 636 | super(Mobilenet_v3_small, self).__init__() 637 | model = torchvision.models.mobilenet_v3_small(pretrained=True) 638 | self.feature = nn.Sequential(*list(model.children())[:-2]) 639 | self.feat_dim = 2304 640 | self.num_of_classes = num_of_classes 641 | self.fc_layer = nn.Sequential( 642 | nn.Linear(self.feat_dim, self.num_of_classes),) 643 | 644 | def classifier(self, x): 645 | out = self.fc_layer(x) 646 | __, iden = torch.max(out, dim = 1) 647 | iden = iden.view(-1, 1) 648 | return out, iden 649 | 650 | def forward(self, x): 651 | feature = self.feature(x) 652 | feature = feature.view(feature.size(0), -1) 653 | out, iden = self.classifier(feature) 654 | return feature,out 655 | 656 | class Mobilenet_v2(nn.Module): 657 | def __init__(self, num_of_classes): 658 | super(Mobilenet_v2, self).__init__() 659 | model = torchvision.models.mobilenet_v2(pretrained=True) 660 | self.feature = nn.Sequential(*list(model.children())[:-2]) 661 | self.feat_dim = 12288 662 | self.num_of_classes = num_of_classes 663 | self.fc_layer = nn.Sequential( 664 | nn.Linear(self.feat_dim, self.num_of_classes),) 665 | 666 | def classifier(self, x): 667 | out = self.fc_layer(x) 668 | __, iden = torch.max(out, dim = 1) 669 | iden = iden.view(-1, 1) 670 | return out, iden 671 | 672 | def forward(self, x): 673 | feature = self.feature(x) 674 | feature = feature.view(feature.size(0), -1) 675 | out, iden = self.classifier(feature) 676 | return feature,out 677 | 678 | class EvolveFace(nn.Module): 679 | def __init__(self, num_of_classes, IR152): 680 | super(EvolveFace, self).__init__() 681 | if IR152: 682 | model = evolve.IR_152_64((64,64)) 683 | else: 684 | model = evolve.IR_50_64((64,64)) 685 | self.model = model 686 | self.feat_dim = 512 687 | self.num_classes = num_of_classes 688 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 689 | nn.Dropout(p=0.15), 690 | Flatten(), 691 | nn.Linear(512 * 4 * 4, 512), 692 | nn.BatchNorm1d(512)) 693 | 694 | self.fc_layer = nn.Sequential( 695 | nn.Linear(self.feat_dim, self.num_classes),) 696 | 697 | 698 | def classifier(self, x): 699 | out = self.fc_layer(x) 700 | __, iden = torch.max(out, dim = 1) 701 | iden = iden.view(-1, 1) 702 | return out, iden 703 | 704 | def forward(self,x): 705 | feature = self.model(x) 706 | feature = self.output_layer(feature) 707 | feature = feature.view(feature.size(0), -1) 708 | out, iden = self.classifier(feature) 709 | 710 | return out 711 | 712 | class FaceNet(nn.Module): 713 | def __init__(self, num_classes=1000): 714 | super(FaceNet, self).__init__() 715 | self.feature = evolve.IR_50_112((112, 112)) 716 | self.feat_dim = 512 717 | self.num_classes = num_classes 718 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 719 | 720 | def predict(self, x): 721 | feat = self.feature(x) 722 | feat = feat.view(feat.size(0), -1) 723 | out = self.fc_layer(feat) 724 | return out 725 | 726 | def forward(self, x): 727 | feat = self.feature(x) 728 | feat = feat.view(feat.size(0), -1) 729 | out = self.fc_layer(feat) 730 | return [feat, out] 731 | 732 | class FaceNet64(nn.Module): 733 | def __init__(self, num_classes = 1000): 734 | super(FaceNet64, self).__init__() 735 | self.feature = evolve.IR_50_64((64, 64)) 736 | self.feat_dim = 512 737 | self.num_classes = num_classes 738 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 739 | nn.Dropout(), 740 | Flatten(), 741 | nn.Linear(512 * 4 * 4, 512), 742 | nn.BatchNorm1d(512)) 743 | 744 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 745 | 746 | def forward(self, x): 747 | feat = self.feature(x) 748 | feat = self.output_layer(feat) 749 | feat = feat.view(feat.size(0), -1) 750 | out = self.fc_layer(feat) 751 | __, iden = torch.max(out, dim=1) 752 | iden = iden.view(-1, 1) 753 | return feat, out 754 | 755 | 756 | class IR152(nn.Module): 757 | def __init__(self, num_classes=1000): 758 | super(IR152, self).__init__() 759 | self.feature = evolve.IR_152_64((64, 64)) 760 | self.feat_dim = 512 761 | self.num_classes = num_classes 762 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 763 | nn.Dropout(), 764 | Flatten(), 765 | nn.Linear(512 * 4 * 4, 512), 766 | nn.BatchNorm1d(512)) 767 | 768 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 769 | 770 | def forward(self, x): 771 | feat = self.feature(x) 772 | feat = self.output_layer(feat) 773 | feat = feat.view(feat.size(0), -1) 774 | out = self.fc_layer(feat) 775 | return feat, out 776 | 777 | class IR152_vib(nn.Module): 778 | def __init__(self, num_classes=1000): 779 | super(IR152_vib, self).__init__() 780 | self.feature = evolve.IR_152_64((64, 64)) 781 | self.feat_dim = 512 782 | self.k = self.feat_dim // 2 783 | self.n_classes = num_classes 784 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 785 | nn.Dropout(), 786 | Flatten(), 787 | nn.Linear(512 * 4 * 4, 512), 788 | nn.BatchNorm1d(512)) 789 | 790 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 791 | self.fc_layer = nn.Sequential( 792 | nn.Linear(self.k, self.n_classes), 793 | nn.Softmax(dim = 1)) 794 | 795 | def forward(self, x): 796 | feature = self.output_layer(self.feature(x)) 797 | feature = feature.view(feature.size(0), -1) 798 | statis = self.st_layer(feature) 799 | mu, std = statis[:, :self.k], statis[:, self.k:] 800 | 801 | std = F.softplus(std-5, beta=1) 802 | eps = torch.FloatTensor(std.size()).normal_().cuda() 803 | res = mu + std * eps 804 | out = self.fc_layer(res) 805 | __, iden = torch.max(out, dim=1) 806 | iden = iden.view(-1, 1) 807 | 808 | return feature, out, iden, mu#, st 809 | 810 | 811 | class IR50(nn.Module): 812 | def __init__(self, num_classes=1000): 813 | super(IR50, self).__init__() 814 | self.feature = evolve.IR_50_64((64, 64)) 815 | self.feat_dim = 512 816 | self.num_classes = num_classes 817 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 818 | nn.Dropout(), 819 | Flatten(), 820 | nn.Linear(512 * 4 * 4, 512), 821 | nn.BatchNorm1d(512)) 822 | 823 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 824 | self.fc_layer = nn.Sequential( 825 | nn.Linear(self.k, self.n_classes), 826 | nn.Softmax(dim = 1)) 827 | 828 | def forward(self, x): 829 | feature = self.output_layer(self.feature(x)) 830 | feature = feature.view(feature.size(0), -1) 831 | statis = self.st_layer(feature) 832 | mu, std = statis[:, :self.k], statis[:, self.k:] 833 | 834 | std = F.softplus(std-5, beta=1) 835 | eps = torch.FloatTensor(std.size()).normal_().cuda() 836 | res = mu + std * eps 837 | out = self.fc_layer(res) 838 | __, iden = torch.max(out, dim=1) 839 | iden = iden.view(-1, 1) 840 | 841 | return feature, out, iden, mu, std 842 | 843 | class IR50_vib(nn.Module): 844 | def __init__(self, num_classes=1000): 845 | super(IR50_vib, self).__init__() 846 | self.feature = evolve.IR_50_64((64, 64)) 847 | self.feat_dim = 512 848 | self.n_classes = num_classes 849 | self.k = self.feat_dim // 2 850 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 851 | nn.Dropout(), 852 | Flatten(), 853 | nn.Linear(512 * 4 * 4, 512), 854 | nn.BatchNorm1d(512)) 855 | 856 | self.st_layer = nn.Linear(self.feat_dim, self.k * 2) 857 | self.fc_layer = nn.Sequential( 858 | nn.Linear(self.k, self.n_classes), 859 | nn.Softmax(dim=1)) 860 | 861 | def forward(self, x): 862 | feat = self.output_layer(self.feature(x)) 863 | feat = feat.view(feat.size(0), -1) 864 | statis = self.st_layer(feat) 865 | mu, std = statis[:, :self.k], statis[:, self.k:] 866 | 867 | std = F.softplus(std-5, beta=1) 868 | eps = torch.FloatTensor(std.size()).normal_().cuda() 869 | res = mu + std * eps 870 | out = self.fc_layer(res) 871 | __, iden = torch.max(out, dim=1) 872 | iden = iden.view(-1, 1) 873 | 874 | return feat, out, iden, mu, std 875 | 876 | 877 | 878 | def get_classifier(model_name, mode, n_classes, resume_path): 879 | if model_name == "VGG16": 880 | if mode == "reg": 881 | net = VGG16(n_classes) 882 | elif mode == "vib": 883 | net = VGG16_vib(n_classes) 884 | 885 | elif model_name == "FaceNet": 886 | net = FaceNet(n_classes) 887 | 888 | elif model_name == "FaceNet_all": 889 | net = FaceNet(202599) 890 | 891 | elif model_name == "FaceNet64": 892 | net = FaceNet64(n_classes) 893 | 894 | elif model_name == "IR50": 895 | if mode == "reg": 896 | net = IR50(n_classes) 897 | elif mode == "vib": 898 | net = IR50_vib(n_classes) 899 | 900 | elif model_name == "IR152": 901 | if mode == "reg": 902 | net = IR152(n_classes) 903 | else: 904 | net = IR152_vib(n_classes) 905 | 906 | else: 907 | print("Model name Error") 908 | exit() 909 | 910 | if model_name in ['FaceNet', 'FaceNet_all', 'FaceNet_64', 'IR50', 'IR152']: 911 | if resume_path is not "": 912 | print("Resume") 913 | utils.load_state_dict(net.feature, torch.load(resume_path)) 914 | else: 915 | print("No Resume") 916 | 917 | return net 918 | -------------------------------------------------------------------------------- /models/discri.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from utils import LinearWeightNorm 5 | import torch.nn.init as init 6 | 7 | class MinibatchDiscriminator_MNIST(nn.Module): 8 | def __init__(self,in_dim=3, dim=64, n_classes=1000): 9 | super(MinibatchDiscriminator_MNIST, self).__init__() 10 | self.n_classes = n_classes 11 | 12 | def conv_ln_lrelu(in_dim, out_dim, k, s, p): 13 | return nn.Sequential( 14 | nn.Conv2d(in_dim, out_dim, k, s, p), 15 | # Since there is no effective implementation of LayerNorm, 16 | # we use InstanceNorm2d instead of LayerNorm here. 17 | nn.InstanceNorm2d(out_dim, affine=True), 18 | nn.LeakyReLU(0.2)) 19 | 20 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2) 21 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2) 22 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2) 23 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1) 24 | self.mbd1 = MinibatchDiscrimination(dim*4*4, 64, 50) 25 | self.fc_layer = nn.Linear(dim*4*4+64, self.n_classes) 26 | 27 | def forward(self, x): 28 | out = [] 29 | bs = x.shape[0] 30 | feat1 = self.layer1(x) 31 | out.append(feat1) 32 | feat2 = self.layer2(feat1) 33 | out.append(feat2) 34 | feat3 = self.layer3(feat2) 35 | out.append(feat3) 36 | feat4 = self.layer4(feat3) 37 | out.append(feat4) 38 | feat = feat4.view(bs, -1) 39 | # print('feat:', feat.shape) 40 | mb_out = self.mbd1(feat) # Nx(A+B) 41 | y = self.fc_layer(mb_out) 42 | 43 | return feat, y 44 | # return mb_out, y 45 | 46 | 47 | class MinibatchDiscrimination(nn.Module): 48 | def __init__(self, in_features, out_features, kernel_dims, mean=False): 49 | super().__init__() 50 | self.in_features = in_features 51 | self.out_features = out_features 52 | self.kernel_dims = kernel_dims 53 | self.mean = mean 54 | self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims)) 55 | init.normal(self.T, 0, 1) 56 | 57 | def forward(self, x): 58 | # x is NxA 59 | # T is AxBxC 60 | matrices = x.mm(self.T.view(self.in_features, -1)) 61 | matrices = matrices.view(-1, self.out_features, self.kernel_dims) 62 | 63 | M = matrices.unsqueeze(0) # 1xNxBxC 64 | M_T = M.permute(1, 0, 2, 3) # Nx1xBxC 65 | norm = torch.abs(M - M_T).sum(3) # NxNxB 66 | expnorm = torch.exp(-norm) 67 | o_b = (expnorm.sum(0) - 1) # NxB, subtract self distance 68 | if self.mean: 69 | o_b /= x.size(0) - 1 70 | 71 | x = torch.cat([x, o_b], 1) 72 | return x 73 | 74 | class MinibatchDiscriminator(nn.Module): 75 | def __init__(self,in_dim=3, dim=64, n_classes=1000): 76 | super(MinibatchDiscriminator, self).__init__() 77 | self.n_classes = n_classes 78 | 79 | def conv_ln_lrelu(in_dim, out_dim, k, s, p): 80 | return nn.Sequential( 81 | nn.Conv2d(in_dim, out_dim, k, s, p), 82 | # Since there is no effective implementation of LayerNorm, 83 | # we use InstanceNorm2d instead of LayerNorm here. 84 | nn.InstanceNorm2d(out_dim, affine=True), 85 | nn.LeakyReLU(0.2)) 86 | 87 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2) 88 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2) 89 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2) 90 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1) 91 | self.mbd1 = MinibatchDiscrimination(dim*4*4*4, 64, 50) 92 | self.fc_layer = nn.Linear(dim*4*4*4+64, self.n_classes) 93 | 94 | def forward(self, x): 95 | out = [] 96 | bs = x.shape[0] 97 | feat1 = self.layer1(x) 98 | out.append(feat1) 99 | feat2 = self.layer2(feat1) 100 | out.append(feat2) 101 | feat3 = self.layer3(feat2) 102 | out.append(feat3) 103 | feat4 = self.layer4(feat3) 104 | out.append(feat4) 105 | feat = feat4.view(bs, -1) 106 | # print('feat:', feat.shape) 107 | mb_out = self.mbd1(feat) # Nx(A+B) 108 | y = self.fc_layer(mb_out) 109 | 110 | return feat, y 111 | # return mb_out, y 112 | 113 | 114 | class Discriminator(nn.Module): 115 | def __init__(self,in_dim=3, dim=64, n_classes=1000): 116 | super(Discriminator, self).__init__() 117 | self.n_classes = n_classes 118 | 119 | def conv_ln_lrelu(in_dim, out_dim, k, s, p): 120 | return nn.Sequential( 121 | nn.Conv2d(in_dim, out_dim, k, s, p), 122 | # Since there is no effective implementation of LayerNorm, 123 | # we use InstanceNorm2d instead of LayerNorm here. 124 | nn.InstanceNorm2d(out_dim, affine=True), 125 | nn.LeakyReLU(0.2)) 126 | 127 | self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2) 128 | self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2) 129 | self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2) 130 | self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1) 131 | self.fc_layer = nn.Linear(dim*4*4*4, self.n_classes) 132 | 133 | def forward(self, x): 134 | bs = x.shape[0] 135 | feat1 = self.layer1(x) 136 | feat2 = self.layer2(feat1) 137 | feat3 = self.layer3(feat2) 138 | feat4 = self.layer4(feat3) 139 | feat = feat4.view(bs, -1) 140 | y = self.fc_layer(feat) 141 | 142 | return feat, y 143 | 144 | 145 | class DiscriminatorMNIST(nn.Module): 146 | def __init__(self, d_input_dim=1024): 147 | super(DiscriminatorMNIST, self).__init__() 148 | self.fc1 = nn.Linear(d_input_dim, 1024) 149 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2) 150 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2) 151 | self.fc4 = nn.Linear(self.fc3.out_features, 1) 152 | 153 | # forward method 154 | def forward(self, x): 155 | x = x.view(x.size(0), -1) 156 | x = F.leaky_relu(self.fc1(x), 0.2) 157 | x = F.dropout(x, 0.3) 158 | x = F.leaky_relu(self.fc2(x), 0.2) 159 | x = F.dropout(x, 0.3) 160 | x = F.leaky_relu(self.fc3(x), 0.2) 161 | x = F.dropout(x, 0.3) 162 | y = self.fc4(x) 163 | y = y.view(-1) 164 | 165 | return y 166 | 167 | class DGWGAN32(nn.Module): 168 | def __init__(self, in_dim=1, dim=64): 169 | super(DGWGAN32, self).__init__() 170 | def conv_ln_lrelu(in_dim, out_dim): 171 | return nn.Sequential( 172 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 173 | # Since there is no effective implementation of LayerNorm, 174 | # we use InstanceNorm2d instead of LayerNorm here. 175 | nn.InstanceNorm2d(out_dim, affine=True), 176 | nn.LeakyReLU(0.2)) 177 | 178 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2)) 179 | self.layer2 = conv_ln_lrelu(dim, dim * 2) 180 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4) 181 | self.layer4 = nn.Conv2d(dim * 4, 1, 4) 182 | 183 | def forward(self, x): 184 | feat1 = self.layer1(x) 185 | feat2 = self.layer2(feat1) 186 | feat3 = self.layer3(feat2) 187 | y = self.layer4(feat3) 188 | y = y.view(-1) 189 | return y 190 | 191 | class DGWGAN(nn.Module): 192 | def __init__(self, in_dim=3, dim=64): 193 | super(DGWGAN, self).__init__() 194 | def conv_ln_lrelu(in_dim, out_dim): 195 | return nn.Sequential( 196 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 197 | # Since there is no effective implementation of LayerNorm, 198 | # we use InstanceNorm2d instead of LayerNorm here. 199 | nn.InstanceNorm2d(out_dim, affine=True), 200 | nn.LeakyReLU(0.2)) 201 | 202 | self.ls = nn.Sequential( 203 | nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2), 204 | conv_ln_lrelu(dim, dim * 2), 205 | conv_ln_lrelu(dim * 2, dim * 4), 206 | conv_ln_lrelu(dim * 4, dim * 8), 207 | nn.Conv2d(dim * 8, 1, 4)) 208 | 209 | def forward(self, x): 210 | y = self.ls(x) 211 | y = y.view(-1) 212 | return y 213 | 214 | class DLWGAN(nn.Module): 215 | def __init__(self, in_dim=3, dim=64): 216 | super(DLWGAN, self).__init__() 217 | 218 | def conv_ln_lrelu(in_dim, out_dim): 219 | return nn.Sequential( 220 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 221 | # Since there is no effective implementation of LayerNorm, 222 | # we use InstanceNorm2d instead of LayerNorm here. 223 | nn.InstanceNorm2d(out_dim, affine=True), 224 | nn.LeakyReLU(0.2)) 225 | 226 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2)) 227 | self.layer2 = conv_ln_lrelu(dim, dim * 2) 228 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4) 229 | self.layer4 = nn.Conv2d(dim * 4, 1, 4) 230 | 231 | 232 | def forward(self, x): 233 | feat1 = self.layer1(x) 234 | feat2 = self.layer2(feat1) 235 | feat3 = self.layer3(feat2) 236 | y = self.layer4(feat3) 237 | return y 238 | 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /models/evolve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \ 4 | AdaptiveAvgPool2d, Sequential, Module 5 | from collections import namedtuple 6 | 7 | 8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152'] 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | 20 | return output 21 | 22 | 23 | class SEModule(Module): 24 | def __init__(self, channels, reduction): 25 | super(SEModule, self).__init__() 26 | self.avg_pool = AdaptiveAvgPool2d(1) 27 | self.fc1 = Conv2d( 28 | channels, channels // reduction, kernel_size=1, padding=0, bias=False) 29 | 30 | nn.init.xavier_uniform_(self.fc1.weight.data) 31 | 32 | self.relu = ReLU(inplace=True) 33 | self.fc2 = Conv2d( 34 | channels // reduction, channels, kernel_size=1, padding=0, bias=False) 35 | 36 | self.sigmoid = Sigmoid() 37 | 38 | def forward(self, x): 39 | module_input = x 40 | x = self.avg_pool(x) 41 | x = self.fc1(x) 42 | x = self.relu(x) 43 | x = self.fc2(x) 44 | x = self.sigmoid(x) 45 | 46 | return module_input * x 47 | 48 | 49 | class bottleneck_IR(Module): 50 | def __init__(self, in_channel, depth, stride): 51 | super(bottleneck_IR, self).__init__() 52 | if in_channel == depth: 53 | self.shortcut_layer = MaxPool2d(1, stride) 54 | else: 55 | self.shortcut_layer = Sequential( 56 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 57 | self.res_layer = Sequential( 58 | BatchNorm2d(in_channel), 59 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 60 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) 61 | 62 | def forward(self, x): 63 | shortcut = self.shortcut_layer(x) 64 | res = self.res_layer(x) 65 | 66 | return res + shortcut 67 | 68 | 69 | class bottleneck_IR_SE(Module): 70 | def __init__(self, in_channel, depth, stride): 71 | super(bottleneck_IR_SE, self).__init__() 72 | if in_channel == depth: 73 | self.shortcut_layer = MaxPool2d(1, stride) 74 | else: 75 | self.shortcut_layer = Sequential( 76 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 77 | BatchNorm2d(depth)) 78 | self.res_layer = Sequential( 79 | BatchNorm2d(in_channel), 80 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 81 | PReLU(depth), 82 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 83 | BatchNorm2d(depth), 84 | SEModule(depth, 16) 85 | ) 86 | 87 | def forward(self, x): 88 | shortcut = self.shortcut_layer(x) 89 | res = self.res_layer(x) 90 | 91 | return res + shortcut 92 | 93 | 94 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 95 | '''A named tuple describing a ResNet block.''' 96 | 97 | 98 | def get_block(in_channel, depth, num_units, stride=2): 99 | 100 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 101 | 102 | 103 | def get_blocks(num_layers): 104 | if num_layers == 50: 105 | blocks = [ 106 | get_block(in_channel=64, depth=64, num_units=3), 107 | get_block(in_channel=64, depth=128, num_units=4), 108 | get_block(in_channel=128, depth=256, num_units=14), 109 | get_block(in_channel=256, depth=512, num_units=3) 110 | ] 111 | elif num_layers == 100: 112 | blocks = [ 113 | get_block(in_channel=64, depth=64, num_units=3), 114 | get_block(in_channel=64, depth=128, num_units=13), 115 | get_block(in_channel=128, depth=256, num_units=30), 116 | get_block(in_channel=256, depth=512, num_units=3) 117 | ] 118 | elif num_layers == 152: 119 | blocks = [ 120 | get_block(in_channel=64, depth=64, num_units=3), 121 | get_block(in_channel=64, depth=128, num_units=8), 122 | get_block(in_channel=128, depth=256, num_units=36), 123 | get_block(in_channel=256, depth=512, num_units=3) 124 | ] 125 | 126 | return blocks 127 | 128 | 129 | class Backbone64(Module): 130 | def __init__(self, input_size, num_layers, mode='ir'): 131 | super(Backbone64, self).__init__() 132 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]" 133 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 134 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 135 | blocks = get_blocks(num_layers) 136 | if mode == 'ir': 137 | unit_module = bottleneck_IR 138 | elif mode == 'ir_se': 139 | unit_module = bottleneck_IR_SE 140 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 141 | BatchNorm2d(64), 142 | PReLU(64)) 143 | 144 | modules = [] 145 | for block in blocks: 146 | for bottleneck in block: 147 | modules.append( 148 | unit_module(bottleneck.in_channel, 149 | bottleneck.depth, 150 | bottleneck.stride)) 151 | self.body = Sequential(*modules) 152 | 153 | self._initialize_weights() 154 | 155 | def forward(self, x): 156 | x = self.input_layer(x) 157 | x = self.body(x) 158 | #x = self.output_layer(x) 159 | 160 | return x 161 | 162 | def _initialize_weights(self): 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.xavier_uniform_(m.weight.data) 166 | if m.bias is not None: 167 | m.bias.data.zero_() 168 | elif isinstance(m, nn.BatchNorm2d): 169 | m.weight.data.fill_(1) 170 | m.bias.data.zero_() 171 | elif isinstance(m, nn.BatchNorm1d): 172 | m.weight.data.fill_(1) 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.Linear): 175 | nn.init.xavier_uniform_(m.weight.data) 176 | if m.bias is not None: 177 | m.bias.data.zero_() 178 | 179 | class Backbone112(Module): 180 | def __init__(self, input_size, num_layers, mode='ir'): 181 | super(Backbone112, self).__init__() 182 | assert input_size[0] in [64, 112, 224], "input_size should be [112, 112] or [224, 224]" 183 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 184 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 185 | blocks = get_blocks(num_layers) 186 | if mode == 'ir': 187 | unit_module = bottleneck_IR 188 | elif mode == 'ir_se': 189 | unit_module = bottleneck_IR_SE 190 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 191 | BatchNorm2d(64), 192 | PReLU(64)) 193 | 194 | if input_size[0] == 112: 195 | self.output_layer = Sequential(BatchNorm2d(512), 196 | Dropout(), 197 | Flatten(), 198 | Linear(512 * 7 * 7, 512), 199 | BatchNorm1d(512)) 200 | else: 201 | self.output_layer = Sequential(BatchNorm2d(512), 202 | Dropout(), 203 | Flatten(), 204 | Linear(512 * 14 * 14, 512), 205 | BatchNorm1d(512)) 206 | 207 | modules = [] 208 | for block in blocks: 209 | for bottleneck in block: 210 | modules.append( 211 | unit_module(bottleneck.in_channel, 212 | bottleneck.depth, 213 | bottleneck.stride)) 214 | self.body = Sequential(*modules) 215 | 216 | self._initialize_weights() 217 | 218 | def forward(self, x): 219 | x = self.input_layer(x) 220 | x = self.body(x) 221 | x = self.output_layer(x) 222 | 223 | return x 224 | 225 | def _initialize_weights(self): 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.xavier_uniform_(m.weight.data) 229 | if m.bias is not None: 230 | m.bias.data.zero_() 231 | elif isinstance(m, nn.BatchNorm2d): 232 | m.weight.data.fill_(1) 233 | m.bias.data.zero_() 234 | elif isinstance(m, nn.BatchNorm1d): 235 | m.weight.data.fill_(1) 236 | m.bias.data.zero_() 237 | elif isinstance(m, nn.Linear): 238 | nn.init.xavier_uniform_(m.weight.data) 239 | if m.bias is not None: 240 | m.bias.data.zero_() 241 | 242 | 243 | def IR_50_64(input_size): 244 | """Constructs a ir-50 model. 245 | """ 246 | model = Backbone64(input_size, 50, 'ir') 247 | 248 | return model 249 | 250 | def IR_50_112(input_size): 251 | """Constructs a ir-50 model. 252 | """ 253 | model = Backbone112(input_size, 50, 'ir') 254 | 255 | return model 256 | 257 | 258 | def IR_100(input_size): 259 | """Constructs a ir-100 model. 260 | """ 261 | model = Backbone(input_size, 100, 'ir') 262 | 263 | return model 264 | 265 | def IR_152_64(input_size): 266 | """Constructs a ir-152 model. 267 | """ 268 | model = Backbone64(input_size, 152, 'ir') 269 | 270 | return model 271 | 272 | 273 | def IR_152_112(input_size): 274 | """Constructs a ir-152 model. 275 | """ 276 | model = Backbone112(input_size, 152, 'ir') 277 | 278 | return model 279 | 280 | def IR_SE_50(input_size): 281 | """Constructs a ir_se-50 model. 282 | """ 283 | model = Backbone(input_size, 50, 'ir_se') 284 | 285 | return model 286 | 287 | 288 | def IR_SE_101(input_size): 289 | """Constructs a ir_se-101 model. 290 | """ 291 | model = Backbone(input_size, 100, 'ir_se') 292 | 293 | return model 294 | 295 | 296 | def IR_SE_152(input_size): 297 | """Constructs a ir_se-152 model. 298 | """ 299 | model = Backbone(input_size, 152, 'ir_se') 300 | 301 | return model -------------------------------------------------------------------------------- /models/facenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, \ 4 | AdaptiveAvgPool2d, Sequential, Module 5 | from collections import namedtuple 6 | 7 | 8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152'] 9 | class FaceNet(nn.Module): 10 | def __init__(self, num_classes = 1000): 11 | super(FaceNet, self).__init__() 12 | self.feature = IR_50_112((112, 112)) 13 | self.feat_dim = 512 14 | self.num_classes = num_classes 15 | self.fc_layer = nn.Sequential( 16 | nn.Linear(self.feat_dim, self.num_classes), 17 | nn.Softmax(dim = 1)) 18 | 19 | def forward(self, x): 20 | feat = self.feature(x) 21 | feat = feat.view(feat.size(0), -1) 22 | out = self.fc_layer(feat) 23 | __, iden = torch.max(out, dim = 1) 24 | iden = iden.view(-1, 1) 25 | return feat, out, iden 26 | 27 | class FaceNet64(nn.Module): 28 | def __init__(self, num_classes = 1000): 29 | super(FaceNet64, self).__init__() 30 | self.feature = IR_50_64((64, 64)) 31 | self.feat_dim = 512 32 | self.num_classes = num_classes 33 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 34 | nn.Dropout(), 35 | Flatten(), 36 | nn.Linear(512 * 4 * 4, 512), 37 | nn.BatchNorm1d(512)) 38 | 39 | self.fc_layer = nn.Sequential( 40 | nn.Linear(self.feat_dim, self.num_classes), 41 | nn.Softmax(dim = 1)) 42 | 43 | def forward(self, x): 44 | feat = self.feature(x) 45 | feat = self.output_layer(feat) 46 | feat = feat.view(feat.size(0), -1) 47 | out = self.fc_layer(feat) 48 | __, iden = torch.max(out, dim = 1) 49 | iden = iden.view(-1, 1) 50 | return feat, out, iden 51 | 52 | class Flatten(Module): 53 | def forward(self, input): 54 | return input.view(input.size(0), -1) 55 | 56 | 57 | def l2_norm(input, axis=1): 58 | norm = torch.norm(input, 2, axis, True) 59 | output = torch.div(input, norm) 60 | 61 | return output 62 | 63 | 64 | class SEModule(Module): 65 | def __init__(self, channels, reduction): 66 | super(SEModule, self).__init__() 67 | self.avg_pool = AdaptiveAvgPool2d(1) 68 | self.fc1 = Conv2d( 69 | channels, channels // reduction, kernel_size=1, padding=0, bias=False) 70 | 71 | nn.init.xavier_uniform_(self.fc1.weight.data) 72 | 73 | self.relu = ReLU(inplace=True) 74 | self.fc2 = Conv2d( 75 | channels // reduction, channels, kernel_size=1, padding=0, bias=False) 76 | 77 | self.sigmoid = Sigmoid() 78 | 79 | def forward(self, x): 80 | module_input = x 81 | x = self.avg_pool(x) 82 | x = self.fc1(x) 83 | x = self.relu(x) 84 | x = self.fc2(x) 85 | x = self.sigmoid(x) 86 | 87 | return module_input * x 88 | 89 | 90 | class bottleneck_IR(Module): 91 | def __init__(self, in_channel, depth, stride): 92 | super(bottleneck_IR, self).__init__() 93 | if in_channel == depth: 94 | self.shortcut_layer = MaxPool2d(1, stride) 95 | else: 96 | self.shortcut_layer = Sequential( 97 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 98 | self.res_layer = Sequential( 99 | BatchNorm2d(in_channel), 100 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 101 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) 102 | 103 | def forward(self, x): 104 | shortcut = self.shortcut_layer(x) 105 | res = self.res_layer(x) 106 | 107 | return res + shortcut 108 | 109 | 110 | class bottleneck_IR_SE(Module): 111 | def __init__(self, in_channel, depth, stride): 112 | super(bottleneck_IR_SE, self).__init__() 113 | if in_channel == depth: 114 | self.shortcut_layer = MaxPool2d(1, stride) 115 | else: 116 | self.shortcut_layer = Sequential( 117 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 118 | BatchNorm2d(depth)) 119 | self.res_layer = Sequential( 120 | BatchNorm2d(in_channel), 121 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 122 | PReLU(depth), 123 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 124 | BatchNorm2d(depth), 125 | SEModule(depth, 16) 126 | ) 127 | 128 | def forward(self, x): 129 | shortcut = self.shortcut_layer(x) 130 | res = self.res_layer(x) 131 | 132 | return res + shortcut 133 | 134 | 135 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 136 | '''A named tuple describing a ResNet block.''' 137 | 138 | 139 | def get_block(in_channel, depth, num_units, stride=2): 140 | 141 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 142 | 143 | 144 | def get_blocks(num_layers): 145 | if num_layers == 50: 146 | blocks = [ 147 | get_block(in_channel=64, depth=64, num_units=3), 148 | get_block(in_channel=64, depth=128, num_units=4), 149 | get_block(in_channel=128, depth=256, num_units=14), 150 | get_block(in_channel=256, depth=512, num_units=3) 151 | ] 152 | elif num_layers == 100: 153 | blocks = [ 154 | get_block(in_channel=64, depth=64, num_units=3), 155 | get_block(in_channel=64, depth=128, num_units=13), 156 | get_block(in_channel=128, depth=256, num_units=30), 157 | get_block(in_channel=256, depth=512, num_units=3) 158 | ] 159 | elif num_layers == 152: 160 | blocks = [ 161 | get_block(in_channel=64, depth=64, num_units=3), 162 | get_block(in_channel=64, depth=128, num_units=8), 163 | get_block(in_channel=128, depth=256, num_units=36), 164 | get_block(in_channel=256, depth=512, num_units=3) 165 | ] 166 | 167 | return blocks 168 | 169 | 170 | class Backbone64(Module): 171 | def __init__(self, input_size, num_layers, mode='ir'): 172 | super(Backbone64, self).__init__() 173 | assert input_size[0] in [64], "input_size should be [112, 112] or [224, 224]" 174 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 175 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 176 | blocks = get_blocks(num_layers) 177 | if mode == 'ir': 178 | unit_module = bottleneck_IR 179 | elif mode == 'ir_se': 180 | unit_module = bottleneck_IR_SE 181 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 182 | BatchNorm2d(64), 183 | PReLU(64)) 184 | 185 | self.output_layer = Sequential(BatchNorm2d(512), 186 | Dropout(), 187 | Flatten(), 188 | Linear(512 * 14 * 14, 512), 189 | BatchNorm1d(512)) 190 | 191 | modules = [] 192 | for block in blocks: 193 | for bottleneck in block: 194 | modules.append( 195 | unit_module(bottleneck.in_channel, 196 | bottleneck.depth, 197 | bottleneck.stride)) 198 | self.body = Sequential(*modules) 199 | 200 | self._initialize_weights() 201 | 202 | def forward(self, x): 203 | x = self.input_layer(x) 204 | x = self.body(x) 205 | 206 | return x 207 | 208 | def _initialize_weights(self): 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | nn.init.xavier_uniform_(m.weight.data) 212 | if m.bias is not None: 213 | m.bias.data.zero_() 214 | elif isinstance(m, nn.BatchNorm2d): 215 | m.weight.data.fill_(1) 216 | m.bias.data.zero_() 217 | elif isinstance(m, nn.BatchNorm1d): 218 | m.weight.data.fill_(1) 219 | m.bias.data.zero_() 220 | elif isinstance(m, nn.Linear): 221 | nn.init.xavier_uniform_(m.weight.data) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | class Backbone112(Module): 226 | def __init__(self, input_size, num_layers, mode='ir'): 227 | super(Backbone112, self).__init__() 228 | assert input_size[0] in [112], "input_size should be [112, 112] or [224, 224]" 229 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 230 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 231 | blocks = get_blocks(num_layers) 232 | if mode == 'ir': 233 | unit_module = bottleneck_IR 234 | elif mode == 'ir_se': 235 | unit_module = bottleneck_IR_SE 236 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 237 | BatchNorm2d(64), 238 | PReLU(64)) 239 | 240 | if input_size[0] == 112: 241 | self.output_layer = Sequential(BatchNorm2d(512), 242 | Dropout(), 243 | Flatten(), 244 | Linear(512 * 7 * 7, 512), 245 | BatchNorm1d(512)) 246 | 247 | modules = [] 248 | for block in blocks: 249 | for bottleneck in block: 250 | modules.append( 251 | unit_module(bottleneck.in_channel, 252 | bottleneck.depth, 253 | bottleneck.stride)) 254 | self.body = Sequential(*modules) 255 | 256 | self._initialize_weights() 257 | 258 | def forward(self, x): 259 | x = self.input_layer(x) 260 | x = self.body(x) 261 | x = self.output_layer(x) 262 | 263 | return x 264 | 265 | def _initialize_weights(self): 266 | for m in self.modules(): 267 | if isinstance(m, nn.Conv2d): 268 | nn.init.xavier_uniform_(m.weight.data) 269 | if m.bias is not None: 270 | m.bias.data.zero_() 271 | elif isinstance(m, nn.BatchNorm2d): 272 | m.weight.data.fill_(1) 273 | m.bias.data.zero_() 274 | elif isinstance(m, nn.BatchNorm1d): 275 | m.weight.data.fill_(1) 276 | m.bias.data.zero_() 277 | elif isinstance(m, nn.Linear): 278 | nn.init.xavier_uniform_(m.weight.data) 279 | if m.bias is not None: 280 | m.bias.data.zero_() 281 | 282 | 283 | def IR_50_64(input_size): 284 | """Constructs a ir-50 model. 285 | """ 286 | model = Backbone64(input_size, 50, 'ir') 287 | 288 | return model 289 | 290 | def IR_50_112(input_size): 291 | """Constructs a ir-50 model. 292 | """ 293 | model = Backbone112(input_size, 50, 'ir') 294 | 295 | return model 296 | 297 | 298 | def IR_101(input_size): 299 | """Constructs a ir-101 model. 300 | """ 301 | model = Backbone(input_size, 100, 'ir') 302 | 303 | return model 304 | 305 | 306 | def IR_152_64(input_size): 307 | """Constructs a ir-152 model. 308 | """ 309 | model = Backbone64(input_size, 152, 'ir') 310 | 311 | return model 312 | 313 | def IR_152_112(input_size): 314 | """Constructs a ir-152 model. 315 | """ 316 | model = Backbone112(input_size, 152, 'ir') 317 | 318 | return model 319 | 320 | 321 | def IR_SE_50(input_size): 322 | """Constructs a ir_se-50 model. 323 | """ 324 | model = Backbone(input_size, 50, 'ir_se') 325 | 326 | return model 327 | 328 | 329 | def IR_SE_101(input_size): 330 | """Constructs a ir_se-101 model. 331 | """ 332 | model = Backbone(input_size, 100, 'ir_se') 333 | 334 | return model 335 | 336 | 337 | def IR_SE_152(input_size): 338 | """Constructs a ir_se-152 model. 339 | """ 340 | model = Backbone(input_size, 152, 'ir_se') 341 | 342 | return model -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GeneratorCXR(nn.Module): 5 | def __init__(self, in_dim=100, dim=64): 6 | super(GeneratorCXR, self).__init__() 7 | def dconv_bn_relu(in_dim, out_dim): 8 | return nn.Sequential( 9 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 10 | padding=2, output_padding=1, bias=False), 11 | nn.BatchNorm2d(out_dim), 12 | nn.ReLU()) 13 | 14 | self.l1 = nn.Sequential( 15 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), 16 | nn.BatchNorm1d(dim * 8 * 4 * 4), 17 | nn.ReLU()) 18 | self.l2_5 = nn.Sequential( 19 | dconv_bn_relu(dim * 8, dim * 4), 20 | dconv_bn_relu(dim * 4, dim * 2), 21 | dconv_bn_relu(dim * 2, dim), 22 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1), 23 | nn.Sigmoid()) 24 | 25 | def forward(self, x): 26 | y = self.l1(x) 27 | y = y.view(y.size(0), -1, 4, 4) 28 | y = self.l2_5(y) 29 | return y 30 | 31 | class GeneratorMNIST(nn.Module): 32 | def __init__(self, in_dim=100, dim=64): 33 | super(GeneratorMNIST, self).__init__() 34 | def dconv_bn_relu(in_dim, out_dim): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 37 | padding=2, output_padding=1, bias=False), 38 | nn.BatchNorm2d(out_dim), 39 | nn.ReLU()) 40 | 41 | self.l1 = nn.Sequential( 42 | nn.Linear(in_dim, dim * 4 * 4 * 4, bias=False), 43 | nn.BatchNorm1d(dim * 4 * 4 * 4), 44 | nn.ReLU()) 45 | self.l2_5 = nn.Sequential( 46 | dconv_bn_relu(dim * 4, dim * 2), 47 | dconv_bn_relu(dim * 2, dim), 48 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1), 49 | nn.Sigmoid()) 50 | 51 | def forward(self, x): 52 | y = self.l1(x) 53 | y = y.view(y.size(0), -1, 4, 4) 54 | y = self.l2_5(y) 55 | return y 56 | 57 | class Generator(nn.Module): 58 | def __init__(self, in_dim=100, dim=64): 59 | super(Generator, self).__init__() 60 | def dconv_bn_relu(in_dim, out_dim): 61 | return nn.Sequential( 62 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 63 | padding=2, output_padding=1, bias=False), 64 | nn.BatchNorm2d(out_dim), 65 | nn.ReLU()) 66 | 67 | self.l1 = nn.Sequential( 68 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), 69 | nn.BatchNorm1d(dim * 8 * 4 * 4), 70 | nn.ReLU()) 71 | self.l2_5 = nn.Sequential( 72 | dconv_bn_relu(dim * 8, dim * 4), 73 | dconv_bn_relu(dim * 4, dim * 2), 74 | dconv_bn_relu(dim * 2, dim), 75 | nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1), 76 | nn.Sigmoid()) 77 | 78 | def forward(self, x): 79 | y = self.l1(x) 80 | y = y.view(y.size(0), -1, 4, 4) 81 | y = self.l2_5(y) 82 | return y 83 | 84 | class GeneratorMNIST(nn.Module): 85 | def __init__(self, in_dim=100, dim=64): 86 | super(GeneratorMNIST, self).__init__() 87 | def dconv_bn_relu(in_dim, out_dim): 88 | return nn.Sequential( 89 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 90 | padding=2, output_padding=1, bias=False), 91 | nn.BatchNorm2d(out_dim), 92 | nn.ReLU()) 93 | 94 | self.l1 = nn.Sequential( 95 | nn.Linear(in_dim, dim * 4 * 4 * 4, bias=False), 96 | nn.BatchNorm1d(dim * 4 * 4 * 4), 97 | nn.ReLU()) 98 | self.l2_5 = nn.Sequential( 99 | dconv_bn_relu(dim * 4, dim * 2), 100 | dconv_bn_relu(dim * 2, dim), 101 | nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1), 102 | nn.Sigmoid()) 103 | 104 | def forward(self, x): 105 | y = self.l1(x) 106 | y = y.view(y.size(0), -1, 4, 4) 107 | y = self.l2_5(y) 108 | return y 109 | 110 | class CompletionNetwork(nn.Module): 111 | def __init__(self): 112 | super(CompletionNetwork, self).__init__() 113 | # input_shape: (None, 4, img_h, img_w) 114 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2) 115 | self.bn1 = nn.BatchNorm2d(32) 116 | self.act1 = nn.ReLU() 117 | # input_shape: (None, 64, img_h, img_w) 118 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 119 | self.bn2 = nn.BatchNorm2d(64) 120 | self.act2 = nn.ReLU() 121 | # input_shape: (None, 128, img_h//2, img_w//2) 122 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 123 | self.bn3 = nn.BatchNorm2d(64) 124 | self.act3 = nn.ReLU() 125 | # input_shape: (None, 128, img_h//2, img_w//2) 126 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 127 | self.bn4 = nn.BatchNorm2d(128) 128 | self.act4 = nn.ReLU() 129 | # input_shape: (None, 256, img_h//4, img_w//4) 130 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 131 | self.bn5 = nn.BatchNorm2d(128) 132 | self.act5 = nn.ReLU() 133 | # input_shape: (None, 256, img_h//4, img_w//4) 134 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 135 | self.bn6 = nn.BatchNorm2d(128) 136 | self.act6 = nn.ReLU() 137 | # input_shape: (None, 256, img_h//4, img_w//4) 138 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2) 139 | self.bn7 = nn.BatchNorm2d(128) 140 | self.act7 = nn.ReLU() 141 | # input_shape: (None, 256, img_h//4, img_w//4) 142 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4) 143 | self.bn8 = nn.BatchNorm2d(128) 144 | self.act8 = nn.ReLU() 145 | # input_shape: (None, 256, img_h//4, img_w//4) 146 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8) 147 | self.bn9 = nn.BatchNorm2d(128) 148 | self.act9 = nn.ReLU() 149 | # input_shape: (None, 256, img_h//4, img_w//4) 150 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16) 151 | self.bn10 = nn.BatchNorm2d(128) 152 | self.act10 = nn.ReLU() 153 | # input_shape: (None, 256, img_h//4, img_w//4) 154 | self.conv11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 155 | self.bn11 = nn.BatchNorm2d(128) 156 | self.act11 = nn.ReLU() 157 | # input_shape: (None, 256, img_h//4, img_w//4) 158 | self.conv12 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 159 | self.bn12 = nn.BatchNorm2d(128) 160 | self.act12 = nn.ReLU() 161 | # input_shape: (None, 256, img_h//4, img_w//4) 162 | self.deconv13 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) 163 | self.bn13 = nn.BatchNorm2d(64) 164 | self.act13 = nn.ReLU() 165 | # input_shape: (None, 128, img_h//2, img_w//2) 166 | self.conv14 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 167 | self.bn14 = nn.BatchNorm2d(64) 168 | self.act14 = nn.ReLU() 169 | # input_shape: (None, 128, img_h//2, img_w//2) 170 | self.deconv15 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) 171 | self.bn15 = nn.BatchNorm2d(32) 172 | self.act15 = nn.ReLU() 173 | # input_shape: (None, 64, img_h, img_w) 174 | self.conv16 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 175 | self.bn16 = nn.BatchNorm2d(32) 176 | self.act16 = nn.ReLU() 177 | # input_shape: (None, 32, img_h, img_w) 178 | self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1) 179 | self.act17 = nn.Sigmoid() 180 | # output_shape: (None, 3, img_h. img_w) 181 | 182 | def forward(self, x): 183 | x = self.bn1(self.act1(self.conv1(x))) 184 | x = self.bn2(self.act2(self.conv2(x))) 185 | x = self.bn3(self.act3(self.conv3(x))) 186 | x = self.bn4(self.act4(self.conv4(x))) 187 | x = self.bn5(self.act5(self.conv5(x))) 188 | x = self.bn6(self.act6(self.conv6(x))) 189 | x = self.bn7(self.act7(self.conv7(x))) 190 | x = self.bn8(self.act8(self.conv8(x))) 191 | x = self.bn9(self.act9(self.conv9(x))) 192 | x = self.bn10(self.act10(self.conv10(x))) 193 | x = self.bn11(self.act11(self.conv11(x))) 194 | x = self.bn12(self.act12(self.conv12(x))) 195 | x = self.bn13(self.act13(self.deconv13(x))) 196 | x = self.bn14(self.act14(self.conv14(x))) 197 | x = self.bn15(self.act15(self.deconv15(x))) 198 | x = self.bn16(self.act16(self.conv16(x))) 199 | x = self.act17(self.conv17(x)) 200 | return x 201 | 202 | def dconv_bn_relu(in_dim, out_dim): 203 | return nn.Sequential( 204 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 205 | padding=2, output_padding=1, bias=False), 206 | nn.BatchNorm2d(out_dim), 207 | nn.ReLU()) 208 | 209 | class ContextNetwork(nn.Module): 210 | def __init__(self): 211 | super(ContextNetwork, self).__init__() 212 | # input_shape: (None, 4, img_h, img_w) 213 | self.conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2) 214 | self.bn1 = nn.BatchNorm2d(32) 215 | self.act1 = nn.ReLU() 216 | # input_shape: (None, 32, img_h, img_w) 217 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 218 | self.bn2 = nn.BatchNorm2d(64) 219 | self.act2 = nn.ReLU() 220 | # input_shape: (None, 64, img_h//2, img_w//2) 221 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 222 | self.bn3 = nn.BatchNorm2d(64) 223 | self.act3 = nn.ReLU() 224 | # input_shape: (None, 128, img_h//2, img_w//2) 225 | self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 226 | self.bn4 = nn.BatchNorm2d(128) 227 | self.act4 = nn.ReLU() 228 | # input_shape: (None, 128, img_h//4, img_w//4) 229 | self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 230 | self.bn5 = nn.BatchNorm2d(128) 231 | self.act5 = nn.ReLU() 232 | # input_shape: (None, 128, img_h//4, img_w//4) 233 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 234 | self.bn6 = nn.BatchNorm2d(128) 235 | self.act6 = nn.ReLU() 236 | # input_shape: (None, 128, img_h//4, img_w//4) 237 | self.conv7 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=2, padding=2) 238 | self.bn7 = nn.BatchNorm2d(128) 239 | self.act7 = nn.ReLU() 240 | # input_shape: (None, 128, img_h//4, img_w//4) 241 | self.conv8 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=4, padding=4) 242 | self.bn8 = nn.BatchNorm2d(128) 243 | self.act8 = nn.ReLU() 244 | # input_shape: (None, 128, img_h//4, img_w//4) 245 | self.conv9 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=8, padding=8) 246 | self.bn9 = nn.BatchNorm2d(128) 247 | self.act9 = nn.ReLU() 248 | # input_shape: (None, 128, img_h//4, img_w//4) 249 | self.conv10 = nn.Conv2d(128, 128, kernel_size=3, stride=1, dilation=16, padding=16) 250 | self.bn10 = nn.BatchNorm2d(128) 251 | self.act10 = nn.ReLU() 252 | 253 | 254 | 255 | def forward(self, x): 256 | x = self.bn1(self.act1(self.conv1(x))) 257 | x = self.bn2(self.act2(self.conv2(x))) 258 | x = self.bn3(self.act3(self.conv3(x))) 259 | x = self.bn4(self.act4(self.conv4(x))) 260 | x = self.bn5(self.act5(self.conv5(x))) 261 | x = self.bn6(self.act6(self.conv6(x))) 262 | x = self.bn7(self.act7(self.conv7(x))) 263 | x = self.bn8(self.act8(self.conv8(x))) 264 | x = self.bn9(self.act9(self.conv9(x))) 265 | x = self.bn10(self.act10(self.conv10(x))) 266 | return x 267 | 268 | class IdentityGenerator(nn.Module): 269 | 270 | def __init__(self, in_dim = 100, dim=64): 271 | super(IdentityGenerator, self).__init__() 272 | 273 | self.l1 = nn.Sequential( 274 | nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False), 275 | nn.BatchNorm1d(dim * 8 * 4 * 4), 276 | nn.ReLU()) 277 | self.l2_5 = nn.Sequential( 278 | dconv_bn_relu(dim * 8, dim * 4), 279 | dconv_bn_relu(dim * 4, dim * 2)) 280 | 281 | def forward(self, x): 282 | y = self.l1(x) 283 | y = y.view(y.size(0), -1, 4, 4) 284 | y = self.l2_5(y) 285 | return y 286 | 287 | class InversionNet(nn.Module): 288 | def __init__(self, out_dim = 128): 289 | super(InversionNet, self).__init__() 290 | 291 | # input [4, h, w] output [256, h // 4, w // 4] 292 | self.ContextNetwork = ContextNetwork() 293 | # input [z_dim] output[128, 16, 16] 294 | self.IdentityGenerator = IdentityGenerator() 295 | 296 | self.dim = 128 + 128 297 | self.out_dim = out_dim 298 | 299 | self.Dconv = nn.Sequential( 300 | dconv_bn_relu(self.dim, self.out_dim), 301 | dconv_bn_relu(self.out_dim, self.out_dim // 2)) 302 | 303 | self.Conv = nn.Sequential( 304 | nn.Conv2d(self.out_dim // 2, self.out_dim // 4, kernel_size=3, stride=1, padding=1), 305 | nn.BatchNorm2d(self.out_dim // 4), 306 | nn.ReLU(), 307 | nn.Conv2d(self.out_dim // 4, 3, kernel_size=3, stride=1, padding=1), 308 | nn.Sigmoid()) 309 | 310 | 311 | def forward(self, inp): 312 | # x.shape [4, h, w] z.shape [100] 313 | x, z = inp 314 | context_info = self.ContextNetwork(x) 315 | identity_info = self.IdentityGenerator(z) 316 | y = torch.cat((context_info, identity_info), dim=1) 317 | y = self.Dconv(y) 318 | y = self.Conv(y) 319 | 320 | return y 321 | -------------------------------------------------------------------------------- /recovery.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from models.classify import * 3 | from models.generator import * 4 | from models.discri import * 5 | import torch 6 | import os 7 | import numpy as np 8 | from attack import inversion, dist_inversion 9 | from argparse import ArgumentParser 10 | 11 | 12 | torch.manual_seed(9) 13 | 14 | parser = ArgumentParser(description='Inversion') 15 | parser.add_argument('--configs', type=str, default='./config/celeba/attacking/ffhq.json') 16 | 17 | args = parser.parse_args() 18 | 19 | 20 | 21 | def init_attack_args(cfg): 22 | if cfg["attack"]["method"] =='kedmi': 23 | args.improved_flag = True 24 | args.clipz = True 25 | args.num_seeds = 1 26 | else: 27 | args.improved_flag = False 28 | args.clipz = False 29 | args.num_seeds = 5 30 | 31 | if cfg["attack"]["variant"] == 'L_logit' or cfg["attack"]["variant"] == 'ours': 32 | args.loss = 'logit_loss' 33 | else: 34 | args.loss = 'cel' 35 | 36 | if cfg["attack"]["variant"] == 'L_aug' or cfg["attack"]["variant"] == 'ours': 37 | args.classid = '0,1,2,3' 38 | else: 39 | args.classid = '0' 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | # global args, logger 45 | 46 | cfg = load_json(json_file=args.configs) 47 | init_attack_args(cfg=cfg) 48 | 49 | # Save dir 50 | if args.improved_flag == True: 51 | prefix = os.path.join(cfg["root_path"], "kedmi_300ids") 52 | else: 53 | prefix = os.path.join(cfg["root_path"], "gmi_300ids") 54 | save_folder = os.path.join("{}_{}".format(cfg["dataset"]["name"], cfg["dataset"]["model_name"]), cfg["attack"]["variant"]) 55 | prefix = os.path.join(prefix, save_folder) 56 | save_dir = os.path.join(prefix, "latent") 57 | save_img_dir = os.path.join(prefix, "imgs_{}".format(cfg["attack"]["variant"])) 58 | args.log_path = os.path.join(prefix, "invertion_logs") 59 | 60 | os.makedirs(prefix, exist_ok=True) 61 | os.makedirs(args.log_path, exist_ok=True) 62 | os.makedirs(save_img_dir, exist_ok=True) 63 | os.makedirs(save_dir, exist_ok=True) 64 | 65 | 66 | # Load models 67 | targetnets, E, G, D, n_classes, fea_mean, fea_logvar = get_attack_model(args, cfg) 68 | N = 5 69 | bs = 60 70 | 71 | 72 | # Begin attacking 73 | for i in range(1): 74 | iden = torch.from_numpy(np.arange(bs)) 75 | 76 | # evaluate on the first 300 identities only 77 | target_cosines = 0 78 | eval_cosines = 0 79 | for idx in range(5): 80 | iden = iden %n_classes 81 | print("--------------------- Attack batch [%s]------------------------------" % idx) 82 | print('Iden:{}'.format(iden)) 83 | save_dir_z = '{}/{}_{}'.format(save_dir,i,idx) 84 | 85 | if args.improved_flag == True: 86 | #KEDMI 87 | print('kedmi') 88 | 89 | dist_inversion(G, D, targetnets, E, iden, 90 | lr=cfg["attack"]["lr"], iter_times=cfg["attack"]["iters_mi"], 91 | momentum=0.9, lamda=100, 92 | clip_range=1, improved=args.improved_flag, 93 | num_seeds=args.num_seeds, 94 | used_loss=args.loss, 95 | prefix=save_dir_z, 96 | save_img_dir=os.path.join(save_img_dir, '{}_'.format(idx)), 97 | fea_mean=fea_mean, 98 | fea_logvar=fea_logvar, 99 | lam=cfg["attack"]["lam"], 100 | clipz=args.clipz) 101 | else: 102 | #GMI 103 | print('gmi') 104 | if cfg["attack"]["same_z"] =='': 105 | inversion(G, D, targetnets, E, iden, 106 | lr=cfg["attack"]["lr"], iter_times=cfg["attack"]["iters_mi"], 107 | momentum=0.9, lamda=100, 108 | clip_range=1, improved=args.improved_flag, 109 | used_loss=args.loss, 110 | prefix=save_dir_z, 111 | save_img_dir=save_img_dir, 112 | num_seeds=args.num_seeds, 113 | fea_mean=fea_mean, 114 | fea_logvar=fea_logvar,lam=cfg["attack"]["lam"], 115 | istart=args.istart) 116 | else: 117 | inversion(G, D, targetnets, E, iden, 118 | lr=args.lr, iter_times=args.iters_mi, 119 | momentum=0.9, lamda=100, 120 | clip_range=1, improved=args.improved_flag, 121 | used_loss=args.loss, 122 | prefix=save_dir_z, 123 | save_img_dir=save_img_dir, 124 | num_seeds=args.num_seeds, 125 | fea_mean=fea_mean, 126 | fea_logvar=fea_logvar,lam=cfg["attack"]["lam"], 127 | istart=args.istart, 128 | same_z='{}/{}_{}'.format(args.same_z,i,idx)) 129 | iden = iden + bs 130 | 131 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | brotlipy==0.7.0 3 | cachetools==5.3.0 4 | cycler==0.11.0 5 | fonttools==4.38.0 6 | google-auth==2.16.1 7 | google-auth-oauthlib==0.4.6 8 | grpcio==1.51.3 9 | importlib-metadata==6.0.0 10 | joblib==1.2.0 11 | kiwisolver==1.4.4 12 | Markdown==3.4.1 13 | MarkupSafe==2.1.2 14 | matplotlib==3.5.3 15 | mkl-fft==1.3.1 16 | mkl-service==2.4.0 17 | oauthlib==3.2.2 18 | opencv-python==4.7.0.72 19 | packaging==23.0 20 | pandas==1.3.5 21 | Pillow==9.4.0 22 | protobuf==3.20.3 23 | pyasn1==0.4.8 24 | pyasn1-modules==0.2.8 25 | pyparsing==3.0.9 26 | python-dateutil==2.8.2 27 | pytz==2022.7.1 28 | requests-oauthlib==1.3.1 29 | rsa==4.9 30 | scikit-learn==1.0.2 31 | scipy==1.7.3 32 | sklearn==0.0.post1 33 | tensorboard==2.11.2 34 | tensorboard-data-server==0.6.1 35 | tensorboard-plugin-wit==1.8.1 36 | tensorboardX==2.6 37 | threadpoolctl==3.1.0 38 | torchaudio==0.11.0 39 | torchvision==0.12.0 40 | tqdm==4.64.1 41 | Werkzeug==2.2.3 42 | zipp==3.15.0 43 | -------------------------------------------------------------------------------- /train_augmented_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from utils import * 7 | from models.classify import * 8 | from engine import train_augmodel 9 | from argparse import ArgumentParser 10 | 11 | 12 | # Training settings 13 | parser = ArgumentParser(description='Train Augmented model') 14 | parser.add_argument('--configs', type=str, default='./config/celeba/training_augmodel/ffhq.json') 15 | args = parser.parse_args() 16 | 17 | 18 | if __name__ == '__main__': 19 | file = args.configs 20 | cfg = load_json(json_file=file) 21 | 22 | train_augmodel(cfg) -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import torch, os, engine, utils 2 | import torch.nn as nn 3 | from argparse import ArgumentParser 4 | from models import classify 5 | 6 | 7 | parser = ArgumentParser(description='Train Classifier') 8 | parser.add_argument('--configs', type=str, default='./config/celeba/training_classifiers/classify.json') 9 | 10 | args = parser.parse_args() 11 | 12 | 13 | 14 | def main(args, model_name, trainloader, testloader): 15 | n_classes = args["dataset"]["n_classes"] 16 | mode = args["dataset"]["mode"] 17 | 18 | resume_path = args[args["dataset"]["model_name"]]["resume"] 19 | net = classify.get_classifier(model_name=model_name, mode=mode, n_classes=n_classes, resume_path=resume_path) 20 | 21 | print(net) 22 | 23 | optimizer = torch.optim.SGD(params=net.parameters(), 24 | lr=args[model_name]['lr'], 25 | momentum=args[model_name]['momentum'], 26 | weight_decay=args[model_name]['weight_decay']) 27 | 28 | criterion = nn.CrossEntropyLoss().cuda() 29 | net = torch.nn.DataParallel(net).to(args['dataset']['device']) 30 | 31 | mode = args["dataset"]["mode"] 32 | n_epochs = args[model_name]['epochs'] 33 | print("Start Training!") 34 | 35 | if mode == "reg": 36 | best_model, best_acc = engine.train_reg(args, net, criterion, optimizer, trainloader, testloader, n_epochs) 37 | elif mode == "vib": 38 | best_model, best_acc = engine.train_vib(args, net, criterion, optimizer, trainloader, testloader, n_epochs) 39 | 40 | torch.save({'state_dict':best_model.state_dict()}, os.path.join(model_path, "{}_{:.2f}_allclass.tar").format(model_name, best_acc[0])) 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | cfg = utils.load_json(json_file=args.configs) 46 | 47 | root_path = cfg["root_path"] 48 | log_path = os.path.join(root_path, "target_logs") 49 | model_path = os.path.join(root_path, "target_ckp") 50 | os.makedirs(model_path, exist_ok=True) 51 | os.makedirs(log_path, exist_ok=True) 52 | 53 | 54 | model_name = cfg['dataset']['model_name'] 55 | log_file = "{}.txt".format(model_name) 56 | utils.Tee(os.path.join(log_path, log_file), 'w') 57 | 58 | print("TRAINING %s" % model_name) 59 | utils.print_params(cfg["dataset"], cfg[model_name], dataset=cfg['dataset']['name']) 60 | 61 | train_file = cfg['dataset']['train_file_path'] 62 | test_file = cfg['dataset']['test_file_path'] 63 | _, trainloader = utils.init_dataloader(cfg, train_file, mode="train") 64 | _, testloader = utils.init_dataloader(cfg, test_file, mode="test") 65 | 66 | main(cfg, model_name, trainloader, testloader) 67 | -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | import engine 2 | from utils import load_json 3 | from argparse import ArgumentParser 4 | 5 | 6 | parser = ArgumentParser(description='Train GAN') 7 | parser.add_argument('--configs', type=str, default='./config/celeba/training_GAN/specific_gan/celeba.json') 8 | parser.add_argument('--mode', type=str, choices=['specific', 'general'], default='specific') 9 | args = parser.parse_args() 10 | 11 | 12 | if __name__ == "__main__": 13 | # os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7' 14 | file = args.configs 15 | cfg = load_json(json_file=file) 16 | 17 | if args.mode == 'specific': 18 | engine.train_specific_gan(cfg=cfg) 19 | elif args.mode == 'general': 20 | engine.train_general_gan(cfg=cfg) 21 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.init as init 2 | import os, models.facenet as facenet, sys 3 | import json, time, random, torch 4 | from models import classify 5 | from models.classify import * 6 | from models.discri import * 7 | from models.generator import * 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.utils as tvls 12 | from torchvision import transforms 13 | from datetime import datetime 14 | import dataloader 15 | from torch.autograd import grad 16 | 17 | device = "cuda" 18 | 19 | class Tee(object): 20 | def __init__(self, name, mode): 21 | self.file = open(name, mode) 22 | self.stdout = sys.stdout 23 | sys.stdout = self 24 | def __del__(self): 25 | sys.stdout = self.stdout 26 | self.file.close() 27 | def write(self, data): 28 | if not '...' in data: 29 | self.file.write(data) 30 | self.stdout.write(data) 31 | self.flush() 32 | def flush(self): 33 | self.file.flush() 34 | 35 | def init_dataloader(args, file_path, batch_size=64, mode="gan", iterator=False): 36 | tf = time.time() 37 | 38 | if mode == "attack": 39 | shuffle_flag = False 40 | else: 41 | shuffle_flag = True 42 | 43 | 44 | data_set = dataloader.ImageFolder(args, file_path, mode) 45 | 46 | if iterator: 47 | data_loader = torch.utils.data.DataLoader(data_set, 48 | batch_size=batch_size, 49 | shuffle=shuffle_flag, 50 | drop_last=True, 51 | num_workers=0, 52 | pin_memory=True).__iter__() 53 | else: 54 | data_loader = torch.utils.data.DataLoader(data_set, 55 | batch_size=batch_size, 56 | shuffle=shuffle_flag, 57 | drop_last=True, 58 | num_workers=2, 59 | pin_memory=True) 60 | interval = time.time() - tf 61 | print('Initializing data loader took %ds' % interval) 62 | 63 | return data_set, data_loader 64 | 65 | def load_pretrain(self, state_dict): 66 | own_state = self.state_dict() 67 | for name, param in state_dict.items(): 68 | if name.startswith("module.fc_layer"): 69 | continue 70 | if name not in own_state: 71 | print(name) 72 | continue 73 | own_state[name].copy_(param.data) 74 | 75 | def load_state_dict(self, state_dict): 76 | own_state = self.state_dict() 77 | for name, param in state_dict.items(): 78 | if name not in own_state: 79 | print(name) 80 | continue 81 | own_state[name].copy_(param.data) 82 | 83 | def load_json(json_file): 84 | with open(json_file) as data_file: 85 | data = json.load(data_file) 86 | return data 87 | 88 | def print_params(info, params, dataset=None): 89 | print('-----------------------------------------------------------------') 90 | print("Running time: %s" % datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 91 | for i, (key, value) in enumerate(info.items()): 92 | print("%s: %s" % (key, str(value))) 93 | for i, (key, value) in enumerate(params.items()): 94 | print("%s: %s" % (key, str(value))) 95 | print('-----------------------------------------------------------------') 96 | 97 | def save_tensor_images(images, filename, nrow = None, normalize = True): 98 | if not nrow: 99 | tvls.save_image(images, filename, normalize = normalize, padding=0) 100 | else: 101 | tvls.save_image(images, filename, normalize = normalize, nrow=nrow, padding=0) 102 | 103 | 104 | def get_deprocessor(): 105 | # resize 112,112 106 | proc = [] 107 | proc.append(transforms.Resize((112, 112))) 108 | proc.append(transforms.ToTensor()) 109 | return transforms.Compose(proc) 110 | 111 | def low2high(img): 112 | # 0 and 1, 64 to 112 113 | bs = img.size(0) 114 | proc = get_deprocessor() 115 | img_tensor = img.detach().cpu().float() 116 | img = torch.zeros(bs, 3, 112, 112) 117 | for i in range(bs): 118 | img_i = transforms.ToPILImage()(img_tensor[i, :, :, :]).convert('RGB') 119 | img_i = proc(img_i) 120 | img[i, :, :, :] = img_i[:, :, :] 121 | 122 | img = img.cuda() 123 | return img 124 | 125 | 126 | def get_model(attack_name, classes): 127 | if attack_name.startswith("VGG16"): 128 | T = classify.VGG16(classes) 129 | elif attack_name.startswith("IR50"): 130 | T = classify.IR50(classes) 131 | elif attack_name.startswith("IR152"): 132 | T = classify.IR152(classes) 133 | elif attack_name.startswith("FaceNet64"): 134 | T = facenet.FaceNet64(classes) 135 | else: 136 | print("Model doesn't exist") 137 | exit() 138 | 139 | T = torch.nn.DataParallel(T).cuda() 140 | return T 141 | 142 | def get_augmodel(model_name, nclass, path_T=None, dataset='celeba'): 143 | if model_name=="VGG16": 144 | model = VGG16(nclass) 145 | elif model_name=="FaceNet": 146 | model = FaceNet(nclass) 147 | elif model_name=="FaceNet64": 148 | model = FaceNet64(nclass) 149 | elif model_name=="IR152": 150 | model = IR152(nclass) 151 | elif model_name =="efficientnet_b0": 152 | model = classify.EfficientNet_b0(nclass) 153 | elif model_name =="efficientnet_b1": 154 | model = classify.EfficientNet_b1(nclass) 155 | elif model_name =="efficientnet_b2": 156 | model = classify.EfficientNet_b2(nclass) 157 | 158 | model = torch.nn.DataParallel(model).cuda() 159 | if path_T is not None: 160 | 161 | ckp_T = torch.load(path_T) 162 | model.load_state_dict(ckp_T['state_dict'], strict=True) 163 | return model 164 | 165 | 166 | def log_sum_exp(x, axis = 1): 167 | m = torch.max(x, dim = 1)[0] 168 | return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis)) 169 | 170 | # define "soft" cross-entropy with pytorch tensor operations 171 | def softXEnt (input, target): 172 | targetprobs = nn.functional.softmax (target, dim = 1) 173 | logprobs = nn.functional.log_softmax (input, dim = 1) 174 | return -(targetprobs * logprobs).sum() / input.shape[0] 175 | 176 | class HLoss(nn.Module): 177 | def __init__(self): 178 | super(HLoss, self).__init__() 179 | 180 | def forward(self, x): 181 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 182 | b = -1.0 * b.sum() 183 | return b 184 | 185 | def freeze(net): 186 | for p in net.parameters(): 187 | p.requires_grad_(False) 188 | 189 | def unfreeze(net): 190 | for p in net.parameters(): 191 | p.requires_grad_(True) 192 | 193 | def gradient_penalty(x, y, DG): 194 | # interpolation 195 | shape = [x.size(0)] + [1] * (x.dim() - 1) 196 | alpha = torch.rand(shape).cuda() 197 | z = x + alpha * (y - x) 198 | z = z.cuda() 199 | z.requires_grad = True 200 | 201 | o = DG(z) 202 | g = grad(o, z, grad_outputs = torch.ones(o.size()).cuda(), create_graph = True)[0].view(z.size(0), -1) 203 | gp = ((g.norm(p = 2, dim = 1) - 1) ** 2).mean() 204 | 205 | return gp 206 | 207 | def log_sum_exp(x, axis = 1): 208 | m = torch.max(x, dim = 1)[0] 209 | return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis)) 210 | 211 | def count_parameters(model): 212 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | 214 | 215 | 216 | def get_GAN(dataset, gan_type, gan_model_dir, n_classes, z_dim, target_model): 217 | 218 | G = Generator(z_dim) 219 | if gan_type == True: 220 | D = MinibatchDiscriminator(n_classes=n_classes) 221 | else: 222 | D = DGWGAN(3) 223 | 224 | if gan_type == True: 225 | path = os.path.join(os.path.join(gan_model_dir, dataset), target_model) 226 | path_G = os.path.join(path, "improved_{}_G.tar".format(dataset)) 227 | path_D = os.path.join(path, "improved_{}_D.tar".format(dataset)) 228 | else: 229 | path = os.path.join(gan_model_dir, dataset) 230 | path_G = os.path.join(path, "{}_G.tar".format(dataset)) 231 | path_D = os.path.join(path, "{}_D.tar".format(dataset)) 232 | 233 | print('path_G',path_G) 234 | print('path_D',path_D) 235 | 236 | G = torch.nn.DataParallel(G).to(device) 237 | D = torch.nn.DataParallel(D).to(device) 238 | ckp_G = torch.load(path_G) 239 | G.load_state_dict(ckp_G['state_dict'], strict=True) 240 | ckp_D = torch.load(path_D) 241 | D.load_state_dict(ckp_D['state_dict'], strict=True) 242 | 243 | return G, D 244 | 245 | 246 | def get_attack_model(args, args_json, eval_mode=False): 247 | now = datetime.now() # current date and time 248 | 249 | if not eval_mode: 250 | log_file = "invertion_logs_{}_{}.txt".format(args.loss,now.strftime("%m_%d_%Y_%H_%M_%S")) 251 | utils.Tee(os.path.join(args.log_path, log_file), 'w') 252 | 253 | n_classes=args_json['dataset']['n_classes'] 254 | 255 | 256 | model_types_ = args_json['train']['model_types'].split(',') 257 | checkpoints = args_json['train']['cls_ckpts'].split(',') 258 | 259 | G, D = get_GAN(args_json['dataset']['name'],gan_type=args.improved_flag, 260 | gan_model_dir=args_json['train']['gan_model_dir'], 261 | n_classes=n_classes,z_dim=100,target_model=model_types_[0]) 262 | 263 | dataset = args_json['dataset']['name'] 264 | cid = args.classid.split(',') 265 | # target and student classifiers 266 | for i in range(len(cid)): 267 | id_ = int(cid[i]) 268 | model_types_[id_] = model_types_[id_].strip() 269 | checkpoints[id_] = checkpoints[id_].strip() 270 | print('Load classifier {} at {}'.format(model_types_[id_], checkpoints[id_])) 271 | model = get_augmodel(model_types_[id_],n_classes,checkpoints[id_],dataset) 272 | model = model.to(device) 273 | model = model.eval() 274 | if i==0: 275 | targetnets = [model] 276 | else: 277 | targetnets.append(model) 278 | 279 | # p_reg 280 | if args.loss=='logit_loss': 281 | if model_types_[id_] == "IR152" or model_types_[id_]=="VGG16" or model_types_[id_]=="FaceNet64": 282 | #target model 283 | p_reg = os.path.join(args_json["dataset"]["p_reg_path"], '{}_{}_p_reg.pt'.format(dataset,model_types_[id_])) #'./p_reg/{}_{}_p_reg.pt'.format(dataset,model_types_[id_]) 284 | else: 285 | #aug model 286 | p_reg = os.path.join(args_json["dataset"]["p_reg_path"], '{}_{}_{}_p_reg.pt'.format(dataset,model_types_[0],model_types_[id_])) #'./p_reg/{}_{}_{}_p_reg.pt'.format(dataset,model_types_[0],model_types_[id_]) 287 | # print('p_reg',p_reg) 288 | if not os.path.exists(p_reg): 289 | _, dataloader_gan = init_dataloader(args_json, args_json['dataset']['gan_file_path'], 50, mode="gan") 290 | from attack import get_act_reg 291 | fea_mean_,fea_logvar_ = get_act_reg(dataloader_gan,model,device) 292 | torch.save({'fea_mean':fea_mean_,'fea_logvar':fea_logvar_},p_reg) 293 | else: 294 | fea_reg = torch.load(p_reg) 295 | fea_mean_ = fea_reg['fea_mean'] 296 | fea_logvar_ = fea_reg['fea_logvar'] 297 | if i == 0: 298 | fea_mean = [fea_mean_.to(device)] 299 | fea_logvar = [fea_logvar_.to(device)] 300 | else: 301 | fea_mean.append(fea_mean_) 302 | fea_logvar.append(fea_logvar_) 303 | # print('fea_logvar_',i,fea_logvar_.shape,fea_mean_.shape) 304 | 305 | else: 306 | fea_mean,fea_logvar = 0,0 307 | 308 | # evaluation classifier 309 | E = get_augmodel(args_json['train']['eval_model'],n_classes,args_json['train']['eval_dir']) 310 | E.eval() 311 | G.eval() 312 | D.eval() 313 | 314 | return targetnets, E, G, D, n_classes, fea_mean, fea_logvar 315 | --------------------------------------------------------------------------------