├── .gitignore ├── FLGradientInversion ├── README.md ├── config │ └── config_inversion.json ├── fl_gradient_inversion.py ├── main.py ├── orig.png ├── prior │ └── prior_1.jpg ├── recon.png ├── requirements.txt └── torchvision_class.py ├── LICENSE ├── README.md ├── cifar10 ├── README.md ├── deepinversion_cifar10.py ├── images │ └── better_last.png └── resnet_cifar.py ├── deepinversion.py ├── example_logs ├── fp16_set0_rn50.log ├── fp16_set0_rn50_adi02.log ├── fp16_set0_rn50_adi02_output_00030_gpu_0.jpg ├── fp16_set0_rn50_output_00030_gpu_0.jpg ├── fp16_set1_rn50.log ├── fp16_set1_rn50_output_00020_gpu_0.jpg ├── fp32_set0_mnv2.log ├── fp32_set0_mnv2_output_00030_gpu_0.jpg ├── fp32_set0_rn50.log ├── fp32_set0_rn50_first_bn_scaled.jpg ├── fp32_set0_rn50_first_bn_scaled.log ├── fp32_set0_rn50_output_00030_gpu_0.jpg └── teaser.png ├── imagenet_inversion.py ├── models └── resnetv15.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | generation/ 2 | temp/ 3 | __pycache__ 4 | .idea/ 5 | *.tar.gz 6 | *.zip 7 | *.pkl 8 | *.pyc 9 | -------------------------------------------------------------------------------- /FLGradientInversion/README.md: -------------------------------------------------------------------------------- 1 | # FL Gradient Inversion 2 | 3 | This directory contains the tools necessary to recreate the chest X-ray 4 | experiments described in 5 | 6 | 7 | ### Do Gradient Inversion Attacks Make Federated Learning Unsafe? [arXiv:2202.06924](https://arxiv.org/abs/2202.06924) 8 | 9 | ###### Abstract: 10 | 11 | > Federated learning (FL) allows the collaborative training of AI models without needing to share raw data. This capability makes it especially interesting for healthcare applications where patient and data privacy is of utmost concern. However, recent works on the inversion of deep neural networks from model gradients raised concerns about the security of FL in preventing the leakage of training data. In this work, we show that these attacks presented in the literature are impractical in real FL use-cases and provide a new baseline attack that works for more realistic scenarios where the clients' training involves updating the Batch Normalization (BN) statistics. Furthermore, we present new ways to measure and visualize potential data leakage in FL. Our work is a step towards establishing reproducible methods of measuring data leakage in FL and could help determine the optimal tradeoffs between privacy-preserving techniques, such as differential privacy, and model accuracy based on quantifiable metrics. 12 | 13 | 14 | ## Updates 15 | 16 | ***01/16/2023*** 17 | 18 | 1. Our FL Gradient Inversion [paper](https://arxiv.org/pdf/2202.06924.pdf) is accepted to [IEEE Transactions on Medical Imaging (TMI)](https://www.embs.org/tmi/). 19 | 20 | 2. We release the code for FL Gradient Inversion model. 21 | 22 | ## Quick-start 23 | 24 | First, install requirements. The code was tested with Python 3.10. 25 | ```setup 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | To run an example gradient inversion attack from pre-recorded FL gradients 30 | from a "high-risk" client sending updates in the 10th training round based on 31 | just image (batch size 1), execute the following. 32 | 33 | ##### 1. Download the pre-recorded weights. 34 | 35 | Download the weights [here](https://drive.google.com/file/d/1o6aZy2oBSD7ayPgkHfZ41lzANhldTVyr/view?usp=share_link) 36 | and extract to `./weights`. 37 | 38 | The extracted folder should have the following content. 39 | ``` 40 | weights 41 | ├── batchnorm_round10_client9.npz 42 | ├── FL_global_model_round10.pt 43 | └── updates_round10_client9.npz 44 | ``` 45 | 46 | #### 2. Run the inversion code 47 | ``` 48 | ./main.py 49 | ``` 50 | 51 | 52 | ## Federated Learning Experiments 53 | 54 | To reproduce the experiments in the paper, we use [NVIDIA FLARE](https://github.com/NVIDIA/NVFlare) to produce 55 | the model updates shared in federated learning. Please visit [here] 56 | (https://nvidia.github.io/NVFlare/research/gradient-inversion) for 57 | details. 58 | 59 | The expected result is saved under [./outputs/recon.png](./outputs/recon.png). For larger 60 | training set sizes, several images will be reconstructed. See the 61 | "local_num_images" config option. 62 | 63 | #### Reconstruction 64 | 65 | | Original | Inversion | 66 | |-----------------|------------------| 67 | | ![](./orig.png) | ![](./recon.png) | 68 | 69 | > Note, the original image is from the [COVID-19 Radiography Database](https://www.kaggle.com/tawsifurrahman/covid19-radiography-database) (Normal-4085.png), 70 | > with a random patient name and date of birth overlaid. 71 | 72 | ## Citation 73 | 74 | > A. Hatamizadeh et al., "Do Gradient Inversion Attacks Make Federated Learning Unsafe?," in IEEE Transactions on Medical Imaging, doi: 10.1109/TMI.2023.3239391. 75 | 76 | BibTeX 77 | ``` 78 | @ARTICLE{10025466, 79 | author={Hatamizadeh, Ali and Yin, Hongxu and Molchanov, Pavlo and Myronenko, Andriy and Li, Wenqi and Dogra, Prerna and Feng, Andrew and Flores, Mona G. and Kautz, Jan and Xu, Daguang and Roth, Holger R.}, 80 | journal={IEEE Transactions on Medical Imaging}, 81 | title={Do Gradient Inversion Attacks Make Federated Learning Unsafe?}, 82 | year={2023}, 83 | volume={}, 84 | number={}, 85 | pages={1-1}, 86 | doi={10.1109/TMI.2023.3239391}} 87 | ``` 88 | 89 | ## License 90 | 91 | Copyright (C) 2023 NVIDIA Corporation. All rights reserved. 92 | 93 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE 94 | -------------------------------------------------------------------------------- /FLGradientInversion/config/config_inversion.json: -------------------------------------------------------------------------------- 1 | { 2 | "checkpoint_file": "./weights/FL_global_model_round10.pt", 3 | "weights_file": "./weights/updates_round10_client9.npz", 4 | "batchnorm_file": "./weights/batchnorm_round10_client9.npz", 5 | "img_prior": "./prior/prior_1.jpg", 6 | "save_path": "./outputs/", 7 | "model_name": "resnet18", 8 | "criterion": "BCEWithLogitsLoss", 9 | "num_classes": 2, 10 | "batch_size": 1, 11 | "iterations": 40000, 12 | "resolution": 224, 13 | "pretrained": false, 14 | "start_rand": false, 15 | "init_target_rand": true, 16 | "no_lr_decay": false, 17 | "grad_l2": 1e-3, 18 | "original_bn_l2": 1e-1, 19 | "energy_l2": 1e-1, 20 | "tv_l1": 0.0, 21 | "tv_l2": 1e-4, 22 | "lr": 1e-1, 23 | "l2": 1e-5, 24 | "lr_local": 1e-2, 25 | "local_bs": 1, 26 | "local_epoch": 1, 27 | "local_num_images": 1, 28 | "local_optim": "sgd", 29 | "save_every": 500 30 | } 31 | -------------------------------------------------------------------------------- /FLGradientInversion/fl_gradient_inversion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import collections 10 | import json 11 | import logging 12 | import os 13 | from copy import deepcopy 14 | from typing import Callable, Dict, Iterable, Optional, Tuple, Union 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import torch 19 | import torchvision 20 | from ignite.engine import Engine 21 | from monai.data import DataLoader 22 | from monai.engines import SupervisedTrainer 23 | from monai.engines.utils import IterationEvents, default_prepare_batch 24 | from monai.inferers import SimpleInferer 25 | from monai.utils.enums import CommonKeys as Keys 26 | from PIL import Image 27 | 28 | 29 | class FLGradientInversion(object): 30 | def __init__( 31 | self, 32 | network, 33 | grad_lst, 34 | bn_stats, 35 | model_bn, 36 | prior_transforms=None, 37 | save_transforms=None, 38 | ): 39 | """FLGradientInversion is used to reconstruct training images and 40 | targets (ground truth labels) by attempting to invert the gradients 41 | (model updates) shared in a federated learning framework. 42 | 43 | Args: 44 | network: network for which the gradients are being inverted, 45 | i.e. the current global model the models updates are being 46 | computed with respect to. 47 | grad_lst: model updates. 48 | bn_stats: updated batch norm statistics. 49 | model_bn: updated model containing current batch norm statistics. 50 | prior_transforms: Optional custom transforms to read the prior 51 | image. Defaults to None. 52 | save_transforms: Optional transforms to save the reconstructed 53 | images. Defaults to None. 54 | Returns: 55 | __call__() function returns the reconstructions. 56 | """ 57 | self.network = network 58 | self.bn_stats = bn_stats 59 | self.model_bn = model_bn 60 | self.loss_r_feature_layers = [] 61 | self.grad_lst = grad_lst 62 | self.logger = logging.getLogger(self.__class__.__name__) 63 | self.prior_transforms = prior_transforms 64 | self.save_transforms = save_transforms 65 | 66 | def __call__(self, cfg): 67 | """Run the gradient inversion attack. 68 | 69 | Args: 70 | cfg: Configuration dictionary containing the following keys used 71 | in this call. 72 | - img_prior: full path to prior image file used to initialize 73 | the attack. 74 | - save_path: Optional save directory where reconstructed 75 | images and targets are being saved. 76 | - criterion: Loss used for training the classification 77 | network, e.g. "BCEWithLogitsLoss". 78 | - iterations: number of iterations to run the attack. 79 | - resolution: x/y dimension of the images to be reconstructed. 80 | - start_rand: Whether to start from random initialization. 81 | If `False`, the `img_prior` is used. 82 | - init_target_rand: Whether to initialize the reconstructed 83 | targets using a uniform distribution. If `False`, targets 84 | are initialized as all zeros. 85 | - no_lr_decay: Disable the learning rate decay of the 86 | optimizer. 87 | - grad_l2: L2 scaling factor on the gradient loss. 88 | - original_bn_l2: Scaling factor for batchnorm matching loss. 89 | - energy_l2: This adds gaussian noise to find global minimums. 90 | - tv_l1: Coefficient for total variation L1 loss. 91 | - tv_l2: Coefficient for total variation L2 loss. 92 | - lr: Learning rate for optimization. 93 | - l2: L2 loss on the image. 94 | - local_epoch: Local number of epochs used by the FL client. 95 | - local_optim: Local optimizer used by the FL client, Either 96 | "sgd" or "adam". 97 | - save_every: How often to save the reconstructions to file. 98 | Returns: 99 | Reconstructed images. 100 | """ 101 | self.save_path = cfg["save_path"] 102 | save_every = cfg["save_every"] 103 | if save_every > 0: 104 | self.create_folder(self.save_path) 105 | 106 | if cfg["criterion"] == "BCEWithLogitsLoss": 107 | criterion = torch.nn.BCEWithLogitsLoss() 108 | elif cfg["criterion"] == "CrossEntropyLoss": 109 | criterion = torch.nn.CrossEntropyLoss() 110 | else: 111 | raise ValueError( 112 | "criterion should be BCEWithLogitsLoss or CrossEntropyLoss." 113 | ) 114 | 115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | network = self.network 117 | local_rank = torch.cuda.current_device() 118 | if cfg["start_rand"]: 119 | inputs_1 = torch.randn( 120 | (cfg["batch_size"], 1, cfg["resolution"], cfg["resolution"]), 121 | requires_grad=True, 122 | device=device, 123 | dtype=torch.float, 124 | ) 125 | else: 126 | prior_file = cfg["img_prior"] 127 | if self.prior_transforms: 128 | _img = self.prior_transforms(prior_file) 129 | else: # use default prior loading transforms 130 | pil_img = Image.open(prior_file) 131 | self.prior_transforms = torchvision.transforms.Compose( 132 | [ 133 | torchvision.transforms.Resize( 134 | (cfg["resolution"], cfg["resolution"]) 135 | ), 136 | torchvision.transforms.ToTensor(), 137 | ] 138 | ) 139 | _img = self.prior_transforms(pil_img) 140 | 141 | # make init batch 142 | images = torch.empty( 143 | size=( 144 | cfg["local_num_images"], 145 | 1, 146 | cfg["resolution"], 147 | cfg["resolution"], 148 | ) 149 | ) 150 | for i in range(cfg["local_num_images"]): 151 | images[i] = _img.unsqueeze_(0) 152 | inputs_1 = images.to(device) 153 | inputs_1.requires_grad_(True) 154 | 155 | if cfg["init_target_rand"]: 156 | targets_in = torch.rand( 157 | (cfg["local_num_images"], 2), 158 | requires_grad=True, 159 | device=device, 160 | dtype=torch.float, 161 | ) 162 | else: 163 | targets_in = torch.zeros( 164 | (cfg["local_num_images"], 2), 165 | requires_grad=True, 166 | device=device, 167 | dtype=torch.float, 168 | ) 169 | 170 | iteration = -1 171 | for lr_it, _ in enumerate([2, 1]): 172 | iterations_per_layer = cfg["iterations"] 173 | if lr_it == 0: 174 | continue 175 | optimizer = torch.optim.Adam( 176 | [inputs_1, targets_in], 177 | lr=cfg["lr"], 178 | betas=[0.9, 0.9], 179 | eps=1e-8, 180 | ) 181 | lr_scheduler = self.lr_cosine_policy(cfg["lr"], 100, iterations_per_layer) 182 | local_trainer = self.create_trainer( 183 | cfg=cfg, 184 | network=network, 185 | inputs=( 186 | inputs_1 * torch.ones((1, 3, 1, 1)).cuda() 187 | ), # turn grayscale to RGB (3-channel inputs) 188 | targets=targets_in, 189 | criterion=criterion, 190 | device=torch.device("cuda"), 191 | ) 192 | for iteration_loc in range(iterations_per_layer): 193 | iteration += 1 194 | if not cfg["no_lr_decay"]: 195 | lr_scheduler(optimizer, iteration_loc, iteration_loc) 196 | inputs = inputs_1 * torch.ones((1, 3, 1, 1)).cuda() 197 | optimizer.zero_grad() 198 | network.zero_grad() 199 | network.train() 200 | loss_var_l1, loss_var_l2 = self.img_prior(inputs) 201 | loss_l2 = torch.norm( 202 | inputs.view(cfg["local_num_images"], -1), dim=1 203 | ).mean() 204 | loss_aux = ( 205 | cfg["tv_l2"] * loss_var_l2 206 | + cfg["tv_l1"] * loss_var_l1 207 | + cfg["l2"] * loss_l2 208 | ) 209 | loss = loss_aux 210 | if cfg["grad_l2"] > 0: 211 | new_grad = self.sim_local_updates( 212 | cfg=cfg, 213 | trainer=local_trainer, 214 | network=network, 215 | inputs=inputs, 216 | targets=targets_in, 217 | use_sigmoid=True, 218 | use_softmax=False, 219 | ) 220 | loss_grad = 0 221 | for a, b in zip(new_grad, self.grad_lst): 222 | loss_grad += cfg["grad_l2"] * (torch.norm(a - b[1])) 223 | loss = loss + loss_grad 224 | 225 | # add batch norm loss 226 | bn_hooks = [] 227 | self.model_bn.train() 228 | for name, module in self.model_bn.named_modules(): 229 | if isinstance(module, torch.nn.BatchNorm2d): 230 | bn_hooks.append( 231 | DeepInversionFeatureHook( 232 | module=module, 233 | bn_stats=self.bn_stats, 234 | name=name, 235 | ) 236 | ) 237 | # run forward path once to compute bn_hooks 238 | self.model_bn(inputs) 239 | loss_bn_tmp = 0 240 | for hook in bn_hooks: 241 | loss_bn_tmp += hook.r_feature 242 | hook.close() 243 | loss_bn = cfg["original_bn_l2"] * loss_bn_tmp 244 | loss += loss_bn 245 | loss.backward(retain_graph=True) 246 | optimizer.step() 247 | if local_rank == 0: 248 | if iteration % save_every == 0: 249 | self.logger.info(f"------------iteration {iteration}----------") 250 | self.logger.info(f"total loss {loss.item()}") 251 | self.logger.info( 252 | f"mean targets {torch.mean(targets_in, 0).detach().cpu().numpy()}" 253 | ) 254 | self.logger.info(f"gradient loss {loss_grad.item()}") 255 | self.logger.info(f"bn matching loss {loss_bn.item()}") 256 | self.logger.info( 257 | f"tvl2 loss {cfg['tv_l2'] * loss_var_l2.item()}" 258 | ) 259 | best_inputs = inputs.clone() 260 | if iteration % save_every == 0 and (save_every > 0): 261 | self.save_results( 262 | images=best_inputs, targets=targets_in, name="recon" 263 | ) 264 | # save reconstruction collage 265 | torchvision.utils.save_image( 266 | best_inputs, 267 | os.path.join(self.save_path, "recon.png"), 268 | normalize=True, 269 | scale_each=True, 270 | nrow=int(int(cfg["local_num_images"]) ** 0.5), 271 | ) 272 | if cfg["energy_l2"] > 0.0: 273 | inputs_noise_add = torch.randn(inputs.size(), device=device) 274 | for param_group in optimizer.param_groups: 275 | current_lr = param_group["lr"] 276 | break 277 | std = cfg["energy_l2"] * current_lr 278 | if iteration % save_every == 0: 279 | if local_rank == 0: 280 | self.logger.info( 281 | f"Energy method waken up, " 282 | f"adding Gaussian of std {std}" 283 | ) 284 | inputs.data = inputs.data + inputs_noise_add * std 285 | 286 | if save_every > 0: 287 | self.save_results(images=best_inputs, targets=targets_in, name="recon") 288 | 289 | optimizer.state = collections.defaultdict(dict) 290 | 291 | return best_inputs, targets_in 292 | 293 | @staticmethod 294 | def sim_local_updates( 295 | cfg, 296 | trainer, 297 | network, 298 | inputs, 299 | targets, 300 | use_softmax=False, 301 | use_sigmoid=True, 302 | ): 303 | """ 304 | Run the equivalent local optimization loop to get gradients 305 | which will be matched (using SupervisedTrainer) 306 | """ 307 | trainer.logger.setLevel(logging.WARNING) 308 | 309 | params_before = deepcopy(network.state_dict()) 310 | trainer.network.load_state_dict(params_before) 311 | if use_softmax and use_sigmoid: 312 | raise ValueError( 313 | "Only set one of `use_softmax` or `use_sigmoid` to be true." 314 | ) 315 | if use_softmax: 316 | targets = torch.softmax(targets, dim=-1) 317 | if use_sigmoid: 318 | targets = torch.sigmoid(targets) 319 | data = [] 320 | for i in range(cfg["local_num_images"]): 321 | data.append({Keys.IMAGE: inputs[i, ...], Keys.LABEL: targets[i, ...]}) 322 | trainer.data_loader = DataLoader([data], batch_size=cfg["local_bs"]) 323 | if cfg["local_optim"] == "sgd": 324 | optimizer = torch.optim.SGD(network.parameters(), cfg["lr_local"]) 325 | elif cfg["local_optim"] == "adam": 326 | optimizer = torch.optim.Adam(network.parameters(), cfg["lr_local"]) 327 | else: 328 | raise ValueError( 329 | f"Local optimizer {cfg['local_optim']} " f"is not currently supported !" 330 | ) 331 | trainer.optimizer.load_state_dict(optimizer.state_dict()) 332 | trainer.optimizer.zero_grad() 333 | trainer.network.zero_grad() 334 | trainer.run() 335 | params_after = trainer.network.state_dict() 336 | new_grad = [] 337 | for name, _ in network.named_parameters(): 338 | new_grad.append(params_after[name] - params_before[name]) 339 | return new_grad 340 | 341 | @staticmethod 342 | def create_trainer(cfg, network, inputs, targets, criterion, device=None): 343 | if device is None: 344 | device = torch.device("cuda") 345 | 346 | data = [] 347 | for i in range(cfg["local_num_images"]): 348 | data.append({Keys.IMAGE: inputs[i, ...], Keys.LABEL: targets[i, ...]}) 349 | loader = DataLoader([data], batch_size=cfg["local_bs"]) 350 | if cfg["local_optim"] == "sgd": 351 | optimizer = torch.optim.SGD(network.parameters(), cfg["lr_local"]) 352 | elif cfg["local_optim"] == "adam": 353 | optimizer = torch.optim.Adam(network.parameters(), cfg["lr_local"]) 354 | else: 355 | raise ValueError( 356 | "Local optimizer {} is not currently supported !".format( 357 | cfg["local_optim"] 358 | ) 359 | ) 360 | optimizer.zero_grad() 361 | trainer = InversionSupervisedTrainer( 362 | device=device, 363 | max_epochs=cfg["local_epoch"], 364 | train_data_loader=loader, 365 | network=network, 366 | optimizer=optimizer, 367 | loss_function=criterion, 368 | amp=False, 369 | ) 370 | return trainer 371 | 372 | def img_prior(self, inputs_jit): 373 | # COMPUTE total variation regularization loss 374 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:] 375 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :] 376 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:] 377 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:] 378 | loss_var_l2 = ( 379 | torch.norm(diff1) 380 | + torch.norm(diff2) 381 | + torch.norm(diff3) 382 | + torch.norm(diff4) 383 | ) 384 | loss_var_l1 = ( 385 | (diff1.abs() / 255.0).mean() 386 | + (diff2.abs() / 255.0).mean() 387 | + (diff3.abs() / 255.0).mean() 388 | + (diff4.abs() / 255.0).mean() 389 | ) 390 | loss_var_l1 = loss_var_l1 * 255.0 391 | return loss_var_l1, loss_var_l2 392 | 393 | def denormalize(self, image_tensor, use_fp16=False): 394 | 395 | if use_fp16: 396 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16) 397 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16) 398 | else: 399 | mean = np.array([0.485, 0.456, 0.406]) 400 | std = np.array([0.229, 0.224, 0.225]) 401 | 402 | for c in range(3): 403 | m, s = mean[c], std[c] 404 | 405 | if len(image_tensor.shape) == 4: 406 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1) 407 | 408 | elif len(image_tensor.shape) == 3: 409 | image_tensor[c] = torch.clamp(image_tensor[c] * s + m, 0, 1) 410 | else: 411 | raise NotImplementedError() 412 | 413 | return image_tensor 414 | 415 | def create_folder(self, directory): 416 | 417 | if not os.path.exists(directory): 418 | os.makedirs(directory) 419 | 420 | def lr_policy(self, lr_fn): 421 | def _alr(optimizer, iteration, epoch): 422 | lr = lr_fn(iteration, epoch) 423 | for param_group in optimizer.param_groups: 424 | param_group["lr"] = lr 425 | 426 | return _alr 427 | 428 | def lr_cosine_policy(self, base_lr, warmup_length, epochs): 429 | def _lr_fn(iteration, epoch): 430 | if epoch < warmup_length: 431 | lr = base_lr * (epoch + 1) / warmup_length 432 | else: 433 | e = epoch - warmup_length 434 | es = epochs - warmup_length 435 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 436 | return lr 437 | 438 | return self.lr_policy(_lr_fn) 439 | 440 | def save_results(self, images, targets, name="recon"): 441 | # save reconstructed images 442 | for id in range(images.shape[0]): 443 | img = images[id, ...] 444 | if self.save_transforms: 445 | self.save_transforms(img) 446 | else: 447 | save_name = f"{name}_{id}.png" 448 | place_to_store = os.path.join(self.save_path, save_name) 449 | 450 | image_np = img.data.cpu().numpy() 451 | image_np = image_np.transpose((1, 2, 0)) 452 | image_np = np.array( 453 | (image_np - np.min(image_np)) 454 | / (np.max(image_np) - np.min(image_np)) 455 | ) 456 | plt.imsave(place_to_store, image_np) 457 | 458 | # save reconstructed targets 459 | place_to_store = os.path.join(self.save_path, f"{name}_targets.json") 460 | 461 | with open(place_to_store, "w") as f: 462 | json.dump(targets.detach().cpu().numpy().tolist(), f, indent=4) 463 | 464 | 465 | class InversionSupervisedTrainer(SupervisedTrainer): 466 | """ 467 | Same as MONAI's SupervisedTrainer but using 468 | retain_graph=True in backward() calls. 469 | """ 470 | 471 | def __init__( 472 | self, 473 | device: torch.device, 474 | max_epochs: int, 475 | train_data_loader: Union[Iterable, DataLoader], 476 | network: torch.nn.Module, 477 | optimizer: torch.optim.Optimizer, 478 | loss_function: Callable, 479 | epoch_length: Optional[int] = None, 480 | non_blocking: bool = False, 481 | prepare_batch: Callable = default_prepare_batch, 482 | amp: bool = False, 483 | ) -> None: 484 | super().__init__( 485 | device=device, 486 | max_epochs=max_epochs, 487 | train_data_loader=train_data_loader, 488 | network=network, 489 | optimizer=optimizer, 490 | loss_function=loss_function, 491 | epoch_length=epoch_length, 492 | non_blocking=non_blocking, 493 | prepare_batch=prepare_batch, 494 | iteration_update=None, 495 | inferer=SimpleInferer(), 496 | key_train_metric=None, 497 | additional_metrics=None, 498 | amp=amp, 499 | event_names=None, 500 | event_to_attr=None, 501 | ) 502 | 503 | def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): 504 | """ 505 | Callback function for the Supervised Training processing logic of 1 506 | iteration in Ignite Engine. 507 | Return below items in a dictionary: 508 | - IMAGE: image Tensor data for model input, already moved 509 | to device. 510 | - LABEL: label Tensor data corresponding to the image, already 511 | moved to device. 512 | - PRED: prediction result of model. 513 | - LOSS: loss value computed by loss function. 514 | 515 | Args: 516 | engine: Ignite Engine, it can be a trainer, validator or evaluator. 517 | batchdata: input data for this iteration, usually can be dictionary 518 | or tuple of Tensor data. 519 | 520 | Raises: 521 | ValueError: When ``batchdata`` is None. 522 | 523 | """ 524 | if batchdata is None: 525 | raise ValueError("Must provide batch data for current iteration.") 526 | batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) 527 | if len(batch) == 2: 528 | inputs, targets = batch 529 | args: Tuple = () 530 | kwargs: Dict = {} 531 | else: 532 | inputs, targets, args, kwargs = batch 533 | engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} 534 | 535 | def _compute_pred_loss(): 536 | engine.state.output[Keys.PRED] = self.inferer( 537 | inputs, self.network, *args, **kwargs 538 | ) 539 | engine.fire_event(IterationEvents.FORWARD_COMPLETED) 540 | engine.state.output[Keys.LOSS] = self.loss_function( 541 | engine.state.output[Keys.PRED], targets 542 | ).mean() 543 | engine.fire_event(IterationEvents.LOSS_COMPLETED) 544 | 545 | self.network.train() 546 | self.network.zero_grad() 547 | self.optimizer.zero_grad() 548 | if self.amp and self.scaler is not None: 549 | with torch.cuda.amp.autocast(): 550 | _compute_pred_loss() 551 | self.scaler.scale(engine.state.output[Keys.LOSS]).backward( 552 | retain_graph=True 553 | ) 554 | engine.fire_event(IterationEvents.BACKWARD_COMPLETED) 555 | self.scaler.step(self.optimizer) 556 | self.scaler.update() 557 | else: 558 | _compute_pred_loss() 559 | engine.state.output[Keys.LOSS].backward(retain_graph=True) 560 | engine.fire_event(IterationEvents.BACKWARD_COMPLETED) 561 | self.optimizer.step() 562 | engine.fire_event(IterationEvents.MODEL_COMPLETED) 563 | return engine.state.output 564 | 565 | 566 | class DeepInversionFeatureHook: 567 | """ 568 | Implementation of the forward hook to track feature statistics and 569 | compute a loss on them. 570 | Will compute mean and variance, and will use l2 as a loss 571 | """ 572 | 573 | def __init__(self, module, bn_stats=None, name=None): 574 | self.hook = module.register_forward_hook(self.hook_fn) 575 | self.bn_stats = bn_stats 576 | self.name = name 577 | self.r_feature = None 578 | self.mean = None 579 | self.var = None 580 | 581 | def hook_fn(self, module, input, output): 582 | nch = input[0].shape[1] 583 | mean = input[0].mean([0, 2, 3]) 584 | var = ( 585 | input[0] 586 | .permute(1, 0, 2, 3) 587 | .contiguous() 588 | .view([nch, -1]) 589 | .var(1, unbiased=False) 590 | ) 591 | if self.bn_stats is None: 592 | var_feature = torch.norm(module.running_var.data - var, 2) 593 | mean_feature = torch.norm(module.running_mean.data - mean, 2) 594 | else: 595 | var_feature = torch.norm( 596 | torch.tensor( 597 | self.bn_stats[self.name + ".running_var"], device=input[0].device 598 | ) 599 | - var, 600 | 2, 601 | ) 602 | mean_feature = torch.norm( 603 | torch.tensor( 604 | self.bn_stats[self.name + ".running_mean"], device=input[0].device 605 | ) 606 | - mean, 607 | 2, 608 | ) 609 | 610 | rescale = 1.0 611 | self.r_feature = mean_feature + rescale * var_feature 612 | self.mean = mean 613 | self.var = var 614 | 615 | def close(self): 616 | self.hook.remove() 617 | -------------------------------------------------------------------------------- /FLGradientInversion/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | 12 | import json 13 | import re 14 | from copy import deepcopy 15 | 16 | import numpy as np 17 | import torch 18 | import torch.utils.data 19 | 20 | from fl_gradient_inversion import FLGradientInversion 21 | from torchvision_class import TorchVisionClassificationModel 22 | 23 | 24 | def run(cfg): 25 | """Run the gradient inversion attack. 26 | 27 | Args: 28 | cfg: Configuration dictionary containing the following keys used 29 | in to set up the attack. Should also contain the keys expected by 30 | FLGradientInversion's __call__() function. 31 | - model_name: Used to select the model aritechture, 32 | e.g. "resnet18". 33 | - num_classes: 34 | - pretrained: 35 | - checkpoint_file: 36 | - weights_file: 37 | - batchnorm_file: 38 | Returns: 39 | Reconstructed images. 40 | """ 41 | torch.backends.cudnn.deterministic = True 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | net = TorchVisionClassificationModel( 44 | model_name=cfg["model_name"], 45 | num_classes=cfg["num_classes"], 46 | pretrained=cfg["pretrained"], 47 | ) 48 | 49 | checkpoint_file = cfg["checkpoint_file"] 50 | add_weights = cfg["weights_file"] 51 | batchnorm_file = cfg["batchnorm_file"] 52 | input_parameters = [] 53 | updates = np.load(add_weights, allow_pickle=True)["weights"].item() 54 | update_sum = 0.0 55 | n_excluded = 0 56 | weights = [] 57 | if checkpoint_file: 58 | model_data = torch.load(checkpoint_file) 59 | if "model" in model_data.keys(): 60 | net.load_state_dict(model_data["model"]) 61 | else: 62 | net.load_state_dict(model_data) 63 | exclude_vars = None 64 | if exclude_vars: 65 | re_pattern = re.compile(exclude_vars) 66 | for name, _ in net.named_parameters(): 67 | if exclude_vars: 68 | if re_pattern.search(name): 69 | n_excluded += 1 70 | weights.append(0.0) 71 | else: 72 | weights.append(1.0) 73 | val = updates[name] 74 | update_sum += np.sum(np.abs(val)) 75 | val = torch.from_numpy(val).to(device) 76 | input_parameters.append(val) 77 | assert update_sum > 0.0, "All updates are zero!" 78 | model_bn = deepcopy(net).cuda() 79 | update_sum = 0.0 80 | new_state_dict = model_bn.state_dict() 81 | for n in updates.keys(): 82 | val = updates[n] 83 | update_sum += np.sum(np.abs(val)) 84 | new_state_dict[n] += torch.tensor( 85 | val, dtype=new_state_dict[n].dtype, device=new_state_dict[n].device 86 | ) 87 | model_bn.load_state_dict(new_state_dict) 88 | assert update_sum > 0.0, "All updates are zero!" 89 | n_bn_updated = 0 90 | global_state_dict = net.state_dict() 91 | if batchnorm_file: 92 | bn_momentum = 0.1 93 | print( 94 | f"Using full BN stats from {batchnorm_file} " 95 | f"with momentum {bn_momentum} ! \n" 96 | ) 97 | bn_stats = np.load(batchnorm_file, allow_pickle=True)["batchnorm"].item() 98 | for n in bn_stats.keys(): 99 | if "running" in n: 100 | xt = ( 101 | bn_stats[n] - (1 - bn_momentum) * global_state_dict[n].numpy() 102 | ) / bn_momentum 103 | n_bn_updated += 1 104 | bn_stats[n] = xt 105 | 106 | net = net.to(device) 107 | grad_lst = [] 108 | grad_lst_orig = np.load(add_weights, allow_pickle=True)["weights"].item() 109 | for name, _ in net.named_parameters(): 110 | val = torch.from_numpy(grad_lst_orig[name]).cuda() 111 | grad_lst.append([name, val]) 112 | grad_inversion_engine = FLGradientInversion( 113 | network=net, 114 | grad_lst=grad_lst, 115 | bn_stats=bn_stats, 116 | model_bn=model_bn, 117 | ) 118 | grad_inversion_engine(cfg) 119 | 120 | 121 | def main(): 122 | with open("./config/config_inversion.json", "r") as f: 123 | cfg = json.load(f) 124 | 125 | run(cfg) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /FLGradientInversion/orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/orig.png -------------------------------------------------------------------------------- /FLGradientInversion/prior/prior_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/prior/prior_1.jpg -------------------------------------------------------------------------------- /FLGradientInversion/recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/recon.png -------------------------------------------------------------------------------- /FLGradientInversion/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | torchvision==0.14.0 3 | pytorch-ignite==0.4.10 4 | numpy 5 | Pillow 6 | monai==1.1.0 7 | matplotlib -------------------------------------------------------------------------------- /FLGradientInversion/torchvision_class.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | from monai.utils import optional_import 11 | 12 | models, _ = optional_import("torchvision.models") 13 | 14 | 15 | class TorchVisionClassificationModel(torch.nn.Module): 16 | """ 17 | Customize TorchVision models to replace final linear/fully-connected layer to fit number of classes. 18 | 19 | Args: 20 | model_name: fully connected layer at the end from https://pytorch.org/vision/stable/models.html, e.g. 21 | ``resnet18`` (default), ``alexnet``, ``vgg16``, etc. 22 | num_classes: number of classes for the last classification layer. Default to 1. 23 | pretrained: whether to use the imagenet pretrained weights. Default to False. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model_name: str = "resnet18", 29 | num_classes: int = 1, 30 | pretrained: bool = False, 31 | bias=True, 32 | ): 33 | super().__init__() 34 | self.model = getattr(models, model_name)(pretrained=pretrained) 35 | if "fc" in dir(self.model): 36 | self.model.fc = torch.nn.Linear( 37 | in_features=self.model.fc.in_features, 38 | out_features=num_classes, 39 | bias=bias, 40 | ) 41 | elif "classifier" in dir(self.model) and "vgg" not in model_name: 42 | self.model.classifier = torch.nn.Linear( 43 | in_features=self.model.classifier.in_features, 44 | out_features=num_classes, 45 | bias=bias, 46 | ) 47 | elif "vgg" in model_name: 48 | self.model.classifier[-1] = torch.nn.Linear( 49 | in_features=self.model.classifier[-1].in_features, 50 | out_features=num_classes, 51 | bias=bias, 52 | ) 53 | else: 54 | raise ValueError( 55 | f"Model ['{model_name}'] does not have a supported classifier attribute." 56 | ) 57 | 58 | def forward(self, x): 59 | return self.model(x) 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Nvidia Source Code License-NC 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 9 | this License. 10 | 11 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate 12 | array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or 13 | provided by Nvidia or its affiliates. 14 | 15 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. 16 | copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that 17 | remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 18 | 19 | Works, including the Software, are “made available” under this License by including in or with the Work either 20 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 21 | 22 | 2. License Grants 23 | 24 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 25 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, 26 | publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 27 | 28 | 3. Limitations 29 | 30 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include 31 | a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, 32 | trademark, or attribution notices that are present in the Work. 33 | 34 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 35 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation 36 | in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to 37 | Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will 38 | continue to apply to the Work itself. 39 | 40 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. 41 | The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or 42 | non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 43 | 44 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 45 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 46 | your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate 47 | immediately. 48 | 49 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or 50 | trademarks, except as necessary to reproduce the notices described in this License. 51 | 52 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants 53 | in Sections 2.1 and 2.2) will terminate immediately. 54 | 55 | 4. Disclaimer of Warranty. 56 | 57 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 58 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 59 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 60 | 61 | 5. Limitation of Liability. 62 | 63 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), 64 | CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, 65 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 66 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR 67 | MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY 68 | OF SUCH DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion 3 | 4 | This repository is the official PyTorch implementation of [Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion](https://arxiv.org/abs/1912.08795) presented at CVPR 2020. 5 | 6 | The code will help to invert images from models of torchvision (pretrained on ImageNet), and run the images over another model to check generalization. We plan to update repo with CIFAR10 examples and teacher-student training. 7 | 8 | Useful links:
9 | * [Camera Ready PDF](https://drive.google.com/file/d/1jg4o458y70aCqUPRklMEy6dOGlZ0qMde/view?usp=sharing)
10 | * [ArXiv Full](https://arxiv.org/pdf/1912.08795.pdf)
11 | * [Dataset - Synthesized ImageNet](https://drive.google.com/open?id=1AXCW6_E_Qtr5qyb9jygGaLub13gQo10c): from ResNet50v1.5, ~2GB, organized by classes, ~140k images. Were used in Section 4.4 (Data-free Knowledge Transfer), best viewed in gThumb. 12 | 13 | ![Teaser](example_logs/teaser.png "Teaser") 14 | 15 | ## License 16 | 17 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 18 | 19 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE 20 | 21 | ## Updates 22 | 23 | - 2020 July 7. Added CIFAR10 inversion result for ResNet34 in the folder cifar10. Code on knowledge distillation will follow soon. 24 | - 2020 June 16. Added a new scaling factor `first_bn_multiplier` for first BN layer. This improves fidelity. 25 | 26 | ## Requirements 27 | 28 | Code was tested in virtual environment with Python 3.6. Install requirements: 29 | 30 | ```setup 31 | pip install torch==1.4.0 32 | pip install torchvision==0.5.0 33 | pip install numpy 34 | pip install Pillow 35 | ``` 36 | 37 | Additionally install APEX library for FP16 support (2x less memory, 2x faster): [Installing NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start) 38 | 39 | Provided code was originally designed to invert ResNet50v1.5 model trained for 90 epochs that achieves 77.26% top-1 on ImageNet. We are not able to share the model, but anyone can train it here: [ResNet50v1.5](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5). 40 | Code works well for the default ResNet50 from torchvision package. 41 | 42 | Code was tested on NVIDIA V100 GPU and Titan X Pascal. 43 | 44 | ## Running the code 45 | 46 | This snippet will generate 84 images by inverting resnet50 model from torchvision package. 47 | 48 | `python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25` 49 | 50 | Arguments: 51 | 52 | - `bs` - batch size, should be close to original batch size during training, but not necessary. 53 | - `lr` - learning rate for the optimizer of input tensor for model inversion. 54 | - `do_flip` - will do random flipping between iterations 55 | - `exp_name` - name of the experiment, will create folder with this name in `./generations/` where intermediate generations will be stored after 100 iterations 56 | - `r_feature` - coefficient for feature distribution regularization, might need adjustment for other networks 57 | - `arch_name` - name of the network architecture, should be one of pretrained models from torch vision package: `resnet50`, `resnet18`, `mobilenet_v2` etc. 58 | - `fp16` - enables FP16 training if needed, will use FP16 training via APEX AMP (O2 level) 59 | - `verifier` - enables checking accuracy of generated images with another network (def `mobilenet_v2`) network after each 100 iterations. 60 | Useful to observe generalizability of generated images. 61 | - `setting_id` - settings for optimization: 0 - multi resolution scheme, 1 - 2k iterations full resolution, 2 - 20k iterations (the closes to ResNet50 experiments in the paper). Recommended to use setting_id={0, 1} 62 | - `adi_scale` - competition coefficient. With positive value will lead to images that are good for the original model, but bad for verifier. Value 0.2 was used in the paper. 63 | - `random_label` - randomly select classes for inversion. Without this argument the code will generate hand picked classes. 64 | 65 | After 3k iterations (~6 mins on NVIDIA V100) generation is done: `Verifier accuracy: 91.6...%` (experiment with >98% verifier accuracy can be found `/example_logs`). We generated images by inverting vanilla ResNet50 (not trained for image generation) and classification accuracy by MobileNetv2 is >90%. A grid of images look like (from `/final_images/`, reduced quality due to JPEG compression. ) 66 | ![Generated grid of images](example_logs/fp32_set0_rn50_first_bn_scaled.jpg "ResNet50 Inverted images") 67 | 68 | Optimization is sensitive to hyper-parameters. Try local tunings for your setups/applications. Try to change the r_feature coefficient, l2 regularization, betas of Adam optimizer (beta=0 work well). Keep looking at `loss_r_feature` as it indicates how close feature statistics are to the training distribution. 69 | 70 | Reduce batch size if out of memory or without FP16 optimization. In the paper, we used batch size of 152, and larger batch size is preferred. This code will generate images from 41 hand picked classes. To randomize the target classes, simply use argument `--random_label`. 71 | 72 | Examples of running code with different arguments and resulting images can be found at `/example_logs/`. 73 | 74 | Check if you can invert other architectures, or even apply to other applications (keypoints, detection etc.). 75 | Method has a room for improvement: 76 | (a) improving the loss for feature regularization (we used MSE in paper but that may not be ideal for distribution matching), 77 | (b) making it even faster, 78 | (c) generating images for which multiple models are confident, 79 | (d) increasing diversity. 80 | 81 | Share your most exciting images at Twitter with hashtag [#Deepinversion](https://twitter.com/hashtag/deepinversion?src=hash) and [#DeepInvert](https://twitter.com/hashtag/DeepInvert?src=hashtag_click). 82 | 83 | ## Citation 84 | 85 | ```bibtex 86 | @inproceedings{yin2020dreaming, 87 | title = {Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion}, 88 | author = {Yin, Hongxu and Molchanov, Pavlo and Alvarez, Jose M. and Li, Zhizhong and Mallya, Arun and Hoiem, Derek and Jha, Niraj K and Kautz, Jan}, 89 | booktitle = {The IEEE/CVF Conf. Computer Vision and Pattern Recognition (CVPR)}, 90 | month = June, 91 | year = {2020} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /cifar10/README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | # CIFAR10 experiments 3 | 4 | ## License 5 | 6 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 7 | 8 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE 9 | 10 | 11 | ## Requirements 12 | 13 | Code was tested in virtual environment with Python 3.7. Install requirements: 14 | 15 | ```setup 16 | pip install torch==1.4.0 torchvision==0.5.0 numpy Pillow 17 | ``` 18 | 19 | Additionally install APEX library for FP16 support (2x less memory and 2x faster): [Installing NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start) 20 | 21 | For CIFAR10 we will first need to train a teacher model, for comparison reasons we choose ResNet34 from DAFL method. 22 | Instruction for training teacher model can be found [here](https://github.com/huawei-noah/Data-Efficient-Model-Compression/tree/master/DAFL). 23 | Our model achieves 95.42% top1 accuracy on validation set. 24 | 25 | Running inversion with parameters from the paper: 26 | ``` 27 | python deepinversion_cifar10.py --bs=256 --teacher_weights=./checkpoint/teacher_resnet34_only.weights\ 28 | --r_feature_weight=10 --di_lr=0.05 --exp_descr="paper_parameters" 29 | ``` 30 | 31 | Better reconstructed images can be obtained by tuning parameters, for example increasing total variation coefficient: `--di_var_scale=0.001`. 32 | ``` 33 | python deepinversion_cifar10.py --bs=256 --teacher_weights=./checkpoint/teacher_resnet34_only.weights\ 34 | --r_feature_weight=10 --di_lr=0.1 --exp_descr="paper_parameters_better" --di_var_scale=0.001 --di_l2_scale=0.0 35 | ``` 36 | 37 | ![Resulting batch](images/better_last.png "Resulting batch") -------------------------------------------------------------------------------- /cifar10/deepinversion_cifar10.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ResNet model inversion for CIFAR10. 3 | 4 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 5 | 6 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE 7 | ''' 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import argparse 14 | import random 15 | import torch 16 | import torch.nn as nn 17 | # import torch.nn.parallel 18 | import torch.backends.cudnn as cudnn 19 | import torch.optim as optim 20 | # import torch.utils.data 21 | import torch.nn.functional as F 22 | import torchvision 23 | import torchvision.transforms as transforms 24 | import torchvision.utils as vutils 25 | import torchvision.transforms as transforms 26 | 27 | import numpy as np 28 | import os 29 | import glob 30 | import collections 31 | 32 | from resnet_cifar import ResNet34, ResNet18 33 | 34 | try: 35 | from apex.parallel import DistributedDataParallel as DDP 36 | from apex import amp, optimizers 37 | USE_APEX = True 38 | except ImportError: 39 | print("Please install apex from https://www.github.com/nvidia/apex to run this example.") 40 | print("will attempt to run without it") 41 | USE_APEX = False 42 | 43 | #provide intermeiate information 44 | debug_output = False 45 | debug_output = True 46 | 47 | 48 | class DeepInversionFeatureHook(): 49 | ''' 50 | Implementation of the forward hook to track feature statistics and compute a loss on them. 51 | Will compute mean and variance, and will use l2 as a loss 52 | ''' 53 | 54 | def __init__(self, module): 55 | self.hook = module.register_forward_hook(self.hook_fn) 56 | 57 | def hook_fn(self, module, input, output): 58 | # hook co compute deepinversion's feature distribution regularization 59 | nch = input[0].shape[1] 60 | 61 | mean = input[0].mean([0, 2, 3]) 62 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) 63 | 64 | # forcing mean and variance to match between two distributions 65 | # other ways might work better, e.g. KL divergence 66 | r_feature = torch.norm(module.running_var.data.type(var.type()) - var, 2) + torch.norm( 67 | module.running_mean.data.type(var.type()) - mean, 2) 68 | 69 | self.r_feature = r_feature 70 | # must have no output 71 | 72 | def close(self): 73 | self.hook.remove() 74 | 75 | def get_images(net, bs=256, epochs=1000, idx=-1, var_scale=0.00005, 76 | net_student=None, prefix=None, competitive_scale=0.01, train_writer = None, global_iteration=None, 77 | use_amp=False, 78 | optimizer = None, inputs = None, bn_reg_scale = 0.0, random_labels = False, l2_coeff=0.0): 79 | ''' 80 | Function returns inverted images from the pretrained model, parameters are tight to CIFAR dataset 81 | args in: 82 | net: network to be inverted 83 | bs: batch size 84 | epochs: total number of iterations to generate inverted images, training longer helps a lot! 85 | idx: an external flag for printing purposes: only print in the first round, set as -1 to disable 86 | var_scale: the scaling factor for variance loss regularization. this may vary depending on bs 87 | larger - more blurred but less noise 88 | net_student: model to be used for Adaptive DeepInversion 89 | prefix: defines the path to store images 90 | competitive_scale: coefficient for Adaptive DeepInversion 91 | train_writer: tensorboardX object to store intermediate losses 92 | global_iteration: indexer to be used for tensorboard 93 | use_amp: boolean to indicate usage of APEX AMP for FP16 calculations - twice faster and less memory on TensorCores 94 | optimizer: potimizer to be used for model inversion 95 | inputs: data place holder for optimization, will be reinitialized to noise 96 | bn_reg_scale: weight for r_feature_regularization 97 | random_labels: sample labels from random distribution or use columns of the same class 98 | l2_coeff: coefficient for L2 loss on input 99 | return: 100 | A tensor on GPU with shape (bs, 3, 32, 32) for CIFAR 101 | ''' 102 | 103 | kl_loss = nn.KLDivLoss(reduction='batchmean').cuda() 104 | 105 | # preventing backpropagation through student for Adaptive DeepInversion 106 | net_student.eval() 107 | 108 | best_cost = 1e6 109 | 110 | # initialize gaussian inputs 111 | inputs.data = torch.randn((bs, 3, 32, 32), requires_grad=True, device='cuda') 112 | # if use_amp: 113 | # inputs.data = inputs.data.half() 114 | 115 | # set up criteria for optimization 116 | criterion = nn.CrossEntropyLoss() 117 | 118 | optimizer.state = collections.defaultdict(dict) # Reset state of optimizer 119 | 120 | # target outputs to generate 121 | if random_labels: 122 | targets = torch.LongTensor([random.randint(0,9) for _ in range(bs)]).to('cuda') 123 | else: 124 | targets = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 25 + [0, 1, 2, 3, 4, 5]).to('cuda') 125 | 126 | ## Create hooks for feature statistics catching 127 | loss_r_feature_layers = [] 128 | for module in net.modules(): 129 | if isinstance(module, nn.BatchNorm2d): 130 | loss_r_feature_layers.append(DeepInversionFeatureHook(module)) 131 | 132 | # setting up the range for jitter 133 | lim_0, lim_1 = 2, 2 134 | 135 | for epoch in range(epochs): 136 | # apply random jitter offsets 137 | off1 = random.randint(-lim_0, lim_0) 138 | off2 = random.randint(-lim_1, lim_1) 139 | inputs_jit = torch.roll(inputs, shifts=(off1,off2), dims=(2,3)) 140 | 141 | # foward with jit images 142 | optimizer.zero_grad() 143 | net.zero_grad() 144 | outputs = net(inputs_jit) 145 | loss = criterion(outputs, targets) 146 | loss_target = loss.item() 147 | 148 | # competition loss, Adaptive DeepInvesrion 149 | if competitive_scale != 0.0: 150 | net_student.zero_grad() 151 | outputs_student = net_student(inputs_jit) 152 | T = 3.0 153 | 154 | if 1: 155 | # jensen shanon divergence: 156 | # another way to force KL between negative probabilities 157 | P = F.softmax(outputs_student / T, dim=1) 158 | Q = F.softmax(outputs / T, dim=1) 159 | M = 0.5 * (P + Q) 160 | 161 | P = torch.clamp(P, 0.01, 0.99) 162 | Q = torch.clamp(Q, 0.01, 0.99) 163 | M = torch.clamp(M, 0.01, 0.99) 164 | eps = 0.0 165 | # loss_verifier_cig = 0.5 * kl_loss(F.log_softmax(outputs_verifier / T, dim=1), M) + 0.5 * kl_loss(F.log_softmax(outputs/T, dim=1), M) 166 | loss_verifier_cig = 0.5 * kl_loss(torch.log(P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M) 167 | # JS criteria - 0 means full correlation, 1 - means completely different 168 | loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0, 1.0) 169 | 170 | loss = loss + competitive_scale * loss_verifier_cig 171 | 172 | # apply total variation regularization 173 | diff1 = inputs_jit[:,:,:,:-1] - inputs_jit[:,:,:,1:] 174 | diff2 = inputs_jit[:,:,:-1,:] - inputs_jit[:,:,1:,:] 175 | diff3 = inputs_jit[:,:,1:,:-1] - inputs_jit[:,:,:-1,1:] 176 | diff4 = inputs_jit[:,:,:-1,:-1] - inputs_jit[:,:,1:,1:] 177 | loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 178 | loss = loss + var_scale*loss_var 179 | 180 | # R_feature loss 181 | loss_distr = sum([mod.r_feature for mod in loss_r_feature_layers]) 182 | loss = loss + bn_reg_scale*loss_distr # best for noise before BN 183 | 184 | # l2 loss 185 | if 1: 186 | loss = loss + l2_coeff * torch.norm(inputs_jit, 2) 187 | 188 | if debug_output and epoch % 200==0: 189 | print(f"It {epoch}\t Losses: total: {loss.item():3.3f},\ttarget: {loss_target:3.3f} \tR_feature_loss unscaled:\t {loss_distr.item():3.3f}") 190 | vutils.save_image(inputs.data.clone(), 191 | './{}/output_{}.png'.format(prefix, epoch//200), 192 | normalize=True, scale_each=True, nrow=10) 193 | 194 | if best_cost > loss.item(): 195 | best_cost = loss.item() 196 | best_inputs = inputs.data 197 | 198 | # backward pass 199 | if use_amp: 200 | with amp.scale_loss(loss, optimizer) as scaled_loss: 201 | scaled_loss.backward() 202 | else: 203 | loss.backward() 204 | 205 | optimizer.step() 206 | 207 | outputs=net(best_inputs) 208 | _, predicted_teach = outputs.max(1) 209 | 210 | outputs_student=net_student(best_inputs) 211 | _, predicted_std = outputs_student.max(1) 212 | 213 | if idx == 0: 214 | print('Teacher correct out of {}: {}, loss at {}'.format(bs, predicted_teach.eq(targets).sum().item(), criterion(outputs, targets).item())) 215 | print('Student correct out of {}: {}, loss at {}'.format(bs, predicted_std.eq(targets).sum().item(), criterion(outputs_student, targets).item())) 216 | 217 | name_use = "best_images" 218 | if prefix is not None: 219 | name_use = prefix + name_use 220 | next_batch = len(glob.glob("./%s/*.png" % name_use)) // 1 221 | 222 | vutils.save_image(best_inputs[:20].clone(), 223 | './{}/output_{}.png'.format(name_use, next_batch), 224 | normalize=True, scale_each = True, nrow=10) 225 | 226 | if train_writer is not None: 227 | train_writer.add_scalar('gener_teacher_criteria', criterion(outputs, targets), global_iteration) 228 | train_writer.add_scalar('gener_student_criteria', criterion(outputs_student, targets), global_iteration) 229 | 230 | train_writer.add_scalar('gener_teacher_acc', predicted_teach.eq(targets).sum().item() / bs, global_iteration) 231 | train_writer.add_scalar('gener_student_acc', predicted_std.eq(targets).sum().item() / bs, global_iteration) 232 | 233 | train_writer.add_scalar('gener_loss_total', loss.item(), global_iteration) 234 | train_writer.add_scalar('gener_loss_var', (var_scale*loss_var).item(), global_iteration) 235 | 236 | net_student.train() 237 | 238 | return best_inputs 239 | 240 | 241 | def test(): 242 | print('==> Teacher validation') 243 | net_teacher.eval() 244 | test_loss = 0 245 | correct = 0 246 | total = 0 247 | 248 | with torch.no_grad(): 249 | for batch_idx, (inputs, targets) in enumerate(testloader): 250 | inputs, targets = inputs.to(device), targets.to(device) 251 | outputs = net_teacher(inputs) 252 | loss = criterion(outputs, targets) 253 | 254 | test_loss += loss.item() 255 | _, predicted = outputs.max(1) 256 | total += targets.size(0) 257 | correct += predicted.eq(targets).sum().item() 258 | 259 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' 260 | % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 261 | 262 | 263 | if __name__ == "__main__": 264 | 265 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 DeepInversion') 266 | parser.add_argument('--bs', default=256, type=int, help='batch size') 267 | parser.add_argument('--iters_mi', default=2000, type=int, help='number of iterations for model inversion') 268 | parser.add_argument('--cig_scale', default=0.0, type=float, help='competition score') 269 | parser.add_argument('--di_lr', default=0.1, type=float, help='lr for deep inversion') 270 | parser.add_argument('--di_var_scale', default=2.5e-5, type=float, help='TV L2 regularization coefficient') 271 | parser.add_argument('--di_l2_scale', default=0.0, type=float, help='L2 regularization coefficient') 272 | parser.add_argument('--r_feature_weight', default=1e2, type=float, help='weight for BN regularization statistic') 273 | parser.add_argument('--amp', action='store_true', help='use APEX AMP O1 acceleration') 274 | parser.add_argument('--exp_descr', default="try1", type=str, help='name to be added to experiment name') 275 | parser.add_argument('--teacher_weights', default="'./checkpoint/teacher_resnet34_only.weights'", type=str, help='path to load weights of the model') 276 | 277 | args = parser.parse_args() 278 | 279 | print("loading resnet34") 280 | 281 | net_teacher = ResNet34() 282 | net_student = ResNet18() 283 | 284 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 285 | 286 | net_student = net_student.to(device) 287 | net_teacher = net_teacher.to(device) 288 | 289 | criterion = nn.CrossEntropyLoss() 290 | 291 | # place holder for inputs 292 | data_type = torch.half if args.amp else torch.float 293 | inputs = torch.randn((args.bs, 3, 32, 32), requires_grad=True, device='cuda', dtype=data_type) 294 | 295 | optimizer_di = optim.Adam([inputs], lr=args.di_lr) 296 | 297 | if args.amp: 298 | opt_level = "O1" 299 | loss_scale = 'dynamic' 300 | 301 | [net_student, net_teacher], optimizer_di = amp.initialize( 302 | [net_student, net_teacher], optimizer_di, 303 | opt_level=opt_level, 304 | loss_scale=loss_scale) 305 | 306 | checkpoint = torch.load(args.teacher_weights) 307 | net_teacher.load_state_dict(checkpoint) 308 | net_teacher.eval() #important, otherwise generated images will be non natural 309 | if args.amp: 310 | # need to do this trick for FP16 support of batchnorms 311 | net_teacher.train() 312 | for module in net_teacher.modules(): 313 | if isinstance(module, nn.BatchNorm2d): 314 | module.eval().half() 315 | 316 | cudnn.benchmark = True 317 | 318 | 319 | batch_idx = 0 320 | prefix = "runs/data_generation/"+args.exp_descr+"/" 321 | 322 | for create_folder in [prefix, prefix+"/best_images/"]: 323 | if not os.path.exists(create_folder): 324 | os.makedirs(create_folder) 325 | 326 | if 0: 327 | # loading 328 | transform_test = transforms.Compose([ 329 | transforms.ToTensor(), 330 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 331 | ]) 332 | 333 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 334 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=True, num_workers=6, 335 | drop_last=True) 336 | # Checking teacher accuracy 337 | print("Checking teacher accuracy") 338 | test() 339 | 340 | 341 | train_writer = None # tensorboard writter 342 | global_iteration = 0 343 | 344 | print("Starting model inversion") 345 | 346 | inputs = get_images(net=net_teacher, bs=args.bs, epochs=args.iters_mi, idx=batch_idx, 347 | net_student=net_student, prefix=prefix, competitive_scale=args.cig_scale, 348 | train_writer=train_writer, global_iteration=global_iteration, use_amp=args.amp, 349 | optimizer=optimizer_di, inputs=inputs, bn_reg_scale=args.r_feature_weight, 350 | var_scale=args.di_var_scale, random_labels=False, l2_coeff=args.di_l2_scale) 351 | -------------------------------------------------------------------------------- /cifar10/images/better_last.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/cifar10/images/better_last.png -------------------------------------------------------------------------------- /cifar10/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2019.07.24-Changed output of forward function 2 | # Huawei Technologies Co., Ltd. 3 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py 4 | # for comparison with DAFL 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(self.expansion*planes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(self.expansion*planes) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(self.conv1(x))) 57 | out = F.relu(self.bn2(self.conv2(out))) 58 | out = self.bn3(self.conv3(out)) 59 | out += self.shortcut(x) 60 | out = F.relu(out) 61 | return out 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, num_blocks, num_classes=10): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | 69 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(64) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x, out_feature=False): 86 | x = self.conv1(x) 87 | 88 | x = self.bn1(x) 89 | out = F.relu(x) 90 | 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = self.layer4(out) 95 | out = F.avg_pool2d(out, 4) 96 | feature = out.view(out.size(0), -1) 97 | out = self.linear(feature) 98 | if out_feature == False: 99 | return out 100 | else: 101 | return out,feature 102 | 103 | 104 | def ResNet18(num_classes=10): 105 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 106 | 107 | def ResNet34(num_classes=10): 108 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 109 | 110 | def ResNet50(num_classes=10): 111 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 112 | 113 | def ResNet101(num_classes=10): 114 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 115 | 116 | def ResNet152(num_classes=10): 117 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 118 | 119 | -------------------------------------------------------------------------------- /deepinversion.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 3 | # Nvidia Source Code License-NC 4 | # Official PyTorch implementation of CVPR2020 paper 5 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion 6 | # Hongxu Yin, Pavlo Molchanov, Zhizhong Li, Jose M. Alvarez, Arun Mallya, Derek 7 | # Hoiem, Niraj K. Jha, and Jan Kautz 8 | # -------------------------------------------------------- 9 | 10 | from __future__ import division, print_function 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import unicode_literals 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import collections 19 | import torch.cuda.amp as amp 20 | import random 21 | import torch 22 | import torchvision.utils as vutils 23 | from PIL import Image 24 | import numpy as np 25 | 26 | from utils.utils import lr_cosine_policy, lr_policy, beta_policy, mom_cosine_policy, clip, denormalize, create_folder 27 | 28 | 29 | class DeepInversionFeatureHook(): 30 | ''' 31 | Implementation of the forward hook to track feature statistics and compute a loss on them. 32 | Will compute mean and variance, and will use l2 as a loss 33 | ''' 34 | def __init__(self, module): 35 | self.hook = module.register_forward_hook(self.hook_fn) 36 | 37 | def hook_fn(self, module, input, output): 38 | # hook co compute deepinversion's feature distribution regularization 39 | nch = input[0].shape[1] 40 | mean = input[0].mean([0, 2, 3]) 41 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) 42 | 43 | #forcing mean and variance to match between two distributions 44 | #other ways might work better, i.g. KL divergence 45 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( 46 | module.running_mean.data - mean, 2) 47 | 48 | self.r_feature = r_feature 49 | # must have no output 50 | 51 | def close(self): 52 | self.hook.remove() 53 | 54 | 55 | def get_image_prior_losses(inputs_jit): 56 | # COMPUTE total variation regularization loss 57 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:] 58 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :] 59 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:] 60 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:] 61 | 62 | loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 63 | loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + ( 64 | diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean() 65 | loss_var_l1 = loss_var_l1 * 255.0 66 | return loss_var_l1, loss_var_l2 67 | 68 | 69 | class DeepInversionClass(object): 70 | def __init__(self, bs=84, 71 | use_fp16=True, net_teacher=None, path="./gen_images/", 72 | final_data_path="/gen_images_final/", 73 | parameters=dict(), 74 | setting_id=0, 75 | jitter=30, 76 | criterion=None, 77 | coefficients=dict(), 78 | network_output_function=lambda x: x, 79 | hook_for_display = None): 80 | ''' 81 | :param bs: batch size per GPU for image generation 82 | :param use_fp16: use FP16 (or APEX AMP) for model inversion, uses less memory and is faster for GPUs with Tensor Cores 83 | :parameter net_teacher: Pytorch model to be inverted 84 | :param path: path where to write temporal images and data 85 | :param final_data_path: path to write final images into 86 | :param parameters: a dictionary of control parameters: 87 | "resolution": input image resolution, single value, assumed to be a square, 224 88 | "random_label" : for classification initialize target to be random values 89 | "start_noise" : start from noise, def True, other options are not supported at this time 90 | "detach_student": if computing Adaptive DI, should we detach student? 91 | :param setting_id: predefined settings for optimization: 92 | 0 - will run low resolution optimization for 1k and then full resolution for 1k; 93 | 1 - will run optimization on high resolution for 2k 94 | 2 - will run optimization on high resolution for 20k 95 | 96 | :param jitter: amount of random shift applied to image at every iteration 97 | :param coefficients: dictionary with parameters and coefficients for optimization. 98 | keys: 99 | "r_feature" - coefficient for feature distribution regularization 100 | "tv_l1" - coefficient for total variation L1 loss 101 | "tv_l2" - coefficient for total variation L2 loss 102 | "l2" - l2 penalization weight 103 | "lr" - learning rate for optimization 104 | "main_loss_multiplier" - coefficient for the main loss optimization 105 | "adi_scale" - coefficient for Adaptive DeepInversion, competition, def =0 means no competition 106 | network_output_function: function to be applied to the output of the network to get the output 107 | hook_for_display: function to be executed at every print/save call, useful to check accuracy of verifier 108 | ''' 109 | 110 | print("Deep inversion class generation") 111 | # for reproducibility 112 | torch.manual_seed(torch.cuda.current_device()) 113 | 114 | self.net_teacher = net_teacher 115 | 116 | if "resolution" in parameters.keys(): 117 | self.image_resolution = parameters["resolution"] 118 | self.random_label = parameters["random_label"] 119 | self.start_noise = parameters["start_noise"] 120 | self.detach_student = parameters["detach_student"] 121 | self.do_flip = parameters["do_flip"] 122 | self.store_best_images = parameters["store_best_images"] 123 | else: 124 | self.image_resolution = 224 125 | self.random_label = False 126 | self.start_noise = True 127 | self.detach_student = False 128 | self.do_flip = True 129 | self.store_best_images = False 130 | 131 | self.setting_id = setting_id 132 | self.bs = bs # batch size 133 | self.use_fp16 = use_fp16 134 | self.save_every = 100 135 | self.jitter = jitter 136 | self.criterion = criterion 137 | self.network_output_function = network_output_function 138 | do_clip = True 139 | 140 | if "r_feature" in coefficients: 141 | self.bn_reg_scale = coefficients["r_feature"] 142 | self.first_bn_multiplier = coefficients["first_bn_multiplier"] 143 | self.var_scale_l1 = coefficients["tv_l1"] 144 | self.var_scale_l2 = coefficients["tv_l2"] 145 | self.l2_scale = coefficients["l2"] 146 | self.lr = coefficients["lr"] 147 | self.main_loss_multiplier = coefficients["main_loss_multiplier"] 148 | self.adi_scale = coefficients["adi_scale"] 149 | else: 150 | print("Provide a dictionary with ") 151 | 152 | self.num_generations = 0 153 | self.final_data_path = final_data_path 154 | 155 | ## Create folders for images and logs 156 | prefix = path 157 | self.prefix = prefix 158 | 159 | local_rank = torch.cuda.current_device() 160 | if local_rank==0: 161 | create_folder(prefix) 162 | create_folder(prefix + "/best_images/") 163 | create_folder(self.final_data_path) 164 | # save images to folders 165 | # for m in range(1000): 166 | # create_folder(self.final_data_path + "/s{:03d}".format(m)) 167 | 168 | ## Create hooks for feature statistics 169 | self.loss_r_feature_layers = [] 170 | 171 | for module in self.net_teacher.modules(): 172 | if isinstance(module, nn.BatchNorm2d): 173 | self.loss_r_feature_layers.append(DeepInversionFeatureHook(module)) 174 | 175 | self.hook_for_display = None 176 | if hook_for_display is not None: 177 | self.hook_for_display = hook_for_display 178 | 179 | def get_images(self, net_student=None, targets=None): 180 | print("get_images call") 181 | 182 | net_teacher = self.net_teacher 183 | use_fp16 = self.use_fp16 184 | save_every = self.save_every 185 | 186 | kl_loss = nn.KLDivLoss(reduction='batchmean').cuda() 187 | local_rank = torch.cuda.current_device() 188 | best_cost = 1e4 189 | criterion = self.criterion 190 | 191 | # setup target labels 192 | if targets is None: 193 | #only works for classification now, for other tasks need to provide target vector 194 | targets = torch.LongTensor([random.randint(0, 999) for _ in range(self.bs)]).to('cuda') 195 | if not self.random_label: 196 | # preselected classes, good for ResNet50v1.5 197 | targets = [1, 933, 946, 980, 25, 63, 92, 94, 107, 985, 151, 154, 207, 250, 270, 277, 283, 292, 294, 309, 198 | 311, 199 | 325, 340, 360, 386, 402, 403, 409, 530, 440, 468, 417, 590, 670, 817, 762, 920, 949, 963, 200 | 967, 574, 487] 201 | 202 | targets = torch.LongTensor(targets * (int(self.bs / len(targets)))).to('cuda') 203 | 204 | img_original = self.image_resolution 205 | 206 | data_type = torch.half if use_fp16 else torch.float 207 | inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda', 208 | dtype=data_type) 209 | pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2) 210 | 211 | if self.setting_id==0: 212 | skipfirst = False 213 | else: 214 | skipfirst = True 215 | 216 | iteration = 0 217 | for lr_it, lower_res in enumerate([2, 1]): 218 | if lr_it==0: 219 | iterations_per_layer = 2000 220 | else: 221 | iterations_per_layer = 1000 if not skipfirst else 2000 222 | if self.setting_id == 2: 223 | iterations_per_layer = 20000 224 | 225 | if lr_it==0 and skipfirst: 226 | continue 227 | 228 | lim_0, lim_1 = self.jitter // lower_res, self.jitter // lower_res 229 | 230 | if self.setting_id == 0: 231 | #multi resolution, 2k iterations with low resolution, 1k at normal, ResNet50v1.5 works the best, ResNet50 is ok 232 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps = 1e-8) 233 | do_clip = True 234 | elif self.setting_id == 1: 235 | #2k normal resolultion, for ResNet50v1.5; Resnet50 works as well 236 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps = 1e-8) 237 | do_clip = True 238 | elif self.setting_id == 2: 239 | #20k normal resolution the closes to the paper experiments for ResNet50 240 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.9, 0.999], eps = 1e-8) 241 | do_clip = False 242 | 243 | if use_fp16: 244 | static_loss_scale = 256 245 | static_loss_scale = "dynamic" 246 | _, optimizer = amp.initialize([], optimizer, opt_level="O2", loss_scale=static_loss_scale) 247 | 248 | lr_scheduler = lr_cosine_policy(self.lr, 100, iterations_per_layer) 249 | 250 | for iteration_loc in range(iterations_per_layer): 251 | iteration += 1 252 | # learning rate scheduling 253 | lr_scheduler(optimizer, iteration_loc, iteration_loc) 254 | 255 | # perform downsampling if needed 256 | if lower_res!=1: 257 | inputs_jit = pooling_function(inputs) 258 | else: 259 | inputs_jit = inputs 260 | 261 | # apply random jitter offsets 262 | off1 = random.randint(-lim_0, lim_0) 263 | off2 = random.randint(-lim_1, lim_1) 264 | inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3)) 265 | 266 | # Flipping 267 | flip = random.random() > 0.5 268 | if flip and self.do_flip: 269 | inputs_jit = torch.flip(inputs_jit, dims=(3,)) 270 | 271 | # forward pass 272 | optimizer.zero_grad() 273 | net_teacher.zero_grad() 274 | 275 | outputs = net_teacher(inputs_jit) 276 | outputs = self.network_output_function(outputs) 277 | 278 | # R_cross classification loss 279 | loss = criterion(outputs, targets) 280 | 281 | # R_prior losses 282 | loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit) 283 | 284 | # R_feature loss 285 | rescale = [self.first_bn_multiplier] + [1. for _ in range(len(self.loss_r_feature_layers)-1)] 286 | loss_r_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers)]) 287 | 288 | # R_ADI 289 | loss_verifier_cig = torch.zeros(1) 290 | if self.adi_scale!=0.0: 291 | if self.detach_student: 292 | outputs_student = net_student(inputs_jit).detach() 293 | else: 294 | outputs_student = net_student(inputs_jit) 295 | 296 | T = 3.0 297 | if 1: 298 | T = 3.0 299 | # Jensen Shanon divergence: 300 | # another way to force KL between negative probabilities 301 | P = nn.functional.softmax(outputs_student / T, dim=1) 302 | Q = nn.functional.softmax(outputs / T, dim=1) 303 | M = 0.5 * (P + Q) 304 | 305 | P = torch.clamp(P, 0.01, 0.99) 306 | Q = torch.clamp(Q, 0.01, 0.99) 307 | M = torch.clamp(M, 0.01, 0.99) 308 | eps = 0.0 309 | loss_verifier_cig = 0.5 * kl_loss(torch.log(P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M) 310 | # JS criteria - 0 means full correlation, 1 - means completely different 311 | loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0, 1.0) 312 | 313 | if local_rank==0: 314 | if iteration % save_every==0: 315 | print('loss_verifier_cig', loss_verifier_cig.item()) 316 | 317 | # l2 loss on images 318 | loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean() 319 | 320 | # combining losses 321 | loss_aux = self.var_scale_l2 * loss_var_l2 + \ 322 | self.var_scale_l1 * loss_var_l1 + \ 323 | self.bn_reg_scale * loss_r_feature + \ 324 | self.l2_scale * loss_l2 325 | 326 | if self.adi_scale!=0.0: 327 | loss_aux += self.adi_scale * loss_verifier_cig 328 | 329 | loss = self.main_loss_multiplier * loss + loss_aux 330 | 331 | if local_rank==0: 332 | if iteration % save_every==0: 333 | print("------------iteration {}----------".format(iteration)) 334 | print("total loss", loss.item()) 335 | print("loss_r_feature", loss_r_feature.item()) 336 | print("main criterion", criterion(outputs, targets).item()) 337 | 338 | if self.hook_for_display is not None: 339 | self.hook_for_display(inputs, targets) 340 | 341 | # do image update 342 | if use_fp16: 343 | # optimizer.backward(loss) 344 | with amp.scale_loss(loss, optimizer) as scaled_loss: 345 | scaled_loss.backward() 346 | else: 347 | loss.backward() 348 | 349 | optimizer.step() 350 | 351 | # clip color outlayers 352 | if do_clip: 353 | inputs.data = clip(inputs.data, use_fp16=use_fp16) 354 | 355 | if best_cost > loss.item() or iteration == 1: 356 | best_inputs = inputs.data.clone() 357 | best_cost = loss.item() 358 | 359 | if iteration % save_every==0 and (save_every > 0): 360 | if local_rank==0: 361 | vutils.save_image(inputs, 362 | '{}/best_images/output_{:05d}_gpu_{}.png'.format(self.prefix, 363 | iteration // save_every, 364 | local_rank), 365 | normalize=True, scale_each=True, nrow=int(10)) 366 | 367 | if self.store_best_images: 368 | best_inputs = denormalize(best_inputs) 369 | self.save_images(best_inputs, targets) 370 | 371 | # to reduce memory consumption by states of the optimizer we deallocate memory 372 | optimizer.state = collections.defaultdict(dict) 373 | 374 | def save_images(self, images, targets): 375 | # method to store generated images locally 376 | local_rank = torch.cuda.current_device() 377 | for id in range(images.shape[0]): 378 | class_id = targets[id].item() 379 | if 0: 380 | #save into separate folders 381 | place_to_store = '{}/s{:03d}/img_{:05d}_id{:03d}_gpu_{}_2.jpg'.format(self.final_data_path, class_id, 382 | self.num_generations, id, 383 | local_rank) 384 | else: 385 | place_to_store = '{}/img_s{:03d}_{:05d}_id{:03d}_gpu_{}_2.jpg'.format(self.final_data_path, class_id, 386 | self.num_generations, id, 387 | local_rank) 388 | 389 | image_np = images[id].data.cpu().numpy().transpose((1, 2, 0)) 390 | pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) 391 | pil_image.save(place_to_store) 392 | 393 | def generate_batch(self, net_student=None, targets=None): 394 | # for ADI detach student and add put to eval mode 395 | net_teacher = self.net_teacher 396 | 397 | use_fp16 = self.use_fp16 398 | 399 | # fix net_student 400 | if not (net_student is None): 401 | net_student = net_student.eval() 402 | 403 | if targets is not None: 404 | targets = torch.from_numpy(np.array(targets).squeeze()).cuda() 405 | if use_fp16: 406 | targets = targets.half() 407 | 408 | self.get_images(net_student=net_student, targets=targets) 409 | 410 | net_teacher.eval() 411 | 412 | self.num_generations += 1 413 | -------------------------------------------------------------------------------- /example_logs/fp16_set0_rn50.log: -------------------------------------------------------------------------------- 1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp16" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0.2 --adi_scale=0.0 --l2=0.00001 --fp16 2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp16', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1) 3 | loading torchvision model for inversion with the name: resnet50 4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 5 | 6 | Defaults for this optimization level are: 7 | enabled : True 8 | opt_level : O2 9 | cast_model_type : torch.float16 10 | patch_torch_functions : False 11 | keep_batchnorm_fp32 : True 12 | master_weights : True 13 | loss_scale : dynamic 14 | Processing user overrides (additional kwargs that are not None)... 15 | After processing overrides, optimization options are: 16 | enabled : True 17 | opt_level : O2 18 | cast_model_type : torch.float16 19 | patch_torch_functions : False 20 | keep_batchnorm_fp32 : True 21 | master_weights : True 22 | loss_scale : dynamic 23 | ==> Resuming from checkpoint.. 24 | ==> Getting BN params as feature statistics 25 | loading verifier: mobilenet_v2 26 | Deep inversion class generation 27 | get_images call 28 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 29 | 30 | Defaults for this optimization level are: 31 | enabled : True 32 | opt_level : O2 33 | cast_model_type : torch.float16 34 | patch_torch_functions : False 35 | keep_batchnorm_fp32 : True 36 | master_weights : True 37 | loss_scale : dynamic 38 | Processing user overrides (additional kwargs that are not None)... 39 | After processing overrides, optimization options are: 40 | enabled : True 41 | opt_level : O2 42 | cast_model_type : torch.float16 43 | patch_torch_functions : False 44 | keep_batchnorm_fp32 : True 45 | master_weights : True 46 | loss_scale : dynamic 47 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 48 | ------------iteration 100---------- 49 | total loss 1.6876643896102905 50 | loss_r_feature 90.38799285888672 51 | main criterion 0.1843467503786087 52 | Verifier accuracy: 0.0 53 | ------------iteration 200---------- 54 | total loss 1.3778828382492065 55 | loss_r_feature 78.72280883789062 56 | main criterion 0.006536892615258694 57 | Verifier accuracy: 3.5714285373687744 58 | ------------iteration 300---------- 59 | total loss 1.1503252983093262 60 | loss_r_feature 64.16847229003906 61 | main criterion 0.03804704174399376 62 | Verifier accuracy: 5.952381134033203 63 | ------------iteration 400---------- 64 | total loss 1.122263789176941 65 | loss_r_feature 63.72652816772461 66 | main criterion 0.014303048141300678 67 | Verifier accuracy: 9.523809432983398 68 | ------------iteration 500---------- 69 | total loss 1.0680581331253052 70 | loss_r_feature 59.28545379638672 71 | main criterion 0.01540193147957325 72 | Verifier accuracy: 10.714285850524902 73 | ------------iteration 600---------- 74 | total loss 0.9677876830101013 75 | loss_r_feature 54.25212478637695 76 | main criterion 0.006902865134179592 77 | Verifier accuracy: 10.714285850524902 78 | ------------iteration 700---------- 79 | total loss 0.8655126094818115 80 | loss_r_feature 48.535343170166016 81 | main criterion 0.0010316940024495125 82 | Verifier accuracy: 10.714285850524902 83 | ------------iteration 800---------- 84 | total loss 0.7858816981315613 85 | loss_r_feature 43.149375915527344 86 | main criterion 0.0030028026085346937 87 | Verifier accuracy: 8.333333015441895 88 | ------------iteration 900---------- 89 | total loss 0.6898418068885803 90 | loss_r_feature 38.205223083496094 91 | main criterion 0.0007694562082178891 92 | Verifier accuracy: 14.285714149475098 93 | ------------iteration 1000---------- 94 | total loss 0.5907580256462097 95 | loss_r_feature 31.332698822021484 96 | main criterion 0.006730636116117239 97 | Verifier accuracy: 10.714285850524902 98 | ------------iteration 1100---------- 99 | total loss 0.5132060050964355 100 | loss_r_feature 26.838134765625 101 | main criterion 0.0011799221392720938 102 | Verifier accuracy: 9.523809432983398 103 | ------------iteration 1200---------- 104 | total loss 0.4708683490753174 105 | loss_r_feature 24.262351989746094 106 | main criterion 0.0008049465250223875 107 | Verifier accuracy: 10.714285850524902 108 | ------------iteration 1300---------- 109 | total loss 0.4237711727619171 110 | loss_r_feature 21.36450958251953 111 | main criterion 0.0014830997679382563 112 | Verifier accuracy: 5.952381134033203 113 | ------------iteration 1400---------- 114 | total loss 0.39408981800079346 115 | loss_r_feature 20.10096549987793 116 | main criterion 0.0006744748097844422 117 | Verifier accuracy: 9.523809432983398 118 | ------------iteration 1500---------- 119 | total loss 0.33674177527427673 120 | loss_r_feature 16.06620216369629 121 | main criterion 0.0008936041849665344 122 | Verifier accuracy: 10.714285850524902 123 | ------------iteration 1600---------- 124 | total loss 0.3079169690608978 125 | loss_r_feature 14.631080627441406 126 | main criterion 0.0008300145273096859 127 | Verifier accuracy: 13.095237731933594 128 | ------------iteration 1700---------- 129 | total loss 0.2896236181259155 130 | loss_r_feature 13.912619590759277 131 | main criterion 0.000469207763671875 132 | Verifier accuracy: 10.714285850524902 133 | ------------iteration 1800---------- 134 | total loss 0.27033284306526184 135 | loss_r_feature 12.821993827819824 136 | main criterion 0.000999643700197339 137 | Verifier accuracy: 8.333333015441895 138 | ------------iteration 1900---------- 139 | total loss 0.26189038157463074 140 | loss_r_feature 12.395000457763672 141 | main criterion 0.0004911195719614625 142 | Verifier accuracy: 8.333333015441895 143 | ------------iteration 2000---------- 144 | total loss 0.2615959644317627 145 | loss_r_feature 12.435093879699707 146 | main criterion 0.0006502469186671078 147 | Verifier accuracy: 8.333333015441895 148 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 149 | 150 | Defaults for this optimization level are: 151 | enabled : True 152 | opt_level : O2 153 | cast_model_type : torch.float16 154 | patch_torch_functions : False 155 | keep_batchnorm_fp32 : True 156 | master_weights : True 157 | loss_scale : dynamic 158 | Processing user overrides (additional kwargs that are not None)... 159 | After processing overrides, optimization options are: 160 | enabled : True 161 | opt_level : O2 162 | cast_model_type : torch.float16 163 | patch_torch_functions : False 164 | keep_batchnorm_fp32 : True 165 | master_weights : True 166 | loss_scale : dynamic 167 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 168 | ------------iteration 2100---------- 169 | total loss 1.1602451801300049 170 | loss_r_feature 43.01371383666992 171 | main criterion 0.002820753026753664 172 | Verifier accuracy: 64.28571319580078 173 | ------------iteration 2200---------- 174 | total loss 1.2936094999313354 175 | loss_r_feature 46.4441032409668 176 | main criterion 0.0017376854084432125 177 | Verifier accuracy: 65.47618865966797 178 | ------------iteration 2300---------- 179 | total loss 0.9823821187019348 180 | loss_r_feature 38.5743408203125 181 | main criterion 0.0017328716348856688 182 | Verifier accuracy: 77.38095092773438 183 | ------------iteration 2400---------- 184 | total loss 0.8478145003318787 185 | loss_r_feature 31.85713768005371 186 | main criterion 0.002227487973868847 187 | Verifier accuracy: 88.0952377319336 188 | ------------iteration 2500---------- 189 | total loss 0.7831273674964905 190 | loss_r_feature 30.018978118896484 191 | main criterion 0.0015915121184661984 192 | Verifier accuracy: 88.0952377319336 193 | ------------iteration 2600---------- 194 | total loss 0.6972528100013733 195 | loss_r_feature 25.377971649169922 196 | main criterion 0.0012009029742330313 197 | Verifier accuracy: 90.47618865966797 198 | ------------iteration 2700---------- 199 | total loss 0.5916131734848022 200 | loss_r_feature 22.423397064208984 201 | main criterion 0.0010576248168945312 202 | Verifier accuracy: 96.42857360839844 203 | ------------iteration 2800---------- 204 | total loss 0.4864479899406433 205 | loss_r_feature 18.20646095275879 206 | main criterion 0.0008288564858958125 207 | Verifier accuracy: 98.80952453613281 208 | ------------iteration 2900---------- 209 | total loss 0.4284505844116211 210 | loss_r_feature 16.447580337524414 211 | main criterion 0.0007148924050852656 212 | Verifier accuracy: 100.0 213 | ------------iteration 3000---------- 214 | total loss 0.4071371555328369 215 | loss_r_feature 15.495633125305176 216 | main criterion 0.0005214327829889953 217 | Verifier accuracy: 98.80952453613281 218 | -------------------------------------------------------------------------------- /example_logs/fp16_set0_rn50_adi02.log: -------------------------------------------------------------------------------- 1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_adi02" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0.2 --adi_scale=0.2 --l2=0.00001 --fp16 > fp16_set0_rn50_adi02.log 2 | Namespace(adi_scale=0.2, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_adi02', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1) 3 | loading torchvision model for inversion with the name: resnet50 4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 5 | 6 | Defaults for this optimization level are: 7 | enabled : True 8 | opt_level : O2 9 | cast_model_type : torch.float16 10 | patch_torch_functions : False 11 | keep_batchnorm_fp32 : True 12 | master_weights : True 13 | loss_scale : dynamic 14 | Processing user overrides (additional kwargs that are not None)... 15 | After processing overrides, optimization options are: 16 | enabled : True 17 | opt_level : O2 18 | cast_model_type : torch.float16 19 | patch_torch_functions : False 20 | keep_batchnorm_fp32 : True 21 | master_weights : True 22 | loss_scale : dynamic 23 | ==> Resuming from checkpoint.. 24 | ==> Getting BN params as feature statistics 25 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 26 | 27 | Defaults for this optimization level are: 28 | enabled : True 29 | opt_level : O2 30 | cast_model_type : torch.float16 31 | patch_torch_functions : False 32 | keep_batchnorm_fp32 : True 33 | master_weights : True 34 | loss_scale : dynamic 35 | Processing user overrides (additional kwargs that are not None)... 36 | After processing overrides, optimization options are: 37 | enabled : True 38 | opt_level : O2 39 | cast_model_type : torch.float16 40 | patch_torch_functions : False 41 | keep_batchnorm_fp32 : True 42 | master_weights : True 43 | loss_scale : dynamic 44 | Deep inversion class generation 45 | get_images call 46 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 47 | 48 | Defaults for this optimization level are: 49 | enabled : True 50 | opt_level : O2 51 | cast_model_type : torch.float16 52 | patch_torch_functions : False 53 | keep_batchnorm_fp32 : True 54 | master_weights : True 55 | loss_scale : dynamic 56 | Processing user overrides (additional kwargs that are not None)... 57 | After processing overrides, optimization options are: 58 | enabled : True 59 | opt_level : O2 60 | cast_model_type : torch.float16 61 | patch_torch_functions : False 62 | keep_batchnorm_fp32 : True 63 | master_weights : True 64 | loss_scale : dynamic 65 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 66 | loss_verifier_cig 0.12461590766906738 67 | ------------iteration 100---------- 68 | total loss 1.5717917680740356 69 | loss_r_feature 83.22199249267578 70 | main criterion 0.18712279200553894 71 | Verifier accuracy: 1.1904761791229248 72 | loss_verifier_cig 0.0 73 | ------------iteration 200---------- 74 | total loss 1.2926644086837769 75 | loss_r_feature 75.16395568847656 76 | main criterion 0.009654657915234566 77 | Verifier accuracy: 3.5714285373687744 78 | loss_verifier_cig 0.029001593589782715 79 | ------------iteration 300---------- 80 | total loss 1.1569472551345825 81 | loss_r_feature 66.2757797241211 82 | main criterion 0.016784803941845894 83 | Verifier accuracy: 3.5714285373687744 84 | loss_verifier_cig 0.0 85 | ------------iteration 400---------- 86 | total loss 1.1502078771591187 87 | loss_r_feature 66.83061218261719 88 | main criterion 0.002368540968745947 89 | Verifier accuracy: 8.333333015441895 90 | loss_verifier_cig 0.0 91 | ------------iteration 500---------- 92 | total loss 1.0058181285858154 93 | loss_r_feature 58.08607482910156 94 | main criterion 0.0015441349241882563 95 | Verifier accuracy: 4.761904716491699 96 | loss_verifier_cig 0.0 97 | ------------iteration 600---------- 98 | total loss 1.006028175354004 99 | loss_r_feature 57.82709503173828 100 | main criterion 0.0010576475178822875 101 | Verifier accuracy: 2.3809523582458496 102 | loss_verifier_cig 0.0 103 | ------------iteration 700---------- 104 | total loss 0.8937127590179443 105 | loss_r_feature 50.34546661376953 106 | main criterion 0.0013211341574788094 107 | Verifier accuracy: 5.952381134033203 108 | loss_verifier_cig 0.0 109 | ------------iteration 800---------- 110 | total loss 0.7877690196037292 111 | loss_r_feature 44.32988357543945 112 | main criterion 0.0008785157115198672 113 | Verifier accuracy: 8.333333015441895 114 | loss_verifier_cig 0.0 115 | ------------iteration 900---------- 116 | total loss 0.7290716171264648 117 | loss_r_feature 40.597286224365234 118 | main criterion 0.002379996469244361 119 | Verifier accuracy: 7.142857074737549 120 | loss_verifier_cig 0.04963517189025879 121 | ------------iteration 1000---------- 122 | total loss 0.7164633870124817 123 | loss_r_feature 34.76940155029297 124 | main criterion 0.07151930779218674 125 | Verifier accuracy: 7.142857074737549 126 | loss_verifier_cig 0.0 127 | ------------iteration 1100---------- 128 | total loss 0.5727385878562927 129 | loss_r_feature 30.762893676757812 130 | main criterion 0.0021644660737365484 131 | Verifier accuracy: 5.952381134033203 132 | loss_verifier_cig 0.029660344123840332 133 | ------------iteration 1200---------- 134 | total loss 0.5012800097465515 135 | loss_r_feature 26.345243453979492 136 | main criterion 0.0016442026244476438 137 | Verifier accuracy: 4.761904716491699 138 | loss_verifier_cig 0.0 139 | ------------iteration 1300---------- 140 | total loss 0.45760539174079895 141 | loss_r_feature 24.4682674407959 142 | main criterion 0.0017067590961232781 143 | Verifier accuracy: 7.142857074737549 144 | loss_verifier_cig 0.0 145 | ------------iteration 1400---------- 146 | total loss 0.4178113341331482 147 | loss_r_feature 21.507671356201172 148 | main criterion 0.0004278591659385711 149 | Verifier accuracy: 8.333333015441895 150 | loss_verifier_cig 0.0 151 | ------------iteration 1500---------- 152 | total loss 0.3722783029079437 153 | loss_r_feature 18.461204528808594 154 | main criterion 0.00037799563142471015 155 | Verifier accuracy: 4.761904716491699 156 | loss_verifier_cig 0.0 157 | ------------iteration 1600---------- 158 | total loss 0.3312247097492218 159 | loss_r_feature 16.000471115112305 160 | main criterion 0.0004169373423792422 161 | Verifier accuracy: 5.952381134033203 162 | loss_verifier_cig 0.0 163 | ------------iteration 1700---------- 164 | total loss 0.3149060606956482 165 | loss_r_feature 15.741662979125977 166 | main criterion 0.0002418699732515961 167 | Verifier accuracy: 8.333333015441895 168 | loss_verifier_cig 0.0 169 | ------------iteration 1800---------- 170 | total loss 0.28622013330459595 171 | loss_r_feature 13.761808395385742 172 | main criterion 0.00026948112645186484 173 | Verifier accuracy: 5.952381134033203 174 | loss_verifier_cig 0.0 175 | ------------iteration 1900---------- 176 | total loss 0.2793577015399933 177 | loss_r_feature 13.480212211608887 178 | main criterion 0.00037529354449361563 179 | Verifier accuracy: 5.952381134033203 180 | loss_verifier_cig 0.0 181 | ------------iteration 2000---------- 182 | total loss 0.2799757122993469 183 | loss_r_feature 13.588874816894531 184 | main criterion 0.0003949347010347992 185 | Verifier accuracy: 4.761904716491699 186 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 187 | 188 | Defaults for this optimization level are: 189 | enabled : True 190 | opt_level : O2 191 | cast_model_type : torch.float16 192 | patch_torch_functions : False 193 | keep_batchnorm_fp32 : True 194 | master_weights : True 195 | loss_scale : dynamic 196 | Processing user overrides (additional kwargs that are not None)... 197 | After processing overrides, optimization options are: 198 | enabled : True 199 | opt_level : O2 200 | cast_model_type : torch.float16 201 | patch_torch_functions : False 202 | keep_batchnorm_fp32 : True 203 | master_weights : True 204 | loss_scale : dynamic 205 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 206 | loss_verifier_cig 0.1591559648513794 207 | ------------iteration 2100---------- 208 | total loss 1.305101990699768 209 | loss_r_feature 49.516441345214844 210 | main criterion 0.00047656468814238906 211 | Verifier accuracy: 9.523809432983398 212 | loss_verifier_cig 0.0571821928024292 213 | ------------iteration 2200---------- 214 | total loss 1.1693669557571411 215 | loss_r_feature 46.006778717041016 216 | main criterion 0.00037206921842880547 217 | Verifier accuracy: 10.714285850524902 218 | loss_verifier_cig 0.0 219 | ------------iteration 2300---------- 220 | total loss 0.9874729514122009 221 | loss_r_feature 38.61001205444336 222 | main criterion 0.0001040413262671791 223 | Verifier accuracy: 3.5714285373687744 224 | loss_verifier_cig 0.00839078426361084 225 | ------------iteration 2400---------- 226 | total loss 0.8956082463264465 227 | loss_r_feature 34.82595443725586 228 | main criterion 8.489972242387012e-05 229 | Verifier accuracy: 3.5714285373687744 230 | loss_verifier_cig 0.0 231 | ------------iteration 2500---------- 232 | total loss 0.8721445798873901 233 | loss_r_feature 34.64632797241211 234 | main criterion 0.00011521294072736055 235 | Verifier accuracy: 7.142857074737549 236 | loss_verifier_cig 0.056126534938812256 237 | ------------iteration 2600---------- 238 | total loss 0.7324535250663757 239 | loss_r_feature 28.194265365600586 240 | main criterion 0.0001758393773343414 241 | Verifier accuracy: 9.523809432983398 242 | loss_verifier_cig 0.001925349235534668 243 | ------------iteration 2700---------- 244 | total loss 0.6335565447807312 245 | loss_r_feature 24.50282096862793 246 | main criterion 7.415952859446406e-05 247 | Verifier accuracy: 10.714285850524902 248 | loss_verifier_cig 0.006530642509460449 249 | ------------iteration 2800---------- 250 | total loss 0.5335965752601624 251 | loss_r_feature 20.933963775634766 252 | main criterion 9.012222290039062e-05 253 | Verifier accuracy: 16.66666603088379 254 | loss_verifier_cig 0.0 255 | ------------iteration 2900---------- 256 | total loss 0.4517979919910431 257 | loss_r_feature 17.627702713012695 258 | main criterion 3.878275674651377e-05 259 | Verifier accuracy: 9.523809432983398 260 | loss_verifier_cig 0.0 261 | ------------iteration 3000---------- 262 | total loss 0.43130582571029663 263 | loss_r_feature 16.748857498168945 264 | main criterion 5.76518832531292e-05 265 | Verifier accuracy: 10.714285850524902 266 | -------------------------------------------------------------------------------- /example_logs/fp16_set0_rn50_adi02_output_00030_gpu_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set0_rn50_adi02_output_00030_gpu_0.jpg -------------------------------------------------------------------------------- /example_logs/fp16_set0_rn50_output_00030_gpu_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set0_rn50_output_00030_gpu_0.jpg -------------------------------------------------------------------------------- /example_logs/fp16_set1_rn50.log: -------------------------------------------------------------------------------- 1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp16_set1" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=1 --lr=0.2 --adi_scale=0.0 --l2=0.00001 --fp16 2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp16_set1', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=1, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1) 3 | loading torchvision model for inversion with the name: resnet50 4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 5 | 6 | Defaults for this optimization level are: 7 | enabled : True 8 | opt_level : O2 9 | cast_model_type : torch.float16 10 | patch_torch_functions : False 11 | keep_batchnorm_fp32 : True 12 | master_weights : True 13 | loss_scale : dynamic 14 | Processing user overrides (additional kwargs that are not None)... 15 | After processing overrides, optimization options are: 16 | enabled : True 17 | opt_level : O2 18 | cast_model_type : torch.float16 19 | patch_torch_functions : False 20 | keep_batchnorm_fp32 : True 21 | master_weights : True 22 | loss_scale : dynamic 23 | ==> Resuming from checkpoint.. 24 | ==> Getting BN params as feature statistics 25 | loading verifier: mobilenet_v2 26 | Deep inversion class generation 27 | get_images call 28 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. 29 | 30 | Defaults for this optimization level are: 31 | enabled : True 32 | opt_level : O2 33 | cast_model_type : torch.float16 34 | patch_torch_functions : False 35 | keep_batchnorm_fp32 : True 36 | master_weights : True 37 | loss_scale : dynamic 38 | Processing user overrides (additional kwargs that are not None)... 39 | After processing overrides, optimization options are: 40 | enabled : True 41 | opt_level : O2 42 | cast_model_type : torch.float16 43 | patch_torch_functions : False 44 | keep_batchnorm_fp32 : True 45 | master_weights : True 46 | loss_scale : dynamic 47 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 48 | ------------iteration 100---------- 49 | total loss 1.8805938959121704 50 | loss_r_feature 76.06755828857422 51 | main criterion 0.006961311679333448 52 | Verifier accuracy: 7.142857074737549 53 | ------------iteration 200---------- 54 | total loss 1.3011034727096558 55 | loss_r_feature 58.76356887817383 56 | main criterion 0.0030642691999673843 57 | Verifier accuracy: 40.47618865966797 58 | ------------iteration 300---------- 59 | total loss 1.096110224723816 60 | loss_r_feature 50.20344924926758 61 | main criterion 0.0015656152972951531 62 | Verifier accuracy: 76.19047546386719 63 | ------------iteration 400---------- 64 | total loss 1.0675420761108398 65 | loss_r_feature 46.52493667602539 66 | main criterion 0.0015199638437479734 67 | Verifier accuracy: 67.85713958740234 68 | ------------iteration 500---------- 69 | total loss 1.0750524997711182 70 | loss_r_feature 42.88349533081055 71 | main criterion 0.001011076383292675 72 | Verifier accuracy: 70.23809814453125 73 | ------------iteration 600---------- 74 | total loss 0.9664787650108337 75 | loss_r_feature 38.243465423583984 76 | main criterion 0.001509393914602697 77 | Verifier accuracy: 63.095237731933594 78 | ------------iteration 700---------- 79 | total loss 0.9061623811721802 80 | loss_r_feature 36.1928825378418 81 | main criterion 0.002603746484965086 82 | Verifier accuracy: 64.28571319580078 83 | ------------iteration 800---------- 84 | total loss 0.8216720223426819 85 | loss_r_feature 32.11809539794922 86 | main criterion 0.0023030894808471203 87 | Verifier accuracy: 75.0 88 | ------------iteration 900---------- 89 | total loss 0.9163199067115784 90 | loss_r_feature 36.28833770751953 91 | main criterion 0.05596952140331268 92 | Verifier accuracy: 72.61904907226562 93 | ------------iteration 1000---------- 94 | total loss 0.7641246318817139 95 | loss_r_feature 30.2562313079834 96 | main criterion 0.0017005829140543938 97 | Verifier accuracy: 82.14285278320312 98 | ------------iteration 1100---------- 99 | total loss 0.7134910225868225 100 | loss_r_feature 27.443801879882812 101 | main criterion 0.0016521726502105594 102 | Verifier accuracy: 84.52381134033203 103 | ------------iteration 1200---------- 104 | total loss 0.6644324064254761 105 | loss_r_feature 25.020050048828125 106 | main criterion 0.002477305242791772 107 | Verifier accuracy: 85.71428680419922 108 | ------------iteration 1300---------- 109 | total loss 0.6082080006599426 110 | loss_r_feature 23.141693115234375 111 | main criterion 0.0009327275329269469 112 | Verifier accuracy: 91.66666412353516 113 | ------------iteration 1400---------- 114 | total loss 0.5574465394020081 115 | loss_r_feature 21.793500900268555 116 | main criterion 0.0010219528339803219 117 | Verifier accuracy: 95.23809814453125 118 | ------------iteration 1500---------- 119 | total loss 0.5180432796478271 120 | loss_r_feature 20.062559127807617 121 | main criterion 0.0009084542398341 122 | Verifier accuracy: 94.04761505126953 123 | ------------iteration 1600---------- 124 | total loss 0.46451741456985474 125 | loss_r_feature 17.89577865600586 126 | main criterion 0.0005521774291992188 127 | Verifier accuracy: 97.61904907226562 128 | ------------iteration 1700---------- 129 | total loss 0.42708495259284973 130 | loss_r_feature 17.111621856689453 131 | main criterion 0.0005137125845067203 132 | Verifier accuracy: 96.42857360839844 133 | ------------iteration 1800---------- 134 | total loss 0.3998515009880066 135 | loss_r_feature 16.164155960083008 136 | main criterion 0.00046278181253001094 137 | Verifier accuracy: 96.42857360839844 138 | ------------iteration 1900---------- 139 | total loss 0.39311376214027405 140 | loss_r_feature 16.403812408447266 141 | main criterion 0.0006096022552810609 142 | Verifier accuracy: 96.42857360839844 143 | ------------iteration 2000---------- 144 | total loss 0.390480101108551 145 | loss_r_feature 16.30877113342285 146 | main criterion 0.0006353174103423953 147 | Verifier accuracy: 96.42857360839844 148 | -------------------------------------------------------------------------------- /example_logs/fp16_set1_rn50_output_00020_gpu_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set1_rn50_output_00020_gpu_0.jpg -------------------------------------------------------------------------------- /example_logs/fp32_set0_mnv2.log: -------------------------------------------------------------------------------- 1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_mnv2_set0" --r_feature=0.01 --arch_name="mobilenet_v2" --verifier --verifier_arch="r 2 | esnet18" --setting_id=0 --lr=0.2 --adi_scale=0.0 --l2=0.00001 3 | Namespace(adi_scale=0.0, arch_name='mobilenet_v2', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_mnv2_set0', fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id= 4 | 0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='resnet18', wd=0.01, worldsize=1) 5 | loading torchvision model for inversion with the name: mobilenet_v2 6 | ==> Resuming from checkpoint.. 7 | ==> Getting BN params as feature statistics 8 | loading verifier: resnet18 9 | Deep inversion class generation 10 | get_images call 11 | ------------iteration 100---------- 12 | total loss 3.482088804244995 13 | loss_r_feature 188.9123077392578 14 | main criterion 0.8707727789878845 15 | Verifier accuracy: 0.0 16 | ------------iteration 200---------- 17 | total loss 3.0610427856445312 18 | loss_r_feature 207.21432495117188 19 | main criterion 0.16552281379699707 20 | Verifier accuracy: 0.0 21 | ------------iteration 300---------- 22 | total loss 3.0640416145324707 23 | loss_r_feature 197.27249145507812 24 | main criterion 0.31159523129463196 25 | Verifier accuracy: 5.952381134033203 26 | ------------iteration 400---------- 27 | total loss 2.802457332611084 28 | loss_r_feature 192.7845001220703 29 | main criterion 0.13211332261562347 30 | Verifier accuracy: 3.5714285373687744 31 | ------------iteration 500---------- 32 | total loss 2.019845724105835 33 | loss_r_feature 130.70864868164062 34 | main criterion 0.011200053617358208 35 | Verifier accuracy: 9.523809432983398 36 | ------------iteration 600---------- 37 | total loss 2.262993812561035 38 | loss_r_feature 154.48069763183594 39 | main criterion 0.05309981480240822 40 | Verifier accuracy: 5.952381134033203 41 | ------------iteration 700---------- 42 | total loss 1.8827401399612427 43 | loss_r_feature 122.04618072509766 44 | main criterion 0.02727217972278595 45 | Verifier accuracy: 7.142857074737549 46 | ------------iteration 800---------- 47 | total loss 1.8060630559921265 48 | loss_r_feature 121.67047882080078 49 | main criterion 0.007290567737072706 50 | Verifier accuracy: 15.476190567016602 51 | ------------iteration 900---------- 52 | total loss 1.6456643342971802 53 | loss_r_feature 108.90213775634766 54 | main criterion 0.013873236253857613 55 | Verifier accuracy: 14.285714149475098 56 | ------------iteration 1000---------- 57 | total loss 1.403316855430603 58 | loss_r_feature 90.18961334228516 59 | main criterion 0.016842082142829895 60 | Verifier accuracy: 16.66666603088379 61 | ------------iteration 1100---------- 62 | total loss 1.5693998336791992 63 | loss_r_feature 112.5223617553711 64 | main criterion 0.0024354797787964344 65 | Verifier accuracy: 17.85714340209961 66 | ------------iteration 1200---------- 67 | total loss 1.3196629285812378 68 | loss_r_feature 88.94979858398438 69 | main criterion 0.019705001264810562 70 | Verifier accuracy: 28.571428298950195 71 | ------------iteration 1300---------- 72 | total loss 1.132418155670166 73 | loss_r_feature 75.13455963134766 74 | main criterion 0.007880108430981636 75 | Verifier accuracy: 29.761903762817383 76 | ------------iteration 1400---------- 77 | total loss 1.1630631685256958 78 | loss_r_feature 81.51448822021484 79 | main criterion 0.003338359761983156 80 | Verifier accuracy: 25.0 81 | ------------iteration 1500---------- 82 | total loss 0.849071204662323 83 | loss_r_feature 52.43641662597656 84 | main criterion 0.0011611892841756344 85 | Verifier accuracy: 27.380952835083008 86 | ------------iteration 1600---------- 87 | total loss 0.7553449869155884 88 | loss_r_feature 44.74074935913086 89 | main criterion 0.0010790483793243766 90 | Verifier accuracy: 15.476190567016602 91 | ------------iteration 1700---------- 92 | total loss 0.280352383852005 93 | loss_r_feature 13.329302787780762 94 | main criterion 0.0007046745158731937 95 | Verifier accuracy: 26.190475463867188 96 | ------------iteration 1800---------- 97 | total loss 0.26061105728149414 98 | loss_r_feature 11.900742530822754 99 | main criterion 0.0009304682607762516 100 | Verifier accuracy: 25.0 101 | ------------iteration 1900---------- 102 | total loss 0.2539810240268707 103 | loss_r_feature 11.516486167907715 104 | main criterion 0.0013499259948730469 105 | Verifier accuracy: 27.380952835083008 106 | ------------iteration 2000---------- 107 | total loss 0.25407370924949646 108 | loss_r_feature 11.625514030456543 109 | main criterion 0.0009400731069035828 110 | Verifier accuracy: 27.380952835083008 111 | ------------iteration 2100---------- 112 | total loss 1.186625361442566 113 | loss_r_feature 43.52707290649414 114 | main criterion 0.008745352737605572 115 | Verifier accuracy: 51.19047546386719 116 | ------------iteration 2200---------- 117 | total loss 1.031857967376709 118 | loss_r_feature 35.6275749206543 119 | main criterion 0.0017157054971903563 120 | Verifier accuracy: 73.80952453613281 121 | ------------iteration 2300---------- 122 | total loss 0.8248406648635864 123 | loss_r_feature 30.27305030822754 124 | main criterion 0.001007091486826539 125 | Verifier accuracy: 85.71428680419922 126 | ------------iteration 2400---------- 127 | total loss 0.8060064315795898 128 | loss_r_feature 28.234004974365234 129 | main criterion 0.0011983144795522094 130 | Verifier accuracy: 86.9047622680664 131 | ------------iteration 2500---------- 132 | total loss 0.7223383188247681 133 | loss_r_feature 25.458057403564453 134 | main criterion 0.001484291860833764 135 | Verifier accuracy: 90.47618865966797 136 | ------------iteration 2600---------- 137 | total loss 0.5898433923721313 138 | loss_r_feature 20.13226890563965 139 | main criterion 0.0014175687683746219 140 | Verifier accuracy: 91.66666412353516 141 | ------------iteration 2700---------- 142 | total loss 0.5021806955337524 143 | loss_r_feature 16.399860382080078 144 | main criterion 0.0010060241911560297 145 | Verifier accuracy: 96.42857360839844 146 | ------------iteration 2800---------- 147 | total loss 0.420066237449646 148 | loss_r_feature 13.413532257080078 149 | main criterion 0.000963824160862714 150 | Verifier accuracy: 92.85713958740234 151 | ------------iteration 2900---------- 152 | total loss 0.3616832494735718 153 | loss_r_feature 10.820902824401855 154 | main criterion 0.0006766319274902344 155 | Verifier accuracy: 95.23809814453125 156 | ------------iteration 3000---------- 157 | total loss 0.3493320047855377 158 | loss_r_feature 10.420402526855469 159 | main criterion 0.0005601133452728391 160 | Verifier accuracy: 96.42857360839844 -------------------------------------------------------------------------------- /example_logs/fp32_set0_mnv2_output_00030_gpu_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_mnv2_output_00030_gpu_0.jpg -------------------------------------------------------------------------------- /example_logs/fp32_set0_rn50.log: -------------------------------------------------------------------------------- 1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp32" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0. 2 | 2 --adi_scale=0.0 --l2=0.00001 3 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp32', fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, t 4 | v_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1) 5 | loading torchvision model for inversion with the name: resnet50 6 | ==> Resuming from checkpoint.. 7 | ==> Getting BN params as feature statistics 8 | loading verifier: mobilenet_v2 9 | Deep inversion class generation 10 | get_images call 11 | ------------iteration 100---------- 12 | total loss 1.7325098514556885 13 | loss_r_feature 90.60043334960938 14 | main criterion 0.22487393021583557 15 | Verifier accuracy: 1.1904761791229248 16 | ------------iteration 200---------- 17 | total loss 1.3570595979690552 18 | loss_r_feature 73.92747497558594 19 | main criterion 0.08076667785644531 20 | Verifier accuracy: 5.952381134033203 21 | ------------iteration 300---------- 22 | total loss 1.192380666732788 23 | loss_r_feature 66.64324951171875 24 | main criterion 0.04419786483049393 25 | Verifier accuracy: 4.761904716491699 [68/4971] 26 | ------------iteration 400---------- 27 | total loss 1.1331039667129517 28 | loss_r_feature 63.37013244628906 29 | main criterion 0.027310485020279884 30 | Verifier accuracy: 15.476190567016602 31 | ------------iteration 500---------- 32 | total loss 1.0495476722717285 33 | loss_r_feature 58.564552307128906 34 | main criterion 0.004942258354276419 35 | Verifier accuracy: 13.095237731933594 36 | ------------iteration 600---------- 37 | total loss 1.05406653881073 38 | loss_r_feature 57.1053466796875 39 | main criterion 0.02717919647693634 40 | Verifier accuracy: 15.476190567016602 41 | ------------iteration 700---------- 42 | total loss 0.8731762170791626 43 | loss_r_feature 48.39915466308594 44 | main criterion 0.0029847281984984875 45 | Verifier accuracy: 11.904762268066406 46 | ------------iteration 800---------- 47 | total loss 0.8121204376220703 48 | loss_r_feature 44.756126403808594 49 | main criterion 0.0013222013367339969 50 | Verifier accuracy: 13.095237731933594 51 | ------------iteration 900---------- 52 | total loss 0.7302803993225098 53 | loss_r_feature 39.77760314941406 54 | main criterion 0.00771959638223052 55 | Verifier accuracy: 15.476190567016602 56 | ------------iteration 1000---------- 57 | total loss 0.6622204184532166 58 | loss_r_feature 35.62640380859375 59 | main criterion 0.0059319790452718735 60 | Verifier accuracy: 21.428571701049805 61 | ------------iteration 1100---------- 62 | total loss 0.5391901731491089 63 | loss_r_feature 28.116008758544922 64 | main criterion 0.0015326568391174078 65 | Verifier accuracy: 23.809524536132812 66 | ------------iteration 1200---------- 67 | total loss 0.5046427845954895 68 | loss_r_feature 26.73769187927246 69 | main criterion 0.001567840576171875 70 | Verifier accuracy: 21.428571701049805 71 | ------------iteration 1300---------- 72 | total loss 0.45175373554229736 73 | loss_r_feature 23.481212615966797 74 | main criterion 0.0010218393290415406 75 | Verifier accuracy: 16.66666603088379 76 | ------------iteration 1400---------- 77 | total loss 0.3774851858615875 78 | loss_r_feature 18.778751373291016 79 | main criterion 0.0009181839996017516 80 | Verifier accuracy: 28.571428298950195 81 | ------------iteration 1500---------- 82 | total loss 0.3348837196826935 83 | loss_r_feature 16.353933334350586 84 | main criterion 0.0006400971324183047 85 | Verifier accuracy: 26.190475463867188 86 | ------------iteration 1600---------- 87 | total loss 0.29075923562049866 88 | loss_r_feature 13.699808120727539 89 | main criterion 0.0005158697022125125 90 | Verifier accuracy: 32.14285659790039 91 | ------------iteration 1700---------- 92 | total loss 0.280352383852005 93 | loss_r_feature 13.329302787780762 94 | main criterion 0.0007046745158731937 95 | Verifier accuracy: 26.190475463867188 96 | ------------iteration 1800---------- 97 | total loss 0.26061105728149414 98 | loss_r_feature 11.900742530822754 99 | main criterion 0.0009304682607762516 100 | Verifier accuracy: 25.0 101 | ------------iteration 1900---------- 102 | total loss 0.2539810240268707 103 | loss_r_feature 11.516486167907715 104 | main criterion 0.0013499259948730469 105 | Verifier accuracy: 27.380952835083008 106 | ------------iteration 2000---------- 107 | total loss 0.25407370924949646 108 | loss_r_feature 11.625514030456543 109 | main criterion 0.0009400731069035828 110 | Verifier accuracy: 27.380952835083008 111 | ------------iteration 2100---------- 112 | total loss 1.186625361442566 113 | loss_r_feature 43.52707290649414 114 | main criterion 0.008745352737605572 115 | Verifier accuracy: 51.19047546386719 116 | ------------iteration 2200---------- 117 | total loss 1.031857967376709 118 | loss_r_feature 35.6275749206543 119 | main criterion 0.0017157054971903563 120 | Verifier accuracy: 73.80952453613281 121 | ------------iteration 2300---------- 122 | total loss 0.8248406648635864 123 | loss_r_feature 30.27305030822754 124 | main criterion 0.001007091486826539 125 | Verifier accuracy: 85.71428680419922 126 | ------------iteration 2400---------- 127 | total loss 0.8060064315795898 128 | loss_r_feature 28.234004974365234 129 | main criterion 0.0011983144795522094 130 | Verifier accuracy: 86.9047622680664 131 | ------------iteration 2500---------- 132 | total loss 0.7223383188247681 133 | loss_r_feature 25.458057403564453 134 | main criterion 0.001484291860833764 135 | Verifier accuracy: 90.47618865966797 136 | ------------iteration 2600---------- 137 | total loss 0.5898433923721313 138 | loss_r_feature 20.13226890563965 139 | main criterion 0.0014175687683746219 140 | Verifier accuracy: 91.66666412353516 141 | ------------iteration 2700---------- 142 | total loss 0.5021806955337524 143 | loss_r_feature 16.399860382080078 144 | main criterion 0.0010060241911560297 145 | Verifier accuracy: 96.42857360839844 146 | ------------iteration 2800---------- 147 | total loss 0.420066237449646 148 | loss_r_feature 13.413532257080078 149 | main criterion 0.000963824160862714 150 | Verifier accuracy: 92.85713958740234 151 | ------------iteration 2900---------- 152 | total loss 0.3616832494735718 153 | loss_r_feature 10.820902824401855 154 | main criterion 0.0006766319274902344 155 | Verifier accuracy: 95.23809814453125 156 | ------------iteration 3000---------- 157 | total loss 0.3493320047855377 158 | loss_r_feature 10.420402526855469 159 | main criterion 0.0005601133452728391 160 | Verifier accuracy: 96.42857360839844 -------------------------------------------------------------------------------- /example_logs/fp32_set0_rn50_first_bn_scaled.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_rn50_first_bn_scaled.jpg -------------------------------------------------------------------------------- /example_logs/fp32_set0_rn50_first_bn_scaled.log: -------------------------------------------------------------------------------- 1 | python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion_first_bn_cpcp" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25 2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='rn50_inversion_first_bn_cpcp', first_bn_multiplier=10.0, fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.25, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, store_best_images=False, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', worldsize=1) 3 | loading torchvision model for inversion with the name: resnet50 4 | ==> Resuming from checkpoint.. 5 | ==> Getting BN params as feature statistics 6 | loading verifier: mobilenet_v2 7 | Deep inversion class generation 8 | get_images call 9 | ------------iteration 100---------- 10 | total loss 4.216484546661377 11 | loss_r_feature 319.7900085449219 12 | main criterion 0.3684564530849457 13 | Verifier accuracy: 0.0 14 | ------------iteration 200---------- 15 | total loss 3.074280023574829 16 | loss_r_feature 243.4846954345703 17 | main criterion 0.03289255499839783 18 | Verifier accuracy: 3.5714285373687744 19 | ------------iteration 300---------- 20 | total loss 2.556410312652588 21 | loss_r_feature 185.2130584716797 22 | main criterion 0.1536623239517212 23 | Verifier accuracy: 2.3809523582458496 24 | ------------iteration 400---------- 25 | total loss 2.1119472980499268 26 | loss_r_feature 157.12548828125 27 | main criterion 0.013974723406136036 28 | Verifier accuracy: 2.3809523582458496 29 | ------------iteration 500---------- 30 | total loss 1.6994301080703735 31 | loss_r_feature 120.24372863769531 32 | main criterion 0.023798726499080658 33 | Verifier accuracy: 7.142857074737549 34 | ------------iteration 600---------- 35 | total loss 1.5869784355163574 36 | loss_r_feature 102.4736099243164 37 | main criterion 0.12379765510559082 38 | Verifier accuracy: 8.333333015441895 39 | ------------iteration 700---------- 40 | total loss 1.2205699682235718 41 | loss_r_feature 81.42609405517578 42 | main criterion 0.012213945388793945 43 | Verifier accuracy: 14.285714149475098 44 | ------------iteration 800---------- 45 | total loss 1.0356595516204834 46 | loss_r_feature 67.60336303710938 47 | main criterion 0.0036269596312195063 48 | Verifier accuracy: 22.619047164916992 49 | ------------iteration 900---------- 50 | total loss 0.9518051743507385 51 | loss_r_feature 57.53125762939453 52 | main criterion 0.054748646914958954 53 | Verifier accuracy: 20.238094329833984 54 | ------------iteration 1000---------- 55 | total loss 0.762413501739502 56 | loss_r_feature 46.62602233886719 57 | main criterion 0.005131573881953955 58 | Verifier accuracy: 27.380952835083008 59 | ------------iteration 1100---------- 60 | total loss 0.6755117774009705 61 | loss_r_feature 38.807640075683594 62 | main criterion 0.020008916035294533 63 | Verifier accuracy: 25.0 64 | ------------iteration 1200---------- 65 | total loss 0.5802351832389832 66 | loss_r_feature 34.1475830078125 67 | main criterion 0.0015264465473592281 68 | Verifier accuracy: 28.571428298950195 69 | ------------iteration 1300---------- 70 | total loss 0.49006298184394836 71 | loss_r_feature 27.85464096069336 72 | main criterion 0.001066457713022828 73 | Verifier accuracy: 32.14285659790039 74 | ------------iteration 1400---------- 75 | total loss 0.4420880675315857 76 | loss_r_feature 23.253334045410156 77 | main criterion 0.0077417464926838875 78 | Verifier accuracy: 30.952381134033203 79 | ------------iteration 1500---------- 80 | total loss 0.4081510007381439 81 | loss_r_feature 21.22683334350586 82 | main criterion 0.0008004733244888484 83 | Verifier accuracy: 40.47618865966797 84 | ------------iteration 1600---------- 85 | total loss 0.36195051670074463 86 | loss_r_feature 17.16771697998047 87 | main criterion 0.0006432306254282594 88 | Verifier accuracy: 44.0476188659668 89 | ------------iteration 1700---------- 90 | total loss 0.3319593071937561 91 | loss_r_feature 14.380390167236328 92 | main criterion 0.0008895737701095641 93 | Verifier accuracy: 35.71428680419922 94 | ------------iteration 1800---------- 95 | total loss 0.31532973051071167 96 | loss_r_feature 12.930418968200684 97 | main criterion 0.0006053107208572328 98 | Verifier accuracy: 33.33333206176758 99 | ------------iteration 1900---------- 100 | total loss 0.298938125371933 101 | loss_r_feature 11.36793041229248 102 | main criterion 0.0007616224465891719 103 | Verifier accuracy: 36.904762268066406 104 | ------------iteration 2000---------- 105 | total loss 0.3005948066711426 106 | loss_r_feature 11.55113697052002 107 | main criterion 0.0006574335275217891 108 | Verifier accuracy: 38.095237731933594 109 | ------------iteration 2100---------- 110 | total loss 1.8834247589111328 111 | loss_r_feature 103.4402847290039 112 | main criterion 0.0071905795484781265 113 | Verifier accuracy: 52.380950927734375 114 | ------------iteration 2200---------- 115 | total loss 1.2003257274627686 116 | loss_r_feature 60.801185607910156 117 | main criterion 0.005007199011743069 118 | Verifier accuracy: 63.095237731933594 119 | ------------iteration 2300---------- 120 | total loss 1.0916008949279785 121 | loss_r_feature 51.438507080078125 122 | main criterion 0.0014722461346536875 123 | Verifier accuracy: 85.71428680419922 124 | ------------iteration 2400---------- 125 | total loss 0.8125420212745667 126 | loss_r_feature 33.99945068359375 127 | main criterion 0.0012546947691589594 128 | Verifier accuracy: 83.33333587646484 129 | ------------iteration 2500---------- 130 | total loss 0.8069868683815002 131 | loss_r_feature 35.57963562011719 132 | main criterion 0.0014406385598704219 133 | Verifier accuracy: 89.28571319580078 134 | ------------iteration 2600---------- 135 | total loss 0.6985664963722229 136 | loss_r_feature 28.85590934753418 137 | main criterion 0.0015078613068908453 138 | Verifier accuracy: 94.04761505126953 139 | ------------iteration 2700---------- 140 | total loss 0.640400767326355 141 | loss_r_feature 23.789012908935547 142 | main criterion 0.0010293552186340094 143 | Verifier accuracy: 96.42857360839844 144 | ------------iteration 2800---------- 145 | total loss 0.5649642944335938 146 | loss_r_feature 17.716434478759766 147 | main criterion 0.0007892563007771969 148 | Verifier accuracy: 91.66666412353516 149 | ------------iteration 2900---------- 150 | total loss 0.5026633739471436 151 | loss_r_feature 12.565704345703125 152 | main criterion 0.0006156648742035031 153 | Verifier accuracy: 94.04761505126953 154 | ------------iteration 3000---------- 155 | total loss 0.4865095913410187 156 | loss_r_feature 11.368947982788086 157 | main criterion 0.0005418686778284609 158 | Verifier accuracy: 91.66666412353516 159 | -------------------------------------------------------------------------------- /example_logs/fp32_set0_rn50_output_00030_gpu_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_rn50_output_00030_gpu_0.jpg -------------------------------------------------------------------------------- /example_logs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/teaser.png -------------------------------------------------------------------------------- /imagenet_inversion.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 3 | # Nvidia Source Code License-NC 4 | # Official PyTorch implementation of CVPR2020 paper 5 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion 6 | # Hongxu Yin, Pavlo Molchanov, Zhizhong Li, Jose M. Alvarez, Arun Mallya, Derek 7 | # Hoiem, Niraj K. Jha, and Jan Kautz 8 | # -------------------------------------------------------- 9 | 10 | from __future__ import division, print_function 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import unicode_literals 14 | 15 | import argparse 16 | import torch 17 | from torch import distributed, nn 18 | import random 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | import torch.utils.data 22 | from torchvision import datasets, transforms 23 | 24 | import numpy as np 25 | import torch.cuda.amp as amp 26 | import os 27 | import torchvision.models as models 28 | from utils.utils import load_model_pytorch, distributed_is_initialized 29 | 30 | random.seed(0) 31 | 32 | 33 | def validate_one(input, target, model): 34 | """Perform validation on the validation set""" 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the precision@k for the specified values of k""" 38 | maxk = max(topk) 39 | batch_size = target.size(0) 40 | 41 | _, pred = output.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].reshape(-1).float().sum(0) 48 | res.append(correct_k.mul_(100.0 / batch_size)) 49 | return res 50 | 51 | with torch.no_grad(): 52 | output = model(input) 53 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 54 | 55 | print("Verifier accuracy: ", prec1.item()) 56 | 57 | 58 | def run(args): 59 | torch.manual_seed(args.local_rank) 60 | device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') 61 | 62 | if args.arch_name == "resnet50v15": 63 | from models.resnetv15 import build_resnet 64 | net = build_resnet("resnet50", "classic") 65 | else: 66 | print("loading torchvision model for inversion with the name: {}".format(args.arch_name)) 67 | net = models.__dict__[args.arch_name](pretrained=True) 68 | 69 | net = net.to(device) 70 | 71 | use_fp16 = args.fp16 72 | if use_fp16: 73 | net, _ = amp.initialize(net, [], opt_level="O2") 74 | 75 | print('==> Resuming from checkpoint..') 76 | 77 | ### load models 78 | if args.arch_name=="resnet50v15": 79 | path_to_model = "./models/resnet50v15/model_best.pth.tar" 80 | load_model_pytorch(net, path_to_model, gpu_n=torch.cuda.current_device()) 81 | 82 | net.to(device) 83 | net.eval() 84 | 85 | # reserved to compute test accuracy on generated images by different networks 86 | net_verifier = None 87 | if args.verifier and args.adi_scale == 0: 88 | # if multiple GPUs are used then we can change code to load different verifiers to different GPUs 89 | if args.local_rank == 0: 90 | print("loading verifier: ", args.verifier_arch) 91 | net_verifier = models.__dict__[args.verifier_arch](pretrained=True).to(device) 92 | net_verifier.eval() 93 | 94 | if use_fp16: 95 | net_verifier = net_verifier.half() 96 | 97 | if args.adi_scale != 0.0: 98 | student_arch = "resnet18" 99 | net_verifier = models.__dict__[student_arch](pretrained=True).to(device) 100 | net_verifier.eval() 101 | 102 | if use_fp16: 103 | net_verifier, _ = amp.initialize(net_verifier, [], opt_level="O2") 104 | 105 | net_verifier = net_verifier.to(device) 106 | net_verifier.train() 107 | 108 | if use_fp16: 109 | for module in net_verifier.modules(): 110 | if isinstance(module, nn.BatchNorm2d): 111 | module.eval().half() 112 | 113 | from deepinversion import DeepInversionClass 114 | 115 | exp_name = args.exp_name 116 | # final images will be stored here: 117 | adi_data_path = "./final_images/%s"%exp_name 118 | # temporal data and generations will be stored here 119 | exp_name = "generations/%s"%exp_name 120 | 121 | args.iterations = 2000 122 | args.start_noise = True 123 | # args.detach_student = False 124 | 125 | args.resolution = 224 126 | bs = args.bs 127 | jitter = 30 128 | 129 | parameters = dict() 130 | parameters["resolution"] = 224 131 | parameters["random_label"] = False 132 | parameters["start_noise"] = True 133 | parameters["detach_student"] = False 134 | parameters["do_flip"] = True 135 | 136 | parameters["do_flip"] = args.do_flip 137 | parameters["random_label"] = args.random_label 138 | parameters["store_best_images"] = args.store_best_images 139 | 140 | criterion = nn.CrossEntropyLoss() 141 | 142 | coefficients = dict() 143 | coefficients["r_feature"] = args.r_feature 144 | coefficients["first_bn_multiplier"] = args.first_bn_multiplier 145 | coefficients["tv_l1"] = args.tv_l1 146 | coefficients["tv_l2"] = args.tv_l2 147 | coefficients["l2"] = args.l2 148 | coefficients["lr"] = args.lr 149 | coefficients["main_loss_multiplier"] = args.main_loss_multiplier 150 | coefficients["adi_scale"] = args.adi_scale 151 | 152 | network_output_function = lambda x: x 153 | 154 | # check accuracy of verifier 155 | if args.verifier: 156 | hook_for_display = lambda x,y: validate_one(x, y, net_verifier) 157 | else: 158 | hook_for_display = None 159 | 160 | DeepInversionEngine = DeepInversionClass(net_teacher=net, 161 | final_data_path=adi_data_path, 162 | path=exp_name, 163 | parameters=parameters, 164 | setting_id=args.setting_id, 165 | bs = bs, 166 | use_fp16 = args.fp16, 167 | jitter = jitter, 168 | criterion=criterion, 169 | coefficients = coefficients, 170 | network_output_function = network_output_function, 171 | hook_for_display = hook_for_display) 172 | net_student=None 173 | if args.adi_scale != 0: 174 | net_student = net_verifier 175 | DeepInversionEngine.generate_batch(net_student=net_student) 176 | 177 | def main(): 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument('-s', '--worldsize', type=int, default=1, help='Number of processes participating in the job.') 180 | parser.add_argument('--local_rank', '--rank', type=int, default=0, help='Rank of the current process.') 181 | parser.add_argument('--adi_scale', type=float, default=0.0, help='Coefficient for Adaptive Deep Inversion') 182 | parser.add_argument('--no-cuda', action='store_true') 183 | 184 | parser.add_argument('--epochs', default=20000, type=int, help='batch size') 185 | parser.add_argument('--setting_id', default=0, type=int, help='settings for optimization: 0 - multi resolution, 1 - 2k iterations, 2 - 20k iterations') 186 | parser.add_argument('--bs', default=64, type=int, help='batch size') 187 | parser.add_argument('--jitter', default=30, type=int, help='batch size') 188 | parser.add_argument('--comment', default='', type=str, help='batch size') 189 | parser.add_argument('--arch_name', default='resnet50', type=str, help='model name from torchvision or resnet50v15') 190 | 191 | parser.add_argument('--fp16', action='store_true', help='use FP16 for optimization') 192 | parser.add_argument('--exp_name', type=str, default='test', help='where to store experimental data') 193 | 194 | parser.add_argument('--verifier', action='store_true', help='evaluate batch with another model') 195 | parser.add_argument('--verifier_arch', type=str, default='mobilenet_v2', help = "arch name from torchvision models to act as a verifier") 196 | 197 | parser.add_argument('--do_flip', action='store_true', help='apply flip during model inversion') 198 | parser.add_argument('--random_label', action='store_true', help='generate random label for optimization') 199 | parser.add_argument('--r_feature', type=float, default=0.05, help='coefficient for feature distribution regularization') 200 | parser.add_argument('--first_bn_multiplier', type=float, default=10., help='additional multiplier on first bn layer of R_feature') 201 | parser.add_argument('--tv_l1', type=float, default=0.0, help='coefficient for total variation L1 loss') 202 | parser.add_argument('--tv_l2', type=float, default=0.0001, help='coefficient for total variation L2 loss') 203 | parser.add_argument('--lr', type=float, default=0.2, help='learning rate for optimization') 204 | parser.add_argument('--l2', type=float, default=0.00001, help='l2 loss on the image') 205 | parser.add_argument('--main_loss_multiplier', type=float, default=1.0, help='coefficient for the main loss in optimization') 206 | parser.add_argument('--store_best_images', action='store_true', help='save best images as separate files') 207 | 208 | args = parser.parse_args() 209 | print(args) 210 | 211 | torch.backends.cudnn.benchmark = True 212 | run(args) 213 | 214 | 215 | if __name__ == '__main__': 216 | main() 217 | -------------------------------------------------------------------------------- /models/resnetv15.py: -------------------------------------------------------------------------------- 1 | # Originated from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 2 | # now code is at https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'build_resnet', 'resnet_versions', 'resnet_configs'] 6 | 7 | # ResNetBuilder {{{ 8 | 9 | class ResNetBuilder(object): 10 | def __init__(self, version, config): 11 | self.config = config 12 | 13 | self.L = sum(version['layers']) 14 | self.M = version['block'].M 15 | 16 | self.layer_index = 0 17 | 18 | def conv(self, kernel_size, in_planes, out_planes, stride=1): 19 | if kernel_size == 3: 20 | conv = self.config['conv']( 21 | in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | elif kernel_size == 1: 24 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 25 | bias=False) 26 | elif kernel_size == 5: 27 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, 28 | padding=2, bias=False) 29 | elif kernel_size == 7: 30 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, 31 | padding=3, bias=False) 32 | else: 33 | return None 34 | 35 | if self.config['nonlinearity'] == 'relu': 36 | nn.init.kaiming_normal_(conv.weight, 37 | mode=self.config['conv_init'], 38 | nonlinearity=self.config['nonlinearity']) 39 | 40 | return conv 41 | 42 | def conv3x3(self, in_planes, out_planes, stride=1): 43 | """3x3 convolution with padding""" 44 | c = self.conv(3, in_planes, out_planes, stride=stride) 45 | return c 46 | 47 | def conv1x1(self, in_planes, out_planes, stride=1): 48 | """1x1 convolution with padding""" 49 | c = self.conv(1, in_planes, out_planes, stride=stride) 50 | return c 51 | 52 | def conv7x7(self, in_planes, out_planes, stride=1): 53 | """7x7 convolution with padding""" 54 | c = self.conv(7, in_planes, out_planes, stride=stride) 55 | return c 56 | 57 | def conv5x5(self, in_planes, out_planes, stride=1): 58 | """5x5 convolution with padding""" 59 | c = self.conv(5, in_planes, out_planes, stride=stride) 60 | return c 61 | 62 | def batchnorm(self, planes, last_bn=False): 63 | bn = nn.BatchNorm2d(planes) 64 | gamma_init_val = 0 if last_bn and self.config['last_bn_0_init'] else 1 65 | nn.init.constant_(bn.weight, gamma_init_val) 66 | nn.init.constant_(bn.bias, 0) 67 | 68 | return bn 69 | 70 | def activation(self): 71 | return self.config['activation']() 72 | 73 | # ResNetBuilder }}} 74 | 75 | # BasicBlock {{{ 76 | class BasicBlock(nn.Module): 77 | M = 2 78 | expansion = 1 79 | 80 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 81 | super(BasicBlock, self).__init__() 82 | self.conv1 = builder.conv3x3(inplanes, planes, stride) 83 | self.bn1 = builder.batchnorm(planes) 84 | self.relu = builder.activation() 85 | self.conv2 = builder.conv3x3(planes, planes) 86 | self.bn2 = builder.batchnorm(planes, last_bn=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | if self.bn1 is not None: 95 | out = self.bn1(out) 96 | 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | 101 | if self.bn2 is not None: 102 | out = self.bn2(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | # BasicBlock }}} 112 | 113 | # Bottleneck {{{ 114 | class Bottleneck(nn.Module): 115 | M = 3 116 | expansion = 4 117 | 118 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 119 | super(Bottleneck, self).__init__() 120 | self.conv1 = builder.conv1x1(inplanes, planes) 121 | self.bn1 = builder.batchnorm(planes) 122 | self.conv2 = builder.conv3x3(planes, planes, stride=stride) 123 | self.bn2 = builder.batchnorm(planes) 124 | self.conv3 = builder.conv1x1(planes, planes * self.expansion) 125 | self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True) 126 | self.relu = builder.activation() 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | 131 | def forward(self, x): 132 | residual = x 133 | 134 | out = self.conv1(x) 135 | out = self.bn1(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv2(out) 139 | out = self.bn2(out) 140 | 141 | out = self.relu(out) 142 | 143 | out = self.conv3(out) 144 | out = self.bn3(out) 145 | 146 | if self.downsample is not None: 147 | residual = self.downsample(x) 148 | 149 | out += residual 150 | 151 | out = self.relu(out) 152 | 153 | return out 154 | # Bottleneck }}} 155 | 156 | # ResNet {{{ 157 | class ResNet(nn.Module): 158 | def __init__(self, builder, block, layers, num_classes=1000): 159 | self.inplanes = 64 160 | super(ResNet, self).__init__() 161 | self.conv1 = builder.conv7x7(3, 64, stride=2) 162 | self.bn1 = builder.batchnorm(64) 163 | self.relu = builder.activation() 164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 165 | self.layer1 = self._make_layer(builder, block, 64, layers[0]) 166 | self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2) 167 | self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2) 168 | self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2) 169 | self.avgpool = nn.AdaptiveAvgPool2d(1) 170 | self.fc = nn.Linear(512 * block.expansion, num_classes) 171 | 172 | def _make_layer(self, builder, block, planes, blocks, stride=1): 173 | downsample = None 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | dconv = builder.conv1x1(self.inplanes, planes * block.expansion, 176 | stride=stride) 177 | dbn = builder.batchnorm(planes * block.expansion) 178 | if dbn is not None: 179 | downsample = nn.Sequential(dconv, dbn) 180 | else: 181 | downsample = dconv 182 | 183 | layers = [] 184 | layers.append(block(builder, self.inplanes, planes, stride, downsample)) 185 | self.inplanes = planes * block.expansion 186 | for i in range(1, blocks): 187 | layers.append(block(builder, self.inplanes, planes)) 188 | 189 | builder.layer_index += 1 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | x = self.conv1(x) 195 | if self.bn1 is not None: 196 | x = self.bn1(x) 197 | 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | x = x.view(x.size(0), -1) 208 | x = self.fc(x) 209 | 210 | return x 211 | # ResNet }}} 212 | 213 | 214 | resnet_configs = { 215 | 'classic' : { 216 | 'conv' : nn.Conv2d, 217 | 'conv_init' : 'fan_out', 218 | 'nonlinearity' : 'relu', 219 | 'last_bn_0_init' : False, 220 | 'activation' : lambda: nn.ReLU(inplace=True), 221 | }, 222 | 'fanin' : { 223 | 'conv' : nn.Conv2d, 224 | 'conv_init' : 'fan_in', 225 | 'nonlinearity' : 'relu', 226 | 'last_bn_0_init' : False, 227 | 'activation' : lambda: nn.ReLU(inplace=True), 228 | }, 229 | } 230 | 231 | resnet_versions = { 232 | 'resnet18' : { 233 | 'net' : ResNet, 234 | 'block' : BasicBlock, 235 | 'layers' : [2, 2, 2, 2], 236 | 'num_classes' : 1000, 237 | }, 238 | 'resnet34' : { 239 | 'net' : ResNet, 240 | 'block' : BasicBlock, 241 | 'layers' : [3, 4, 6, 3], 242 | 'num_classes' : 1000, 243 | }, 244 | 'resnet50' : { 245 | 'net' : ResNet, 246 | 'block' : Bottleneck, 247 | 'layers' : [3, 4, 6, 3], 248 | 'num_classes' : 1000, 249 | }, 250 | 'resnet101' : { 251 | 'net' : ResNet, 252 | 'block' : Bottleneck, 253 | 'layers' : [3, 4, 23, 3], 254 | 'num_classes' : 1000, 255 | }, 256 | 'resnet152' : { 257 | 'net' : ResNet, 258 | 'block' : Bottleneck, 259 | 'layers' : [3, 8, 36, 3], 260 | 'num_classes' : 1000, 261 | }, 262 | } 263 | 264 | 265 | def build_resnet(version, config, model_state=None): 266 | version = resnet_versions[version] 267 | config = resnet_configs[config] 268 | 269 | builder = ResNetBuilder(version, config) 270 | print("Version: {}".format(version)) 271 | print("Config: {}".format(config)) 272 | model = version['net'](builder, 273 | version['block'], 274 | version['layers'], 275 | version['num_classes']) 276 | 277 | return model 278 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 3 | # Nvidia Source Code License-NC 4 | # Code written by Pavlo Molchanov and Hongxu Yin 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import os 9 | from torch import distributed, nn 10 | import random 11 | import numpy as np 12 | 13 | def load_model_pytorch(model, load_model, gpu_n=0): 14 | print("=> loading checkpoint '{}'".format(load_model)) 15 | 16 | checkpoint = torch.load(load_model, map_location = lambda storage, loc: storage.cuda(gpu_n)) 17 | 18 | if 'state_dict' in checkpoint.keys(): 19 | load_from = checkpoint['state_dict'] 20 | else: 21 | load_from = checkpoint 22 | 23 | if 1: 24 | if 'module.' in list(model.state_dict().keys())[0]: 25 | if 'module.' not in list(load_from.keys())[0]: 26 | from collections import OrderedDict 27 | 28 | load_from = OrderedDict([("module.{}".format(k), v) for k, v in load_from.items()]) 29 | 30 | if 'module.' not in list(model.state_dict().keys())[0]: 31 | if 'module.' in list(load_from.keys())[0]: 32 | from collections import OrderedDict 33 | 34 | load_from = OrderedDict([(k.replace("module.", ""), v) for k, v in load_from.items()]) 35 | 36 | if 1: 37 | if list(load_from.items())[0][0][:2] == "1." and list(model.state_dict().items())[0][0][:2] != "1.": 38 | load_from = OrderedDict([(k[2:], v) for k, v in load_from.items()]) 39 | 40 | load_from = OrderedDict([(k, v) for k, v in load_from.items() if "gate" not in k]) 41 | 42 | model.load_state_dict(load_from, strict=True) 43 | 44 | epoch_from = -1 45 | if 'epoch' in checkpoint.keys(): 46 | epoch_from = checkpoint['epoch'] 47 | print("=> loaded checkpoint '{}' (epoch {})" 48 | .format(load_model, epoch_from)) 49 | 50 | 51 | def create_folder(directory): 52 | # from https://stackoverflow.com/a/273227 53 | if not os.path.exists(directory): 54 | os.makedirs(directory) 55 | 56 | 57 | random.seed(0) 58 | 59 | def distributed_is_initialized(): 60 | if distributed.is_available(): 61 | if distributed.is_initialized(): 62 | return True 63 | return False 64 | 65 | 66 | def lr_policy(lr_fn): 67 | def _alr(optimizer, iteration, epoch): 68 | lr = lr_fn(iteration, epoch) 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = lr 71 | 72 | return _alr 73 | 74 | 75 | def lr_cosine_policy(base_lr, warmup_length, epochs): 76 | def _lr_fn(iteration, epoch): 77 | if epoch < warmup_length: 78 | lr = base_lr * (epoch + 1) / warmup_length 79 | else: 80 | e = epoch - warmup_length 81 | es = epochs - warmup_length 82 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 83 | return lr 84 | 85 | return lr_policy(_lr_fn) 86 | 87 | 88 | def beta_policy(mom_fn): 89 | def _alr(optimizer, iteration, epoch, param, indx): 90 | mom = mom_fn(iteration, epoch) 91 | for param_group in optimizer.param_groups: 92 | param_group[param][indx] = mom 93 | 94 | return _alr 95 | 96 | 97 | def mom_cosine_policy(base_beta, warmup_length, epochs): 98 | def _beta_fn(iteration, epoch): 99 | if epoch < warmup_length: 100 | beta = base_beta * (epoch + 1) / warmup_length 101 | else: 102 | beta = base_beta 103 | return beta 104 | 105 | return beta_policy(_beta_fn) 106 | 107 | 108 | def clip(image_tensor, use_fp16=False): 109 | ''' 110 | adjust the input based on mean and variance 111 | ''' 112 | if use_fp16: 113 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16) 114 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16) 115 | else: 116 | mean = np.array([0.485, 0.456, 0.406]) 117 | std = np.array([0.229, 0.224, 0.225]) 118 | for c in range(3): 119 | m, s = mean[c], std[c] 120 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s) 121 | return image_tensor 122 | 123 | 124 | def denormalize(image_tensor, use_fp16=False): 125 | ''' 126 | convert floats back to input 127 | ''' 128 | if use_fp16: 129 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16) 130 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16) 131 | else: 132 | mean = np.array([0.485, 0.456, 0.406]) 133 | std = np.array([0.229, 0.224, 0.225]) 134 | 135 | for c in range(3): 136 | m, s = mean[c], std[c] 137 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1) 138 | 139 | return image_tensor --------------------------------------------------------------------------------