├── .gitignore
├── FLGradientInversion
├── README.md
├── config
│ └── config_inversion.json
├── fl_gradient_inversion.py
├── main.py
├── orig.png
├── prior
│ └── prior_1.jpg
├── recon.png
├── requirements.txt
└── torchvision_class.py
├── LICENSE
├── README.md
├── cifar10
├── README.md
├── deepinversion_cifar10.py
├── images
│ └── better_last.png
└── resnet_cifar.py
├── deepinversion.py
├── example_logs
├── fp16_set0_rn50.log
├── fp16_set0_rn50_adi02.log
├── fp16_set0_rn50_adi02_output_00030_gpu_0.jpg
├── fp16_set0_rn50_output_00030_gpu_0.jpg
├── fp16_set1_rn50.log
├── fp16_set1_rn50_output_00020_gpu_0.jpg
├── fp32_set0_mnv2.log
├── fp32_set0_mnv2_output_00030_gpu_0.jpg
├── fp32_set0_rn50.log
├── fp32_set0_rn50_first_bn_scaled.jpg
├── fp32_set0_rn50_first_bn_scaled.log
├── fp32_set0_rn50_output_00030_gpu_0.jpg
└── teaser.png
├── imagenet_inversion.py
├── models
└── resnetv15.py
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | generation/
2 | temp/
3 | __pycache__
4 | .idea/
5 | *.tar.gz
6 | *.zip
7 | *.pkl
8 | *.pyc
9 |
--------------------------------------------------------------------------------
/FLGradientInversion/README.md:
--------------------------------------------------------------------------------
1 | # FL Gradient Inversion
2 |
3 | This directory contains the tools necessary to recreate the chest X-ray
4 | experiments described in
5 |
6 |
7 | ### Do Gradient Inversion Attacks Make Federated Learning Unsafe? [arXiv:2202.06924](https://arxiv.org/abs/2202.06924)
8 |
9 | ###### Abstract:
10 |
11 | > Federated learning (FL) allows the collaborative training of AI models without needing to share raw data. This capability makes it especially interesting for healthcare applications where patient and data privacy is of utmost concern. However, recent works on the inversion of deep neural networks from model gradients raised concerns about the security of FL in preventing the leakage of training data. In this work, we show that these attacks presented in the literature are impractical in real FL use-cases and provide a new baseline attack that works for more realistic scenarios where the clients' training involves updating the Batch Normalization (BN) statistics. Furthermore, we present new ways to measure and visualize potential data leakage in FL. Our work is a step towards establishing reproducible methods of measuring data leakage in FL and could help determine the optimal tradeoffs between privacy-preserving techniques, such as differential privacy, and model accuracy based on quantifiable metrics.
12 |
13 |
14 | ## Updates
15 |
16 | ***01/16/2023***
17 |
18 | 1. Our FL Gradient Inversion [paper](https://arxiv.org/pdf/2202.06924.pdf) is accepted to [IEEE Transactions on Medical Imaging (TMI)](https://www.embs.org/tmi/).
19 |
20 | 2. We release the code for FL Gradient Inversion model.
21 |
22 | ## Quick-start
23 |
24 | First, install requirements. The code was tested with Python 3.10.
25 | ```setup
26 | pip install -r requirements.txt
27 | ```
28 |
29 | To run an example gradient inversion attack from pre-recorded FL gradients
30 | from a "high-risk" client sending updates in the 10th training round based on
31 | just image (batch size 1), execute the following.
32 |
33 | ##### 1. Download the pre-recorded weights.
34 |
35 | Download the weights [here](https://drive.google.com/file/d/1o6aZy2oBSD7ayPgkHfZ41lzANhldTVyr/view?usp=share_link)
36 | and extract to `./weights`.
37 |
38 | The extracted folder should have the following content.
39 | ```
40 | weights
41 | ├── batchnorm_round10_client9.npz
42 | ├── FL_global_model_round10.pt
43 | └── updates_round10_client9.npz
44 | ```
45 |
46 | #### 2. Run the inversion code
47 | ```
48 | ./main.py
49 | ```
50 |
51 |
52 | ## Federated Learning Experiments
53 |
54 | To reproduce the experiments in the paper, we use [NVIDIA FLARE](https://github.com/NVIDIA/NVFlare) to produce
55 | the model updates shared in federated learning. Please visit [here]
56 | (https://nvidia.github.io/NVFlare/research/gradient-inversion) for
57 | details.
58 |
59 | The expected result is saved under [./outputs/recon.png](./outputs/recon.png). For larger
60 | training set sizes, several images will be reconstructed. See the
61 | "local_num_images" config option.
62 |
63 | #### Reconstruction
64 |
65 | | Original | Inversion |
66 | |-----------------|------------------|
67 | |  |  |
68 |
69 | > Note, the original image is from the [COVID-19 Radiography Database](https://www.kaggle.com/tawsifurrahman/covid19-radiography-database) (Normal-4085.png),
70 | > with a random patient name and date of birth overlaid.
71 |
72 | ## Citation
73 |
74 | > A. Hatamizadeh et al., "Do Gradient Inversion Attacks Make Federated Learning Unsafe?," in IEEE Transactions on Medical Imaging, doi: 10.1109/TMI.2023.3239391.
75 |
76 | BibTeX
77 | ```
78 | @ARTICLE{10025466,
79 | author={Hatamizadeh, Ali and Yin, Hongxu and Molchanov, Pavlo and Myronenko, Andriy and Li, Wenqi and Dogra, Prerna and Feng, Andrew and Flores, Mona G. and Kautz, Jan and Xu, Daguang and Roth, Holger R.},
80 | journal={IEEE Transactions on Medical Imaging},
81 | title={Do Gradient Inversion Attacks Make Federated Learning Unsafe?},
82 | year={2023},
83 | volume={},
84 | number={},
85 | pages={1-1},
86 | doi={10.1109/TMI.2023.3239391}}
87 | ```
88 |
89 | ## License
90 |
91 | Copyright (C) 2023 NVIDIA Corporation. All rights reserved.
92 |
93 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE
94 |
--------------------------------------------------------------------------------
/FLGradientInversion/config/config_inversion.json:
--------------------------------------------------------------------------------
1 | {
2 | "checkpoint_file": "./weights/FL_global_model_round10.pt",
3 | "weights_file": "./weights/updates_round10_client9.npz",
4 | "batchnorm_file": "./weights/batchnorm_round10_client9.npz",
5 | "img_prior": "./prior/prior_1.jpg",
6 | "save_path": "./outputs/",
7 | "model_name": "resnet18",
8 | "criterion": "BCEWithLogitsLoss",
9 | "num_classes": 2,
10 | "batch_size": 1,
11 | "iterations": 40000,
12 | "resolution": 224,
13 | "pretrained": false,
14 | "start_rand": false,
15 | "init_target_rand": true,
16 | "no_lr_decay": false,
17 | "grad_l2": 1e-3,
18 | "original_bn_l2": 1e-1,
19 | "energy_l2": 1e-1,
20 | "tv_l1": 0.0,
21 | "tv_l2": 1e-4,
22 | "lr": 1e-1,
23 | "l2": 1e-5,
24 | "lr_local": 1e-2,
25 | "local_bs": 1,
26 | "local_epoch": 1,
27 | "local_num_images": 1,
28 | "local_optim": "sgd",
29 | "save_every": 500
30 | }
31 |
--------------------------------------------------------------------------------
/FLGradientInversion/fl_gradient_inversion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import collections
10 | import json
11 | import logging
12 | import os
13 | from copy import deepcopy
14 | from typing import Callable, Dict, Iterable, Optional, Tuple, Union
15 |
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 | import torch
19 | import torchvision
20 | from ignite.engine import Engine
21 | from monai.data import DataLoader
22 | from monai.engines import SupervisedTrainer
23 | from monai.engines.utils import IterationEvents, default_prepare_batch
24 | from monai.inferers import SimpleInferer
25 | from monai.utils.enums import CommonKeys as Keys
26 | from PIL import Image
27 |
28 |
29 | class FLGradientInversion(object):
30 | def __init__(
31 | self,
32 | network,
33 | grad_lst,
34 | bn_stats,
35 | model_bn,
36 | prior_transforms=None,
37 | save_transforms=None,
38 | ):
39 | """FLGradientInversion is used to reconstruct training images and
40 | targets (ground truth labels) by attempting to invert the gradients
41 | (model updates) shared in a federated learning framework.
42 |
43 | Args:
44 | network: network for which the gradients are being inverted,
45 | i.e. the current global model the models updates are being
46 | computed with respect to.
47 | grad_lst: model updates.
48 | bn_stats: updated batch norm statistics.
49 | model_bn: updated model containing current batch norm statistics.
50 | prior_transforms: Optional custom transforms to read the prior
51 | image. Defaults to None.
52 | save_transforms: Optional transforms to save the reconstructed
53 | images. Defaults to None.
54 | Returns:
55 | __call__() function returns the reconstructions.
56 | """
57 | self.network = network
58 | self.bn_stats = bn_stats
59 | self.model_bn = model_bn
60 | self.loss_r_feature_layers = []
61 | self.grad_lst = grad_lst
62 | self.logger = logging.getLogger(self.__class__.__name__)
63 | self.prior_transforms = prior_transforms
64 | self.save_transforms = save_transforms
65 |
66 | def __call__(self, cfg):
67 | """Run the gradient inversion attack.
68 |
69 | Args:
70 | cfg: Configuration dictionary containing the following keys used
71 | in this call.
72 | - img_prior: full path to prior image file used to initialize
73 | the attack.
74 | - save_path: Optional save directory where reconstructed
75 | images and targets are being saved.
76 | - criterion: Loss used for training the classification
77 | network, e.g. "BCEWithLogitsLoss".
78 | - iterations: number of iterations to run the attack.
79 | - resolution: x/y dimension of the images to be reconstructed.
80 | - start_rand: Whether to start from random initialization.
81 | If `False`, the `img_prior` is used.
82 | - init_target_rand: Whether to initialize the reconstructed
83 | targets using a uniform distribution. If `False`, targets
84 | are initialized as all zeros.
85 | - no_lr_decay: Disable the learning rate decay of the
86 | optimizer.
87 | - grad_l2: L2 scaling factor on the gradient loss.
88 | - original_bn_l2: Scaling factor for batchnorm matching loss.
89 | - energy_l2: This adds gaussian noise to find global minimums.
90 | - tv_l1: Coefficient for total variation L1 loss.
91 | - tv_l2: Coefficient for total variation L2 loss.
92 | - lr: Learning rate for optimization.
93 | - l2: L2 loss on the image.
94 | - local_epoch: Local number of epochs used by the FL client.
95 | - local_optim: Local optimizer used by the FL client, Either
96 | "sgd" or "adam".
97 | - save_every: How often to save the reconstructions to file.
98 | Returns:
99 | Reconstructed images.
100 | """
101 | self.save_path = cfg["save_path"]
102 | save_every = cfg["save_every"]
103 | if save_every > 0:
104 | self.create_folder(self.save_path)
105 |
106 | if cfg["criterion"] == "BCEWithLogitsLoss":
107 | criterion = torch.nn.BCEWithLogitsLoss()
108 | elif cfg["criterion"] == "CrossEntropyLoss":
109 | criterion = torch.nn.CrossEntropyLoss()
110 | else:
111 | raise ValueError(
112 | "criterion should be BCEWithLogitsLoss or CrossEntropyLoss."
113 | )
114 |
115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116 | network = self.network
117 | local_rank = torch.cuda.current_device()
118 | if cfg["start_rand"]:
119 | inputs_1 = torch.randn(
120 | (cfg["batch_size"], 1, cfg["resolution"], cfg["resolution"]),
121 | requires_grad=True,
122 | device=device,
123 | dtype=torch.float,
124 | )
125 | else:
126 | prior_file = cfg["img_prior"]
127 | if self.prior_transforms:
128 | _img = self.prior_transforms(prior_file)
129 | else: # use default prior loading transforms
130 | pil_img = Image.open(prior_file)
131 | self.prior_transforms = torchvision.transforms.Compose(
132 | [
133 | torchvision.transforms.Resize(
134 | (cfg["resolution"], cfg["resolution"])
135 | ),
136 | torchvision.transforms.ToTensor(),
137 | ]
138 | )
139 | _img = self.prior_transforms(pil_img)
140 |
141 | # make init batch
142 | images = torch.empty(
143 | size=(
144 | cfg["local_num_images"],
145 | 1,
146 | cfg["resolution"],
147 | cfg["resolution"],
148 | )
149 | )
150 | for i in range(cfg["local_num_images"]):
151 | images[i] = _img.unsqueeze_(0)
152 | inputs_1 = images.to(device)
153 | inputs_1.requires_grad_(True)
154 |
155 | if cfg["init_target_rand"]:
156 | targets_in = torch.rand(
157 | (cfg["local_num_images"], 2),
158 | requires_grad=True,
159 | device=device,
160 | dtype=torch.float,
161 | )
162 | else:
163 | targets_in = torch.zeros(
164 | (cfg["local_num_images"], 2),
165 | requires_grad=True,
166 | device=device,
167 | dtype=torch.float,
168 | )
169 |
170 | iteration = -1
171 | for lr_it, _ in enumerate([2, 1]):
172 | iterations_per_layer = cfg["iterations"]
173 | if lr_it == 0:
174 | continue
175 | optimizer = torch.optim.Adam(
176 | [inputs_1, targets_in],
177 | lr=cfg["lr"],
178 | betas=[0.9, 0.9],
179 | eps=1e-8,
180 | )
181 | lr_scheduler = self.lr_cosine_policy(cfg["lr"], 100, iterations_per_layer)
182 | local_trainer = self.create_trainer(
183 | cfg=cfg,
184 | network=network,
185 | inputs=(
186 | inputs_1 * torch.ones((1, 3, 1, 1)).cuda()
187 | ), # turn grayscale to RGB (3-channel inputs)
188 | targets=targets_in,
189 | criterion=criterion,
190 | device=torch.device("cuda"),
191 | )
192 | for iteration_loc in range(iterations_per_layer):
193 | iteration += 1
194 | if not cfg["no_lr_decay"]:
195 | lr_scheduler(optimizer, iteration_loc, iteration_loc)
196 | inputs = inputs_1 * torch.ones((1, 3, 1, 1)).cuda()
197 | optimizer.zero_grad()
198 | network.zero_grad()
199 | network.train()
200 | loss_var_l1, loss_var_l2 = self.img_prior(inputs)
201 | loss_l2 = torch.norm(
202 | inputs.view(cfg["local_num_images"], -1), dim=1
203 | ).mean()
204 | loss_aux = (
205 | cfg["tv_l2"] * loss_var_l2
206 | + cfg["tv_l1"] * loss_var_l1
207 | + cfg["l2"] * loss_l2
208 | )
209 | loss = loss_aux
210 | if cfg["grad_l2"] > 0:
211 | new_grad = self.sim_local_updates(
212 | cfg=cfg,
213 | trainer=local_trainer,
214 | network=network,
215 | inputs=inputs,
216 | targets=targets_in,
217 | use_sigmoid=True,
218 | use_softmax=False,
219 | )
220 | loss_grad = 0
221 | for a, b in zip(new_grad, self.grad_lst):
222 | loss_grad += cfg["grad_l2"] * (torch.norm(a - b[1]))
223 | loss = loss + loss_grad
224 |
225 | # add batch norm loss
226 | bn_hooks = []
227 | self.model_bn.train()
228 | for name, module in self.model_bn.named_modules():
229 | if isinstance(module, torch.nn.BatchNorm2d):
230 | bn_hooks.append(
231 | DeepInversionFeatureHook(
232 | module=module,
233 | bn_stats=self.bn_stats,
234 | name=name,
235 | )
236 | )
237 | # run forward path once to compute bn_hooks
238 | self.model_bn(inputs)
239 | loss_bn_tmp = 0
240 | for hook in bn_hooks:
241 | loss_bn_tmp += hook.r_feature
242 | hook.close()
243 | loss_bn = cfg["original_bn_l2"] * loss_bn_tmp
244 | loss += loss_bn
245 | loss.backward(retain_graph=True)
246 | optimizer.step()
247 | if local_rank == 0:
248 | if iteration % save_every == 0:
249 | self.logger.info(f"------------iteration {iteration}----------")
250 | self.logger.info(f"total loss {loss.item()}")
251 | self.logger.info(
252 | f"mean targets {torch.mean(targets_in, 0).detach().cpu().numpy()}"
253 | )
254 | self.logger.info(f"gradient loss {loss_grad.item()}")
255 | self.logger.info(f"bn matching loss {loss_bn.item()}")
256 | self.logger.info(
257 | f"tvl2 loss {cfg['tv_l2'] * loss_var_l2.item()}"
258 | )
259 | best_inputs = inputs.clone()
260 | if iteration % save_every == 0 and (save_every > 0):
261 | self.save_results(
262 | images=best_inputs, targets=targets_in, name="recon"
263 | )
264 | # save reconstruction collage
265 | torchvision.utils.save_image(
266 | best_inputs,
267 | os.path.join(self.save_path, "recon.png"),
268 | normalize=True,
269 | scale_each=True,
270 | nrow=int(int(cfg["local_num_images"]) ** 0.5),
271 | )
272 | if cfg["energy_l2"] > 0.0:
273 | inputs_noise_add = torch.randn(inputs.size(), device=device)
274 | for param_group in optimizer.param_groups:
275 | current_lr = param_group["lr"]
276 | break
277 | std = cfg["energy_l2"] * current_lr
278 | if iteration % save_every == 0:
279 | if local_rank == 0:
280 | self.logger.info(
281 | f"Energy method waken up, "
282 | f"adding Gaussian of std {std}"
283 | )
284 | inputs.data = inputs.data + inputs_noise_add * std
285 |
286 | if save_every > 0:
287 | self.save_results(images=best_inputs, targets=targets_in, name="recon")
288 |
289 | optimizer.state = collections.defaultdict(dict)
290 |
291 | return best_inputs, targets_in
292 |
293 | @staticmethod
294 | def sim_local_updates(
295 | cfg,
296 | trainer,
297 | network,
298 | inputs,
299 | targets,
300 | use_softmax=False,
301 | use_sigmoid=True,
302 | ):
303 | """
304 | Run the equivalent local optimization loop to get gradients
305 | which will be matched (using SupervisedTrainer)
306 | """
307 | trainer.logger.setLevel(logging.WARNING)
308 |
309 | params_before = deepcopy(network.state_dict())
310 | trainer.network.load_state_dict(params_before)
311 | if use_softmax and use_sigmoid:
312 | raise ValueError(
313 | "Only set one of `use_softmax` or `use_sigmoid` to be true."
314 | )
315 | if use_softmax:
316 | targets = torch.softmax(targets, dim=-1)
317 | if use_sigmoid:
318 | targets = torch.sigmoid(targets)
319 | data = []
320 | for i in range(cfg["local_num_images"]):
321 | data.append({Keys.IMAGE: inputs[i, ...], Keys.LABEL: targets[i, ...]})
322 | trainer.data_loader = DataLoader([data], batch_size=cfg["local_bs"])
323 | if cfg["local_optim"] == "sgd":
324 | optimizer = torch.optim.SGD(network.parameters(), cfg["lr_local"])
325 | elif cfg["local_optim"] == "adam":
326 | optimizer = torch.optim.Adam(network.parameters(), cfg["lr_local"])
327 | else:
328 | raise ValueError(
329 | f"Local optimizer {cfg['local_optim']} " f"is not currently supported !"
330 | )
331 | trainer.optimizer.load_state_dict(optimizer.state_dict())
332 | trainer.optimizer.zero_grad()
333 | trainer.network.zero_grad()
334 | trainer.run()
335 | params_after = trainer.network.state_dict()
336 | new_grad = []
337 | for name, _ in network.named_parameters():
338 | new_grad.append(params_after[name] - params_before[name])
339 | return new_grad
340 |
341 | @staticmethod
342 | def create_trainer(cfg, network, inputs, targets, criterion, device=None):
343 | if device is None:
344 | device = torch.device("cuda")
345 |
346 | data = []
347 | for i in range(cfg["local_num_images"]):
348 | data.append({Keys.IMAGE: inputs[i, ...], Keys.LABEL: targets[i, ...]})
349 | loader = DataLoader([data], batch_size=cfg["local_bs"])
350 | if cfg["local_optim"] == "sgd":
351 | optimizer = torch.optim.SGD(network.parameters(), cfg["lr_local"])
352 | elif cfg["local_optim"] == "adam":
353 | optimizer = torch.optim.Adam(network.parameters(), cfg["lr_local"])
354 | else:
355 | raise ValueError(
356 | "Local optimizer {} is not currently supported !".format(
357 | cfg["local_optim"]
358 | )
359 | )
360 | optimizer.zero_grad()
361 | trainer = InversionSupervisedTrainer(
362 | device=device,
363 | max_epochs=cfg["local_epoch"],
364 | train_data_loader=loader,
365 | network=network,
366 | optimizer=optimizer,
367 | loss_function=criterion,
368 | amp=False,
369 | )
370 | return trainer
371 |
372 | def img_prior(self, inputs_jit):
373 | # COMPUTE total variation regularization loss
374 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
375 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
376 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
377 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
378 | loss_var_l2 = (
379 | torch.norm(diff1)
380 | + torch.norm(diff2)
381 | + torch.norm(diff3)
382 | + torch.norm(diff4)
383 | )
384 | loss_var_l1 = (
385 | (diff1.abs() / 255.0).mean()
386 | + (diff2.abs() / 255.0).mean()
387 | + (diff3.abs() / 255.0).mean()
388 | + (diff4.abs() / 255.0).mean()
389 | )
390 | loss_var_l1 = loss_var_l1 * 255.0
391 | return loss_var_l1, loss_var_l2
392 |
393 | def denormalize(self, image_tensor, use_fp16=False):
394 |
395 | if use_fp16:
396 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16)
397 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16)
398 | else:
399 | mean = np.array([0.485, 0.456, 0.406])
400 | std = np.array([0.229, 0.224, 0.225])
401 |
402 | for c in range(3):
403 | m, s = mean[c], std[c]
404 |
405 | if len(image_tensor.shape) == 4:
406 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)
407 |
408 | elif len(image_tensor.shape) == 3:
409 | image_tensor[c] = torch.clamp(image_tensor[c] * s + m, 0, 1)
410 | else:
411 | raise NotImplementedError()
412 |
413 | return image_tensor
414 |
415 | def create_folder(self, directory):
416 |
417 | if not os.path.exists(directory):
418 | os.makedirs(directory)
419 |
420 | def lr_policy(self, lr_fn):
421 | def _alr(optimizer, iteration, epoch):
422 | lr = lr_fn(iteration, epoch)
423 | for param_group in optimizer.param_groups:
424 | param_group["lr"] = lr
425 |
426 | return _alr
427 |
428 | def lr_cosine_policy(self, base_lr, warmup_length, epochs):
429 | def _lr_fn(iteration, epoch):
430 | if epoch < warmup_length:
431 | lr = base_lr * (epoch + 1) / warmup_length
432 | else:
433 | e = epoch - warmup_length
434 | es = epochs - warmup_length
435 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
436 | return lr
437 |
438 | return self.lr_policy(_lr_fn)
439 |
440 | def save_results(self, images, targets, name="recon"):
441 | # save reconstructed images
442 | for id in range(images.shape[0]):
443 | img = images[id, ...]
444 | if self.save_transforms:
445 | self.save_transforms(img)
446 | else:
447 | save_name = f"{name}_{id}.png"
448 | place_to_store = os.path.join(self.save_path, save_name)
449 |
450 | image_np = img.data.cpu().numpy()
451 | image_np = image_np.transpose((1, 2, 0))
452 | image_np = np.array(
453 | (image_np - np.min(image_np))
454 | / (np.max(image_np) - np.min(image_np))
455 | )
456 | plt.imsave(place_to_store, image_np)
457 |
458 | # save reconstructed targets
459 | place_to_store = os.path.join(self.save_path, f"{name}_targets.json")
460 |
461 | with open(place_to_store, "w") as f:
462 | json.dump(targets.detach().cpu().numpy().tolist(), f, indent=4)
463 |
464 |
465 | class InversionSupervisedTrainer(SupervisedTrainer):
466 | """
467 | Same as MONAI's SupervisedTrainer but using
468 | retain_graph=True in backward() calls.
469 | """
470 |
471 | def __init__(
472 | self,
473 | device: torch.device,
474 | max_epochs: int,
475 | train_data_loader: Union[Iterable, DataLoader],
476 | network: torch.nn.Module,
477 | optimizer: torch.optim.Optimizer,
478 | loss_function: Callable,
479 | epoch_length: Optional[int] = None,
480 | non_blocking: bool = False,
481 | prepare_batch: Callable = default_prepare_batch,
482 | amp: bool = False,
483 | ) -> None:
484 | super().__init__(
485 | device=device,
486 | max_epochs=max_epochs,
487 | train_data_loader=train_data_loader,
488 | network=network,
489 | optimizer=optimizer,
490 | loss_function=loss_function,
491 | epoch_length=epoch_length,
492 | non_blocking=non_blocking,
493 | prepare_batch=prepare_batch,
494 | iteration_update=None,
495 | inferer=SimpleInferer(),
496 | key_train_metric=None,
497 | additional_metrics=None,
498 | amp=amp,
499 | event_names=None,
500 | event_to_attr=None,
501 | )
502 |
503 | def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
504 | """
505 | Callback function for the Supervised Training processing logic of 1
506 | iteration in Ignite Engine.
507 | Return below items in a dictionary:
508 | - IMAGE: image Tensor data for model input, already moved
509 | to device.
510 | - LABEL: label Tensor data corresponding to the image, already
511 | moved to device.
512 | - PRED: prediction result of model.
513 | - LOSS: loss value computed by loss function.
514 |
515 | Args:
516 | engine: Ignite Engine, it can be a trainer, validator or evaluator.
517 | batchdata: input data for this iteration, usually can be dictionary
518 | or tuple of Tensor data.
519 |
520 | Raises:
521 | ValueError: When ``batchdata`` is None.
522 |
523 | """
524 | if batchdata is None:
525 | raise ValueError("Must provide batch data for current iteration.")
526 | batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
527 | if len(batch) == 2:
528 | inputs, targets = batch
529 | args: Tuple = ()
530 | kwargs: Dict = {}
531 | else:
532 | inputs, targets, args, kwargs = batch
533 | engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
534 |
535 | def _compute_pred_loss():
536 | engine.state.output[Keys.PRED] = self.inferer(
537 | inputs, self.network, *args, **kwargs
538 | )
539 | engine.fire_event(IterationEvents.FORWARD_COMPLETED)
540 | engine.state.output[Keys.LOSS] = self.loss_function(
541 | engine.state.output[Keys.PRED], targets
542 | ).mean()
543 | engine.fire_event(IterationEvents.LOSS_COMPLETED)
544 |
545 | self.network.train()
546 | self.network.zero_grad()
547 | self.optimizer.zero_grad()
548 | if self.amp and self.scaler is not None:
549 | with torch.cuda.amp.autocast():
550 | _compute_pred_loss()
551 | self.scaler.scale(engine.state.output[Keys.LOSS]).backward(
552 | retain_graph=True
553 | )
554 | engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
555 | self.scaler.step(self.optimizer)
556 | self.scaler.update()
557 | else:
558 | _compute_pred_loss()
559 | engine.state.output[Keys.LOSS].backward(retain_graph=True)
560 | engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
561 | self.optimizer.step()
562 | engine.fire_event(IterationEvents.MODEL_COMPLETED)
563 | return engine.state.output
564 |
565 |
566 | class DeepInversionFeatureHook:
567 | """
568 | Implementation of the forward hook to track feature statistics and
569 | compute a loss on them.
570 | Will compute mean and variance, and will use l2 as a loss
571 | """
572 |
573 | def __init__(self, module, bn_stats=None, name=None):
574 | self.hook = module.register_forward_hook(self.hook_fn)
575 | self.bn_stats = bn_stats
576 | self.name = name
577 | self.r_feature = None
578 | self.mean = None
579 | self.var = None
580 |
581 | def hook_fn(self, module, input, output):
582 | nch = input[0].shape[1]
583 | mean = input[0].mean([0, 2, 3])
584 | var = (
585 | input[0]
586 | .permute(1, 0, 2, 3)
587 | .contiguous()
588 | .view([nch, -1])
589 | .var(1, unbiased=False)
590 | )
591 | if self.bn_stats is None:
592 | var_feature = torch.norm(module.running_var.data - var, 2)
593 | mean_feature = torch.norm(module.running_mean.data - mean, 2)
594 | else:
595 | var_feature = torch.norm(
596 | torch.tensor(
597 | self.bn_stats[self.name + ".running_var"], device=input[0].device
598 | )
599 | - var,
600 | 2,
601 | )
602 | mean_feature = torch.norm(
603 | torch.tensor(
604 | self.bn_stats[self.name + ".running_mean"], device=input[0].device
605 | )
606 | - mean,
607 | 2,
608 | )
609 |
610 | rescale = 1.0
611 | self.r_feature = mean_feature + rescale * var_feature
612 | self.mean = mean
613 | self.var = var
614 |
615 | def close(self):
616 | self.hook.remove()
617 |
--------------------------------------------------------------------------------
/FLGradientInversion/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 |
12 | import json
13 | import re
14 | from copy import deepcopy
15 |
16 | import numpy as np
17 | import torch
18 | import torch.utils.data
19 |
20 | from fl_gradient_inversion import FLGradientInversion
21 | from torchvision_class import TorchVisionClassificationModel
22 |
23 |
24 | def run(cfg):
25 | """Run the gradient inversion attack.
26 |
27 | Args:
28 | cfg: Configuration dictionary containing the following keys used
29 | in to set up the attack. Should also contain the keys expected by
30 | FLGradientInversion's __call__() function.
31 | - model_name: Used to select the model aritechture,
32 | e.g. "resnet18".
33 | - num_classes:
34 | - pretrained:
35 | - checkpoint_file:
36 | - weights_file:
37 | - batchnorm_file:
38 | Returns:
39 | Reconstructed images.
40 | """
41 | torch.backends.cudnn.deterministic = True
42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43 | net = TorchVisionClassificationModel(
44 | model_name=cfg["model_name"],
45 | num_classes=cfg["num_classes"],
46 | pretrained=cfg["pretrained"],
47 | )
48 |
49 | checkpoint_file = cfg["checkpoint_file"]
50 | add_weights = cfg["weights_file"]
51 | batchnorm_file = cfg["batchnorm_file"]
52 | input_parameters = []
53 | updates = np.load(add_weights, allow_pickle=True)["weights"].item()
54 | update_sum = 0.0
55 | n_excluded = 0
56 | weights = []
57 | if checkpoint_file:
58 | model_data = torch.load(checkpoint_file)
59 | if "model" in model_data.keys():
60 | net.load_state_dict(model_data["model"])
61 | else:
62 | net.load_state_dict(model_data)
63 | exclude_vars = None
64 | if exclude_vars:
65 | re_pattern = re.compile(exclude_vars)
66 | for name, _ in net.named_parameters():
67 | if exclude_vars:
68 | if re_pattern.search(name):
69 | n_excluded += 1
70 | weights.append(0.0)
71 | else:
72 | weights.append(1.0)
73 | val = updates[name]
74 | update_sum += np.sum(np.abs(val))
75 | val = torch.from_numpy(val).to(device)
76 | input_parameters.append(val)
77 | assert update_sum > 0.0, "All updates are zero!"
78 | model_bn = deepcopy(net).cuda()
79 | update_sum = 0.0
80 | new_state_dict = model_bn.state_dict()
81 | for n in updates.keys():
82 | val = updates[n]
83 | update_sum += np.sum(np.abs(val))
84 | new_state_dict[n] += torch.tensor(
85 | val, dtype=new_state_dict[n].dtype, device=new_state_dict[n].device
86 | )
87 | model_bn.load_state_dict(new_state_dict)
88 | assert update_sum > 0.0, "All updates are zero!"
89 | n_bn_updated = 0
90 | global_state_dict = net.state_dict()
91 | if batchnorm_file:
92 | bn_momentum = 0.1
93 | print(
94 | f"Using full BN stats from {batchnorm_file} "
95 | f"with momentum {bn_momentum} ! \n"
96 | )
97 | bn_stats = np.load(batchnorm_file, allow_pickle=True)["batchnorm"].item()
98 | for n in bn_stats.keys():
99 | if "running" in n:
100 | xt = (
101 | bn_stats[n] - (1 - bn_momentum) * global_state_dict[n].numpy()
102 | ) / bn_momentum
103 | n_bn_updated += 1
104 | bn_stats[n] = xt
105 |
106 | net = net.to(device)
107 | grad_lst = []
108 | grad_lst_orig = np.load(add_weights, allow_pickle=True)["weights"].item()
109 | for name, _ in net.named_parameters():
110 | val = torch.from_numpy(grad_lst_orig[name]).cuda()
111 | grad_lst.append([name, val])
112 | grad_inversion_engine = FLGradientInversion(
113 | network=net,
114 | grad_lst=grad_lst,
115 | bn_stats=bn_stats,
116 | model_bn=model_bn,
117 | )
118 | grad_inversion_engine(cfg)
119 |
120 |
121 | def main():
122 | with open("./config/config_inversion.json", "r") as f:
123 | cfg = json.load(f)
124 |
125 | run(cfg)
126 |
127 |
128 | if __name__ == "__main__":
129 | main()
130 |
--------------------------------------------------------------------------------
/FLGradientInversion/orig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/orig.png
--------------------------------------------------------------------------------
/FLGradientInversion/prior/prior_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/prior/prior_1.jpg
--------------------------------------------------------------------------------
/FLGradientInversion/recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/FLGradientInversion/recon.png
--------------------------------------------------------------------------------
/FLGradientInversion/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.0
2 | torchvision==0.14.0
3 | pytorch-ignite==0.4.10
4 | numpy
5 | Pillow
6 | monai==1.1.0
7 | matplotlib
--------------------------------------------------------------------------------
/FLGradientInversion/torchvision_class.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | from monai.utils import optional_import
11 |
12 | models, _ = optional_import("torchvision.models")
13 |
14 |
15 | class TorchVisionClassificationModel(torch.nn.Module):
16 | """
17 | Customize TorchVision models to replace final linear/fully-connected layer to fit number of classes.
18 |
19 | Args:
20 | model_name: fully connected layer at the end from https://pytorch.org/vision/stable/models.html, e.g.
21 | ``resnet18`` (default), ``alexnet``, ``vgg16``, etc.
22 | num_classes: number of classes for the last classification layer. Default to 1.
23 | pretrained: whether to use the imagenet pretrained weights. Default to False.
24 | """
25 |
26 | def __init__(
27 | self,
28 | model_name: str = "resnet18",
29 | num_classes: int = 1,
30 | pretrained: bool = False,
31 | bias=True,
32 | ):
33 | super().__init__()
34 | self.model = getattr(models, model_name)(pretrained=pretrained)
35 | if "fc" in dir(self.model):
36 | self.model.fc = torch.nn.Linear(
37 | in_features=self.model.fc.in_features,
38 | out_features=num_classes,
39 | bias=bias,
40 | )
41 | elif "classifier" in dir(self.model) and "vgg" not in model_name:
42 | self.model.classifier = torch.nn.Linear(
43 | in_features=self.model.classifier.in_features,
44 | out_features=num_classes,
45 | bias=bias,
46 | )
47 | elif "vgg" in model_name:
48 | self.model.classifier[-1] = torch.nn.Linear(
49 | in_features=self.model.classifier[-1].in_features,
50 | out_features=num_classes,
51 | bias=bias,
52 | )
53 | else:
54 | raise ValueError(
55 | f"Model ['{model_name}'] does not have a supported classifier attribute."
56 | )
57 |
58 | def forward(self, x):
59 | return self.model(x)
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Nvidia Source Code License-NC
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 |
7 | “Software” means the original work of authorship made available under this License.
8 | “Work” means the Software and any additions to or derivative works of the Software that are made available under
9 | this License.
10 |
11 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate
12 | array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or
13 | provided by Nvidia or its affiliates.
14 |
15 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S.
16 | copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that
17 | remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
18 |
19 | Works, including the Software, are “made available” under this License by including in or with the Work either
20 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
21 |
22 | 2. License Grants
23 |
24 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual,
25 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display,
26 | publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
27 |
28 | 3. Limitations
29 |
30 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include
31 | a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent,
32 | trademark, or attribution notices that are present in the Work.
33 |
34 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and
35 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation
36 | in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to
37 | Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will
38 | continue to apply to the Work itself.
39 |
40 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.
41 | The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or
42 | non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
43 |
44 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim,
45 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then
46 | your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate
47 | immediately.
48 |
49 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or
50 | trademarks, except as necessary to reproduce the notices described in this License.
51 |
52 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants
53 | in Sections 2.1 and 2.2) will terminate immediately.
54 |
55 | 4. Disclaimer of Warranty.
56 |
57 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
58 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
59 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
60 |
61 | 5. Limitation of Liability.
62 |
63 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE),
64 | CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
65 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
66 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
67 | MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY
68 | OF SUCH DAMAGES.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion
3 |
4 | This repository is the official PyTorch implementation of [Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion](https://arxiv.org/abs/1912.08795) presented at CVPR 2020.
5 |
6 | The code will help to invert images from models of torchvision (pretrained on ImageNet), and run the images over another model to check generalization. We plan to update repo with CIFAR10 examples and teacher-student training.
7 |
8 | Useful links:
9 | * [Camera Ready PDF](https://drive.google.com/file/d/1jg4o458y70aCqUPRklMEy6dOGlZ0qMde/view?usp=sharing)
10 | * [ArXiv Full](https://arxiv.org/pdf/1912.08795.pdf)
11 | * [Dataset - Synthesized ImageNet](https://drive.google.com/open?id=1AXCW6_E_Qtr5qyb9jygGaLub13gQo10c): from ResNet50v1.5, ~2GB, organized by classes, ~140k images. Were used in Section 4.4 (Data-free Knowledge Transfer), best viewed in gThumb.
12 |
13 | 
14 |
15 | ## License
16 |
17 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
18 |
19 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE
20 |
21 | ## Updates
22 |
23 | - 2020 July 7. Added CIFAR10 inversion result for ResNet34 in the folder cifar10. Code on knowledge distillation will follow soon.
24 | - 2020 June 16. Added a new scaling factor `first_bn_multiplier` for first BN layer. This improves fidelity.
25 |
26 | ## Requirements
27 |
28 | Code was tested in virtual environment with Python 3.6. Install requirements:
29 |
30 | ```setup
31 | pip install torch==1.4.0
32 | pip install torchvision==0.5.0
33 | pip install numpy
34 | pip install Pillow
35 | ```
36 |
37 | Additionally install APEX library for FP16 support (2x less memory, 2x faster): [Installing NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start)
38 |
39 | Provided code was originally designed to invert ResNet50v1.5 model trained for 90 epochs that achieves 77.26% top-1 on ImageNet. We are not able to share the model, but anyone can train it here: [ResNet50v1.5](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5).
40 | Code works well for the default ResNet50 from torchvision package.
41 |
42 | Code was tested on NVIDIA V100 GPU and Titan X Pascal.
43 |
44 | ## Running the code
45 |
46 | This snippet will generate 84 images by inverting resnet50 model from torchvision package.
47 |
48 | `python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25`
49 |
50 | Arguments:
51 |
52 | - `bs` - batch size, should be close to original batch size during training, but not necessary.
53 | - `lr` - learning rate for the optimizer of input tensor for model inversion.
54 | - `do_flip` - will do random flipping between iterations
55 | - `exp_name` - name of the experiment, will create folder with this name in `./generations/` where intermediate generations will be stored after 100 iterations
56 | - `r_feature` - coefficient for feature distribution regularization, might need adjustment for other networks
57 | - `arch_name` - name of the network architecture, should be one of pretrained models from torch vision package: `resnet50`, `resnet18`, `mobilenet_v2` etc.
58 | - `fp16` - enables FP16 training if needed, will use FP16 training via APEX AMP (O2 level)
59 | - `verifier` - enables checking accuracy of generated images with another network (def `mobilenet_v2`) network after each 100 iterations.
60 | Useful to observe generalizability of generated images.
61 | - `setting_id` - settings for optimization: 0 - multi resolution scheme, 1 - 2k iterations full resolution, 2 - 20k iterations (the closes to ResNet50 experiments in the paper). Recommended to use setting_id={0, 1}
62 | - `adi_scale` - competition coefficient. With positive value will lead to images that are good for the original model, but bad for verifier. Value 0.2 was used in the paper.
63 | - `random_label` - randomly select classes for inversion. Without this argument the code will generate hand picked classes.
64 |
65 | After 3k iterations (~6 mins on NVIDIA V100) generation is done: `Verifier accuracy: 91.6...%` (experiment with >98% verifier accuracy can be found `/example_logs`). We generated images by inverting vanilla ResNet50 (not trained for image generation) and classification accuracy by MobileNetv2 is >90%. A grid of images look like (from `/final_images/`, reduced quality due to JPEG compression. )
66 | 
67 |
68 | Optimization is sensitive to hyper-parameters. Try local tunings for your setups/applications. Try to change the r_feature coefficient, l2 regularization, betas of Adam optimizer (beta=0 work well). Keep looking at `loss_r_feature` as it indicates how close feature statistics are to the training distribution.
69 |
70 | Reduce batch size if out of memory or without FP16 optimization. In the paper, we used batch size of 152, and larger batch size is preferred. This code will generate images from 41 hand picked classes. To randomize the target classes, simply use argument `--random_label`.
71 |
72 | Examples of running code with different arguments and resulting images can be found at `/example_logs/`.
73 |
74 | Check if you can invert other architectures, or even apply to other applications (keypoints, detection etc.).
75 | Method has a room for improvement:
76 | (a) improving the loss for feature regularization (we used MSE in paper but that may not be ideal for distribution matching),
77 | (b) making it even faster,
78 | (c) generating images for which multiple models are confident,
79 | (d) increasing diversity.
80 |
81 | Share your most exciting images at Twitter with hashtag [#Deepinversion](https://twitter.com/hashtag/deepinversion?src=hash) and [#DeepInvert](https://twitter.com/hashtag/DeepInvert?src=hashtag_click).
82 |
83 | ## Citation
84 |
85 | ```bibtex
86 | @inproceedings{yin2020dreaming,
87 | title = {Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion},
88 | author = {Yin, Hongxu and Molchanov, Pavlo and Alvarez, Jose M. and Li, Zhizhong and Mallya, Arun and Hoiem, Derek and Jha, Niraj K and Kautz, Jan},
89 | booktitle = {The IEEE/CVF Conf. Computer Vision and Pattern Recognition (CVPR)},
90 | month = June,
91 | year = {2020}
92 | }
93 | ```
94 |
--------------------------------------------------------------------------------
/cifar10/README.md:
--------------------------------------------------------------------------------
1 | 
2 | # CIFAR10 experiments
3 |
4 | ## License
5 |
6 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
7 |
8 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE
9 |
10 |
11 | ## Requirements
12 |
13 | Code was tested in virtual environment with Python 3.7. Install requirements:
14 |
15 | ```setup
16 | pip install torch==1.4.0 torchvision==0.5.0 numpy Pillow
17 | ```
18 |
19 | Additionally install APEX library for FP16 support (2x less memory and 2x faster): [Installing NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start)
20 |
21 | For CIFAR10 we will first need to train a teacher model, for comparison reasons we choose ResNet34 from DAFL method.
22 | Instruction for training teacher model can be found [here](https://github.com/huawei-noah/Data-Efficient-Model-Compression/tree/master/DAFL).
23 | Our model achieves 95.42% top1 accuracy on validation set.
24 |
25 | Running inversion with parameters from the paper:
26 | ```
27 | python deepinversion_cifar10.py --bs=256 --teacher_weights=./checkpoint/teacher_resnet34_only.weights\
28 | --r_feature_weight=10 --di_lr=0.05 --exp_descr="paper_parameters"
29 | ```
30 |
31 | Better reconstructed images can be obtained by tuning parameters, for example increasing total variation coefficient: `--di_var_scale=0.001`.
32 | ```
33 | python deepinversion_cifar10.py --bs=256 --teacher_weights=./checkpoint/teacher_resnet34_only.weights\
34 | --r_feature_weight=10 --di_lr=0.1 --exp_descr="paper_parameters_better" --di_var_scale=0.001 --di_l2_scale=0.0
35 | ```
36 |
37 | 
--------------------------------------------------------------------------------
/cifar10/deepinversion_cifar10.py:
--------------------------------------------------------------------------------
1 | '''
2 | ResNet model inversion for CIFAR10.
3 |
4 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
5 |
6 | This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE
7 | '''
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 | from __future__ import unicode_literals
12 |
13 | import argparse
14 | import random
15 | import torch
16 | import torch.nn as nn
17 | # import torch.nn.parallel
18 | import torch.backends.cudnn as cudnn
19 | import torch.optim as optim
20 | # import torch.utils.data
21 | import torch.nn.functional as F
22 | import torchvision
23 | import torchvision.transforms as transforms
24 | import torchvision.utils as vutils
25 | import torchvision.transforms as transforms
26 |
27 | import numpy as np
28 | import os
29 | import glob
30 | import collections
31 |
32 | from resnet_cifar import ResNet34, ResNet18
33 |
34 | try:
35 | from apex.parallel import DistributedDataParallel as DDP
36 | from apex import amp, optimizers
37 | USE_APEX = True
38 | except ImportError:
39 | print("Please install apex from https://www.github.com/nvidia/apex to run this example.")
40 | print("will attempt to run without it")
41 | USE_APEX = False
42 |
43 | #provide intermeiate information
44 | debug_output = False
45 | debug_output = True
46 |
47 |
48 | class DeepInversionFeatureHook():
49 | '''
50 | Implementation of the forward hook to track feature statistics and compute a loss on them.
51 | Will compute mean and variance, and will use l2 as a loss
52 | '''
53 |
54 | def __init__(self, module):
55 | self.hook = module.register_forward_hook(self.hook_fn)
56 |
57 | def hook_fn(self, module, input, output):
58 | # hook co compute deepinversion's feature distribution regularization
59 | nch = input[0].shape[1]
60 |
61 | mean = input[0].mean([0, 2, 3])
62 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
63 |
64 | # forcing mean and variance to match between two distributions
65 | # other ways might work better, e.g. KL divergence
66 | r_feature = torch.norm(module.running_var.data.type(var.type()) - var, 2) + torch.norm(
67 | module.running_mean.data.type(var.type()) - mean, 2)
68 |
69 | self.r_feature = r_feature
70 | # must have no output
71 |
72 | def close(self):
73 | self.hook.remove()
74 |
75 | def get_images(net, bs=256, epochs=1000, idx=-1, var_scale=0.00005,
76 | net_student=None, prefix=None, competitive_scale=0.01, train_writer = None, global_iteration=None,
77 | use_amp=False,
78 | optimizer = None, inputs = None, bn_reg_scale = 0.0, random_labels = False, l2_coeff=0.0):
79 | '''
80 | Function returns inverted images from the pretrained model, parameters are tight to CIFAR dataset
81 | args in:
82 | net: network to be inverted
83 | bs: batch size
84 | epochs: total number of iterations to generate inverted images, training longer helps a lot!
85 | idx: an external flag for printing purposes: only print in the first round, set as -1 to disable
86 | var_scale: the scaling factor for variance loss regularization. this may vary depending on bs
87 | larger - more blurred but less noise
88 | net_student: model to be used for Adaptive DeepInversion
89 | prefix: defines the path to store images
90 | competitive_scale: coefficient for Adaptive DeepInversion
91 | train_writer: tensorboardX object to store intermediate losses
92 | global_iteration: indexer to be used for tensorboard
93 | use_amp: boolean to indicate usage of APEX AMP for FP16 calculations - twice faster and less memory on TensorCores
94 | optimizer: potimizer to be used for model inversion
95 | inputs: data place holder for optimization, will be reinitialized to noise
96 | bn_reg_scale: weight for r_feature_regularization
97 | random_labels: sample labels from random distribution or use columns of the same class
98 | l2_coeff: coefficient for L2 loss on input
99 | return:
100 | A tensor on GPU with shape (bs, 3, 32, 32) for CIFAR
101 | '''
102 |
103 | kl_loss = nn.KLDivLoss(reduction='batchmean').cuda()
104 |
105 | # preventing backpropagation through student for Adaptive DeepInversion
106 | net_student.eval()
107 |
108 | best_cost = 1e6
109 |
110 | # initialize gaussian inputs
111 | inputs.data = torch.randn((bs, 3, 32, 32), requires_grad=True, device='cuda')
112 | # if use_amp:
113 | # inputs.data = inputs.data.half()
114 |
115 | # set up criteria for optimization
116 | criterion = nn.CrossEntropyLoss()
117 |
118 | optimizer.state = collections.defaultdict(dict) # Reset state of optimizer
119 |
120 | # target outputs to generate
121 | if random_labels:
122 | targets = torch.LongTensor([random.randint(0,9) for _ in range(bs)]).to('cuda')
123 | else:
124 | targets = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 25 + [0, 1, 2, 3, 4, 5]).to('cuda')
125 |
126 | ## Create hooks for feature statistics catching
127 | loss_r_feature_layers = []
128 | for module in net.modules():
129 | if isinstance(module, nn.BatchNorm2d):
130 | loss_r_feature_layers.append(DeepInversionFeatureHook(module))
131 |
132 | # setting up the range for jitter
133 | lim_0, lim_1 = 2, 2
134 |
135 | for epoch in range(epochs):
136 | # apply random jitter offsets
137 | off1 = random.randint(-lim_0, lim_0)
138 | off2 = random.randint(-lim_1, lim_1)
139 | inputs_jit = torch.roll(inputs, shifts=(off1,off2), dims=(2,3))
140 |
141 | # foward with jit images
142 | optimizer.zero_grad()
143 | net.zero_grad()
144 | outputs = net(inputs_jit)
145 | loss = criterion(outputs, targets)
146 | loss_target = loss.item()
147 |
148 | # competition loss, Adaptive DeepInvesrion
149 | if competitive_scale != 0.0:
150 | net_student.zero_grad()
151 | outputs_student = net_student(inputs_jit)
152 | T = 3.0
153 |
154 | if 1:
155 | # jensen shanon divergence:
156 | # another way to force KL between negative probabilities
157 | P = F.softmax(outputs_student / T, dim=1)
158 | Q = F.softmax(outputs / T, dim=1)
159 | M = 0.5 * (P + Q)
160 |
161 | P = torch.clamp(P, 0.01, 0.99)
162 | Q = torch.clamp(Q, 0.01, 0.99)
163 | M = torch.clamp(M, 0.01, 0.99)
164 | eps = 0.0
165 | # loss_verifier_cig = 0.5 * kl_loss(F.log_softmax(outputs_verifier / T, dim=1), M) + 0.5 * kl_loss(F.log_softmax(outputs/T, dim=1), M)
166 | loss_verifier_cig = 0.5 * kl_loss(torch.log(P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M)
167 | # JS criteria - 0 means full correlation, 1 - means completely different
168 | loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0, 1.0)
169 |
170 | loss = loss + competitive_scale * loss_verifier_cig
171 |
172 | # apply total variation regularization
173 | diff1 = inputs_jit[:,:,:,:-1] - inputs_jit[:,:,:,1:]
174 | diff2 = inputs_jit[:,:,:-1,:] - inputs_jit[:,:,1:,:]
175 | diff3 = inputs_jit[:,:,1:,:-1] - inputs_jit[:,:,:-1,1:]
176 | diff4 = inputs_jit[:,:,:-1,:-1] - inputs_jit[:,:,1:,1:]
177 | loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
178 | loss = loss + var_scale*loss_var
179 |
180 | # R_feature loss
181 | loss_distr = sum([mod.r_feature for mod in loss_r_feature_layers])
182 | loss = loss + bn_reg_scale*loss_distr # best for noise before BN
183 |
184 | # l2 loss
185 | if 1:
186 | loss = loss + l2_coeff * torch.norm(inputs_jit, 2)
187 |
188 | if debug_output and epoch % 200==0:
189 | print(f"It {epoch}\t Losses: total: {loss.item():3.3f},\ttarget: {loss_target:3.3f} \tR_feature_loss unscaled:\t {loss_distr.item():3.3f}")
190 | vutils.save_image(inputs.data.clone(),
191 | './{}/output_{}.png'.format(prefix, epoch//200),
192 | normalize=True, scale_each=True, nrow=10)
193 |
194 | if best_cost > loss.item():
195 | best_cost = loss.item()
196 | best_inputs = inputs.data
197 |
198 | # backward pass
199 | if use_amp:
200 | with amp.scale_loss(loss, optimizer) as scaled_loss:
201 | scaled_loss.backward()
202 | else:
203 | loss.backward()
204 |
205 | optimizer.step()
206 |
207 | outputs=net(best_inputs)
208 | _, predicted_teach = outputs.max(1)
209 |
210 | outputs_student=net_student(best_inputs)
211 | _, predicted_std = outputs_student.max(1)
212 |
213 | if idx == 0:
214 | print('Teacher correct out of {}: {}, loss at {}'.format(bs, predicted_teach.eq(targets).sum().item(), criterion(outputs, targets).item()))
215 | print('Student correct out of {}: {}, loss at {}'.format(bs, predicted_std.eq(targets).sum().item(), criterion(outputs_student, targets).item()))
216 |
217 | name_use = "best_images"
218 | if prefix is not None:
219 | name_use = prefix + name_use
220 | next_batch = len(glob.glob("./%s/*.png" % name_use)) // 1
221 |
222 | vutils.save_image(best_inputs[:20].clone(),
223 | './{}/output_{}.png'.format(name_use, next_batch),
224 | normalize=True, scale_each = True, nrow=10)
225 |
226 | if train_writer is not None:
227 | train_writer.add_scalar('gener_teacher_criteria', criterion(outputs, targets), global_iteration)
228 | train_writer.add_scalar('gener_student_criteria', criterion(outputs_student, targets), global_iteration)
229 |
230 | train_writer.add_scalar('gener_teacher_acc', predicted_teach.eq(targets).sum().item() / bs, global_iteration)
231 | train_writer.add_scalar('gener_student_acc', predicted_std.eq(targets).sum().item() / bs, global_iteration)
232 |
233 | train_writer.add_scalar('gener_loss_total', loss.item(), global_iteration)
234 | train_writer.add_scalar('gener_loss_var', (var_scale*loss_var).item(), global_iteration)
235 |
236 | net_student.train()
237 |
238 | return best_inputs
239 |
240 |
241 | def test():
242 | print('==> Teacher validation')
243 | net_teacher.eval()
244 | test_loss = 0
245 | correct = 0
246 | total = 0
247 |
248 | with torch.no_grad():
249 | for batch_idx, (inputs, targets) in enumerate(testloader):
250 | inputs, targets = inputs.to(device), targets.to(device)
251 | outputs = net_teacher(inputs)
252 | loss = criterion(outputs, targets)
253 |
254 | test_loss += loss.item()
255 | _, predicted = outputs.max(1)
256 | total += targets.size(0)
257 | correct += predicted.eq(targets).sum().item()
258 |
259 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
260 | % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
261 |
262 |
263 | if __name__ == "__main__":
264 |
265 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 DeepInversion')
266 | parser.add_argument('--bs', default=256, type=int, help='batch size')
267 | parser.add_argument('--iters_mi', default=2000, type=int, help='number of iterations for model inversion')
268 | parser.add_argument('--cig_scale', default=0.0, type=float, help='competition score')
269 | parser.add_argument('--di_lr', default=0.1, type=float, help='lr for deep inversion')
270 | parser.add_argument('--di_var_scale', default=2.5e-5, type=float, help='TV L2 regularization coefficient')
271 | parser.add_argument('--di_l2_scale', default=0.0, type=float, help='L2 regularization coefficient')
272 | parser.add_argument('--r_feature_weight', default=1e2, type=float, help='weight for BN regularization statistic')
273 | parser.add_argument('--amp', action='store_true', help='use APEX AMP O1 acceleration')
274 | parser.add_argument('--exp_descr', default="try1", type=str, help='name to be added to experiment name')
275 | parser.add_argument('--teacher_weights', default="'./checkpoint/teacher_resnet34_only.weights'", type=str, help='path to load weights of the model')
276 |
277 | args = parser.parse_args()
278 |
279 | print("loading resnet34")
280 |
281 | net_teacher = ResNet34()
282 | net_student = ResNet18()
283 |
284 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
285 |
286 | net_student = net_student.to(device)
287 | net_teacher = net_teacher.to(device)
288 |
289 | criterion = nn.CrossEntropyLoss()
290 |
291 | # place holder for inputs
292 | data_type = torch.half if args.amp else torch.float
293 | inputs = torch.randn((args.bs, 3, 32, 32), requires_grad=True, device='cuda', dtype=data_type)
294 |
295 | optimizer_di = optim.Adam([inputs], lr=args.di_lr)
296 |
297 | if args.amp:
298 | opt_level = "O1"
299 | loss_scale = 'dynamic'
300 |
301 | [net_student, net_teacher], optimizer_di = amp.initialize(
302 | [net_student, net_teacher], optimizer_di,
303 | opt_level=opt_level,
304 | loss_scale=loss_scale)
305 |
306 | checkpoint = torch.load(args.teacher_weights)
307 | net_teacher.load_state_dict(checkpoint)
308 | net_teacher.eval() #important, otherwise generated images will be non natural
309 | if args.amp:
310 | # need to do this trick for FP16 support of batchnorms
311 | net_teacher.train()
312 | for module in net_teacher.modules():
313 | if isinstance(module, nn.BatchNorm2d):
314 | module.eval().half()
315 |
316 | cudnn.benchmark = True
317 |
318 |
319 | batch_idx = 0
320 | prefix = "runs/data_generation/"+args.exp_descr+"/"
321 |
322 | for create_folder in [prefix, prefix+"/best_images/"]:
323 | if not os.path.exists(create_folder):
324 | os.makedirs(create_folder)
325 |
326 | if 0:
327 | # loading
328 | transform_test = transforms.Compose([
329 | transforms.ToTensor(),
330 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
331 | ])
332 |
333 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
334 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=True, num_workers=6,
335 | drop_last=True)
336 | # Checking teacher accuracy
337 | print("Checking teacher accuracy")
338 | test()
339 |
340 |
341 | train_writer = None # tensorboard writter
342 | global_iteration = 0
343 |
344 | print("Starting model inversion")
345 |
346 | inputs = get_images(net=net_teacher, bs=args.bs, epochs=args.iters_mi, idx=batch_idx,
347 | net_student=net_student, prefix=prefix, competitive_scale=args.cig_scale,
348 | train_writer=train_writer, global_iteration=global_iteration, use_amp=args.amp,
349 | optimizer=optimizer_di, inputs=inputs, bn_reg_scale=args.r_feature_weight,
350 | var_scale=args.di_var_scale, random_labels=False, l2_coeff=args.di_l2_scale)
351 |
--------------------------------------------------------------------------------
/cifar10/images/better_last.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/cifar10/images/better_last.png
--------------------------------------------------------------------------------
/cifar10/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | # 2019.07.24-Changed output of forward function
2 | # Huawei Technologies Co., Ltd.
3 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py
4 | # for comparison with DAFL
5 |
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, in_planes, planes, stride=1):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 |
21 | self.shortcut = nn.Sequential()
22 | if stride != 1 or in_planes != self.expansion*planes:
23 | self.shortcut = nn.Sequential(
24 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
25 | nn.BatchNorm2d(self.expansion*planes)
26 | )
27 |
28 | def forward(self, x):
29 | out = F.relu(self.bn1(self.conv1(x)))
30 | out = self.bn2(self.conv2(out))
31 | out += self.shortcut(x)
32 | out = F.relu(out)
33 | return out
34 |
35 |
36 | class Bottleneck(nn.Module):
37 | expansion = 4
38 |
39 | def __init__(self, in_planes, planes, stride=1):
40 | super(Bottleneck, self).__init__()
41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
42 | self.bn1 = nn.BatchNorm2d(planes)
43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
47 |
48 | self.shortcut = nn.Sequential()
49 | if stride != 1 or in_planes != self.expansion*planes:
50 | self.shortcut = nn.Sequential(
51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
52 | nn.BatchNorm2d(self.expansion*planes)
53 | )
54 |
55 | def forward(self, x):
56 | out = F.relu(self.bn1(self.conv1(x)))
57 | out = F.relu(self.bn2(self.conv2(out)))
58 | out = self.bn3(self.conv3(out))
59 | out += self.shortcut(x)
60 | out = F.relu(out)
61 | return out
62 |
63 |
64 | class ResNet(nn.Module):
65 | def __init__(self, block, num_blocks, num_classes=10):
66 | super(ResNet, self).__init__()
67 | self.in_planes = 64
68 |
69 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
70 | self.bn1 = nn.BatchNorm2d(64)
71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
75 | self.linear = nn.Linear(512*block.expansion, num_classes)
76 |
77 | def _make_layer(self, block, planes, num_blocks, stride):
78 | strides = [stride] + [1]*(num_blocks-1)
79 | layers = []
80 | for stride in strides:
81 | layers.append(block(self.in_planes, planes, stride))
82 | self.in_planes = planes * block.expansion
83 | return nn.Sequential(*layers)
84 |
85 | def forward(self, x, out_feature=False):
86 | x = self.conv1(x)
87 |
88 | x = self.bn1(x)
89 | out = F.relu(x)
90 |
91 | out = self.layer1(out)
92 | out = self.layer2(out)
93 | out = self.layer3(out)
94 | out = self.layer4(out)
95 | out = F.avg_pool2d(out, 4)
96 | feature = out.view(out.size(0), -1)
97 | out = self.linear(feature)
98 | if out_feature == False:
99 | return out
100 | else:
101 | return out,feature
102 |
103 |
104 | def ResNet18(num_classes=10):
105 | return ResNet(BasicBlock, [2,2,2,2], num_classes)
106 |
107 | def ResNet34(num_classes=10):
108 | return ResNet(BasicBlock, [3,4,6,3], num_classes)
109 |
110 | def ResNet50(num_classes=10):
111 | return ResNet(Bottleneck, [3,4,6,3], num_classes)
112 |
113 | def ResNet101(num_classes=10):
114 | return ResNet(Bottleneck, [3,4,23,3], num_classes)
115 |
116 | def ResNet152(num_classes=10):
117 | return ResNet(Bottleneck, [3,8,36,3], num_classes)
118 |
119 |
--------------------------------------------------------------------------------
/deepinversion.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
3 | # Nvidia Source Code License-NC
4 | # Official PyTorch implementation of CVPR2020 paper
5 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion
6 | # Hongxu Yin, Pavlo Molchanov, Zhizhong Li, Jose M. Alvarez, Arun Mallya, Derek
7 | # Hoiem, Niraj K. Jha, and Jan Kautz
8 | # --------------------------------------------------------
9 |
10 | from __future__ import division, print_function
11 | from __future__ import absolute_import
12 | from __future__ import division
13 | from __future__ import unicode_literals
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | import collections
19 | import torch.cuda.amp as amp
20 | import random
21 | import torch
22 | import torchvision.utils as vutils
23 | from PIL import Image
24 | import numpy as np
25 |
26 | from utils.utils import lr_cosine_policy, lr_policy, beta_policy, mom_cosine_policy, clip, denormalize, create_folder
27 |
28 |
29 | class DeepInversionFeatureHook():
30 | '''
31 | Implementation of the forward hook to track feature statistics and compute a loss on them.
32 | Will compute mean and variance, and will use l2 as a loss
33 | '''
34 | def __init__(self, module):
35 | self.hook = module.register_forward_hook(self.hook_fn)
36 |
37 | def hook_fn(self, module, input, output):
38 | # hook co compute deepinversion's feature distribution regularization
39 | nch = input[0].shape[1]
40 | mean = input[0].mean([0, 2, 3])
41 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
42 |
43 | #forcing mean and variance to match between two distributions
44 | #other ways might work better, i.g. KL divergence
45 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
46 | module.running_mean.data - mean, 2)
47 |
48 | self.r_feature = r_feature
49 | # must have no output
50 |
51 | def close(self):
52 | self.hook.remove()
53 |
54 |
55 | def get_image_prior_losses(inputs_jit):
56 | # COMPUTE total variation regularization loss
57 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
58 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
59 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
60 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
61 |
62 | loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
63 | loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + (
64 | diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean()
65 | loss_var_l1 = loss_var_l1 * 255.0
66 | return loss_var_l1, loss_var_l2
67 |
68 |
69 | class DeepInversionClass(object):
70 | def __init__(self, bs=84,
71 | use_fp16=True, net_teacher=None, path="./gen_images/",
72 | final_data_path="/gen_images_final/",
73 | parameters=dict(),
74 | setting_id=0,
75 | jitter=30,
76 | criterion=None,
77 | coefficients=dict(),
78 | network_output_function=lambda x: x,
79 | hook_for_display = None):
80 | '''
81 | :param bs: batch size per GPU for image generation
82 | :param use_fp16: use FP16 (or APEX AMP) for model inversion, uses less memory and is faster for GPUs with Tensor Cores
83 | :parameter net_teacher: Pytorch model to be inverted
84 | :param path: path where to write temporal images and data
85 | :param final_data_path: path to write final images into
86 | :param parameters: a dictionary of control parameters:
87 | "resolution": input image resolution, single value, assumed to be a square, 224
88 | "random_label" : for classification initialize target to be random values
89 | "start_noise" : start from noise, def True, other options are not supported at this time
90 | "detach_student": if computing Adaptive DI, should we detach student?
91 | :param setting_id: predefined settings for optimization:
92 | 0 - will run low resolution optimization for 1k and then full resolution for 1k;
93 | 1 - will run optimization on high resolution for 2k
94 | 2 - will run optimization on high resolution for 20k
95 |
96 | :param jitter: amount of random shift applied to image at every iteration
97 | :param coefficients: dictionary with parameters and coefficients for optimization.
98 | keys:
99 | "r_feature" - coefficient for feature distribution regularization
100 | "tv_l1" - coefficient for total variation L1 loss
101 | "tv_l2" - coefficient for total variation L2 loss
102 | "l2" - l2 penalization weight
103 | "lr" - learning rate for optimization
104 | "main_loss_multiplier" - coefficient for the main loss optimization
105 | "adi_scale" - coefficient for Adaptive DeepInversion, competition, def =0 means no competition
106 | network_output_function: function to be applied to the output of the network to get the output
107 | hook_for_display: function to be executed at every print/save call, useful to check accuracy of verifier
108 | '''
109 |
110 | print("Deep inversion class generation")
111 | # for reproducibility
112 | torch.manual_seed(torch.cuda.current_device())
113 |
114 | self.net_teacher = net_teacher
115 |
116 | if "resolution" in parameters.keys():
117 | self.image_resolution = parameters["resolution"]
118 | self.random_label = parameters["random_label"]
119 | self.start_noise = parameters["start_noise"]
120 | self.detach_student = parameters["detach_student"]
121 | self.do_flip = parameters["do_flip"]
122 | self.store_best_images = parameters["store_best_images"]
123 | else:
124 | self.image_resolution = 224
125 | self.random_label = False
126 | self.start_noise = True
127 | self.detach_student = False
128 | self.do_flip = True
129 | self.store_best_images = False
130 |
131 | self.setting_id = setting_id
132 | self.bs = bs # batch size
133 | self.use_fp16 = use_fp16
134 | self.save_every = 100
135 | self.jitter = jitter
136 | self.criterion = criterion
137 | self.network_output_function = network_output_function
138 | do_clip = True
139 |
140 | if "r_feature" in coefficients:
141 | self.bn_reg_scale = coefficients["r_feature"]
142 | self.first_bn_multiplier = coefficients["first_bn_multiplier"]
143 | self.var_scale_l1 = coefficients["tv_l1"]
144 | self.var_scale_l2 = coefficients["tv_l2"]
145 | self.l2_scale = coefficients["l2"]
146 | self.lr = coefficients["lr"]
147 | self.main_loss_multiplier = coefficients["main_loss_multiplier"]
148 | self.adi_scale = coefficients["adi_scale"]
149 | else:
150 | print("Provide a dictionary with ")
151 |
152 | self.num_generations = 0
153 | self.final_data_path = final_data_path
154 |
155 | ## Create folders for images and logs
156 | prefix = path
157 | self.prefix = prefix
158 |
159 | local_rank = torch.cuda.current_device()
160 | if local_rank==0:
161 | create_folder(prefix)
162 | create_folder(prefix + "/best_images/")
163 | create_folder(self.final_data_path)
164 | # save images to folders
165 | # for m in range(1000):
166 | # create_folder(self.final_data_path + "/s{:03d}".format(m))
167 |
168 | ## Create hooks for feature statistics
169 | self.loss_r_feature_layers = []
170 |
171 | for module in self.net_teacher.modules():
172 | if isinstance(module, nn.BatchNorm2d):
173 | self.loss_r_feature_layers.append(DeepInversionFeatureHook(module))
174 |
175 | self.hook_for_display = None
176 | if hook_for_display is not None:
177 | self.hook_for_display = hook_for_display
178 |
179 | def get_images(self, net_student=None, targets=None):
180 | print("get_images call")
181 |
182 | net_teacher = self.net_teacher
183 | use_fp16 = self.use_fp16
184 | save_every = self.save_every
185 |
186 | kl_loss = nn.KLDivLoss(reduction='batchmean').cuda()
187 | local_rank = torch.cuda.current_device()
188 | best_cost = 1e4
189 | criterion = self.criterion
190 |
191 | # setup target labels
192 | if targets is None:
193 | #only works for classification now, for other tasks need to provide target vector
194 | targets = torch.LongTensor([random.randint(0, 999) for _ in range(self.bs)]).to('cuda')
195 | if not self.random_label:
196 | # preselected classes, good for ResNet50v1.5
197 | targets = [1, 933, 946, 980, 25, 63, 92, 94, 107, 985, 151, 154, 207, 250, 270, 277, 283, 292, 294, 309,
198 | 311,
199 | 325, 340, 360, 386, 402, 403, 409, 530, 440, 468, 417, 590, 670, 817, 762, 920, 949, 963,
200 | 967, 574, 487]
201 |
202 | targets = torch.LongTensor(targets * (int(self.bs / len(targets)))).to('cuda')
203 |
204 | img_original = self.image_resolution
205 |
206 | data_type = torch.half if use_fp16 else torch.float
207 | inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda',
208 | dtype=data_type)
209 | pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2)
210 |
211 | if self.setting_id==0:
212 | skipfirst = False
213 | else:
214 | skipfirst = True
215 |
216 | iteration = 0
217 | for lr_it, lower_res in enumerate([2, 1]):
218 | if lr_it==0:
219 | iterations_per_layer = 2000
220 | else:
221 | iterations_per_layer = 1000 if not skipfirst else 2000
222 | if self.setting_id == 2:
223 | iterations_per_layer = 20000
224 |
225 | if lr_it==0 and skipfirst:
226 | continue
227 |
228 | lim_0, lim_1 = self.jitter // lower_res, self.jitter // lower_res
229 |
230 | if self.setting_id == 0:
231 | #multi resolution, 2k iterations with low resolution, 1k at normal, ResNet50v1.5 works the best, ResNet50 is ok
232 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps = 1e-8)
233 | do_clip = True
234 | elif self.setting_id == 1:
235 | #2k normal resolultion, for ResNet50v1.5; Resnet50 works as well
236 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps = 1e-8)
237 | do_clip = True
238 | elif self.setting_id == 2:
239 | #20k normal resolution the closes to the paper experiments for ResNet50
240 | optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.9, 0.999], eps = 1e-8)
241 | do_clip = False
242 |
243 | if use_fp16:
244 | static_loss_scale = 256
245 | static_loss_scale = "dynamic"
246 | _, optimizer = amp.initialize([], optimizer, opt_level="O2", loss_scale=static_loss_scale)
247 |
248 | lr_scheduler = lr_cosine_policy(self.lr, 100, iterations_per_layer)
249 |
250 | for iteration_loc in range(iterations_per_layer):
251 | iteration += 1
252 | # learning rate scheduling
253 | lr_scheduler(optimizer, iteration_loc, iteration_loc)
254 |
255 | # perform downsampling if needed
256 | if lower_res!=1:
257 | inputs_jit = pooling_function(inputs)
258 | else:
259 | inputs_jit = inputs
260 |
261 | # apply random jitter offsets
262 | off1 = random.randint(-lim_0, lim_0)
263 | off2 = random.randint(-lim_1, lim_1)
264 | inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3))
265 |
266 | # Flipping
267 | flip = random.random() > 0.5
268 | if flip and self.do_flip:
269 | inputs_jit = torch.flip(inputs_jit, dims=(3,))
270 |
271 | # forward pass
272 | optimizer.zero_grad()
273 | net_teacher.zero_grad()
274 |
275 | outputs = net_teacher(inputs_jit)
276 | outputs = self.network_output_function(outputs)
277 |
278 | # R_cross classification loss
279 | loss = criterion(outputs, targets)
280 |
281 | # R_prior losses
282 | loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit)
283 |
284 | # R_feature loss
285 | rescale = [self.first_bn_multiplier] + [1. for _ in range(len(self.loss_r_feature_layers)-1)]
286 | loss_r_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers)])
287 |
288 | # R_ADI
289 | loss_verifier_cig = torch.zeros(1)
290 | if self.adi_scale!=0.0:
291 | if self.detach_student:
292 | outputs_student = net_student(inputs_jit).detach()
293 | else:
294 | outputs_student = net_student(inputs_jit)
295 |
296 | T = 3.0
297 | if 1:
298 | T = 3.0
299 | # Jensen Shanon divergence:
300 | # another way to force KL between negative probabilities
301 | P = nn.functional.softmax(outputs_student / T, dim=1)
302 | Q = nn.functional.softmax(outputs / T, dim=1)
303 | M = 0.5 * (P + Q)
304 |
305 | P = torch.clamp(P, 0.01, 0.99)
306 | Q = torch.clamp(Q, 0.01, 0.99)
307 | M = torch.clamp(M, 0.01, 0.99)
308 | eps = 0.0
309 | loss_verifier_cig = 0.5 * kl_loss(torch.log(P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M)
310 | # JS criteria - 0 means full correlation, 1 - means completely different
311 | loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0, 1.0)
312 |
313 | if local_rank==0:
314 | if iteration % save_every==0:
315 | print('loss_verifier_cig', loss_verifier_cig.item())
316 |
317 | # l2 loss on images
318 | loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean()
319 |
320 | # combining losses
321 | loss_aux = self.var_scale_l2 * loss_var_l2 + \
322 | self.var_scale_l1 * loss_var_l1 + \
323 | self.bn_reg_scale * loss_r_feature + \
324 | self.l2_scale * loss_l2
325 |
326 | if self.adi_scale!=0.0:
327 | loss_aux += self.adi_scale * loss_verifier_cig
328 |
329 | loss = self.main_loss_multiplier * loss + loss_aux
330 |
331 | if local_rank==0:
332 | if iteration % save_every==0:
333 | print("------------iteration {}----------".format(iteration))
334 | print("total loss", loss.item())
335 | print("loss_r_feature", loss_r_feature.item())
336 | print("main criterion", criterion(outputs, targets).item())
337 |
338 | if self.hook_for_display is not None:
339 | self.hook_for_display(inputs, targets)
340 |
341 | # do image update
342 | if use_fp16:
343 | # optimizer.backward(loss)
344 | with amp.scale_loss(loss, optimizer) as scaled_loss:
345 | scaled_loss.backward()
346 | else:
347 | loss.backward()
348 |
349 | optimizer.step()
350 |
351 | # clip color outlayers
352 | if do_clip:
353 | inputs.data = clip(inputs.data, use_fp16=use_fp16)
354 |
355 | if best_cost > loss.item() or iteration == 1:
356 | best_inputs = inputs.data.clone()
357 | best_cost = loss.item()
358 |
359 | if iteration % save_every==0 and (save_every > 0):
360 | if local_rank==0:
361 | vutils.save_image(inputs,
362 | '{}/best_images/output_{:05d}_gpu_{}.png'.format(self.prefix,
363 | iteration // save_every,
364 | local_rank),
365 | normalize=True, scale_each=True, nrow=int(10))
366 |
367 | if self.store_best_images:
368 | best_inputs = denormalize(best_inputs)
369 | self.save_images(best_inputs, targets)
370 |
371 | # to reduce memory consumption by states of the optimizer we deallocate memory
372 | optimizer.state = collections.defaultdict(dict)
373 |
374 | def save_images(self, images, targets):
375 | # method to store generated images locally
376 | local_rank = torch.cuda.current_device()
377 | for id in range(images.shape[0]):
378 | class_id = targets[id].item()
379 | if 0:
380 | #save into separate folders
381 | place_to_store = '{}/s{:03d}/img_{:05d}_id{:03d}_gpu_{}_2.jpg'.format(self.final_data_path, class_id,
382 | self.num_generations, id,
383 | local_rank)
384 | else:
385 | place_to_store = '{}/img_s{:03d}_{:05d}_id{:03d}_gpu_{}_2.jpg'.format(self.final_data_path, class_id,
386 | self.num_generations, id,
387 | local_rank)
388 |
389 | image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
390 | pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
391 | pil_image.save(place_to_store)
392 |
393 | def generate_batch(self, net_student=None, targets=None):
394 | # for ADI detach student and add put to eval mode
395 | net_teacher = self.net_teacher
396 |
397 | use_fp16 = self.use_fp16
398 |
399 | # fix net_student
400 | if not (net_student is None):
401 | net_student = net_student.eval()
402 |
403 | if targets is not None:
404 | targets = torch.from_numpy(np.array(targets).squeeze()).cuda()
405 | if use_fp16:
406 | targets = targets.half()
407 |
408 | self.get_images(net_student=net_student, targets=targets)
409 |
410 | net_teacher.eval()
411 |
412 | self.num_generations += 1
413 |
--------------------------------------------------------------------------------
/example_logs/fp16_set0_rn50.log:
--------------------------------------------------------------------------------
1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp16" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0.2 --adi_scale=0.0 --l2=0.00001 --fp16
2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp16', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1)
3 | loading torchvision model for inversion with the name: resnet50
4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
5 |
6 | Defaults for this optimization level are:
7 | enabled : True
8 | opt_level : O2
9 | cast_model_type : torch.float16
10 | patch_torch_functions : False
11 | keep_batchnorm_fp32 : True
12 | master_weights : True
13 | loss_scale : dynamic
14 | Processing user overrides (additional kwargs that are not None)...
15 | After processing overrides, optimization options are:
16 | enabled : True
17 | opt_level : O2
18 | cast_model_type : torch.float16
19 | patch_torch_functions : False
20 | keep_batchnorm_fp32 : True
21 | master_weights : True
22 | loss_scale : dynamic
23 | ==> Resuming from checkpoint..
24 | ==> Getting BN params as feature statistics
25 | loading verifier: mobilenet_v2
26 | Deep inversion class generation
27 | get_images call
28 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
29 |
30 | Defaults for this optimization level are:
31 | enabled : True
32 | opt_level : O2
33 | cast_model_type : torch.float16
34 | patch_torch_functions : False
35 | keep_batchnorm_fp32 : True
36 | master_weights : True
37 | loss_scale : dynamic
38 | Processing user overrides (additional kwargs that are not None)...
39 | After processing overrides, optimization options are:
40 | enabled : True
41 | opt_level : O2
42 | cast_model_type : torch.float16
43 | patch_torch_functions : False
44 | keep_batchnorm_fp32 : True
45 | master_weights : True
46 | loss_scale : dynamic
47 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0
48 | ------------iteration 100----------
49 | total loss 1.6876643896102905
50 | loss_r_feature 90.38799285888672
51 | main criterion 0.1843467503786087
52 | Verifier accuracy: 0.0
53 | ------------iteration 200----------
54 | total loss 1.3778828382492065
55 | loss_r_feature 78.72280883789062
56 | main criterion 0.006536892615258694
57 | Verifier accuracy: 3.5714285373687744
58 | ------------iteration 300----------
59 | total loss 1.1503252983093262
60 | loss_r_feature 64.16847229003906
61 | main criterion 0.03804704174399376
62 | Verifier accuracy: 5.952381134033203
63 | ------------iteration 400----------
64 | total loss 1.122263789176941
65 | loss_r_feature 63.72652816772461
66 | main criterion 0.014303048141300678
67 | Verifier accuracy: 9.523809432983398
68 | ------------iteration 500----------
69 | total loss 1.0680581331253052
70 | loss_r_feature 59.28545379638672
71 | main criterion 0.01540193147957325
72 | Verifier accuracy: 10.714285850524902
73 | ------------iteration 600----------
74 | total loss 0.9677876830101013
75 | loss_r_feature 54.25212478637695
76 | main criterion 0.006902865134179592
77 | Verifier accuracy: 10.714285850524902
78 | ------------iteration 700----------
79 | total loss 0.8655126094818115
80 | loss_r_feature 48.535343170166016
81 | main criterion 0.0010316940024495125
82 | Verifier accuracy: 10.714285850524902
83 | ------------iteration 800----------
84 | total loss 0.7858816981315613
85 | loss_r_feature 43.149375915527344
86 | main criterion 0.0030028026085346937
87 | Verifier accuracy: 8.333333015441895
88 | ------------iteration 900----------
89 | total loss 0.6898418068885803
90 | loss_r_feature 38.205223083496094
91 | main criterion 0.0007694562082178891
92 | Verifier accuracy: 14.285714149475098
93 | ------------iteration 1000----------
94 | total loss 0.5907580256462097
95 | loss_r_feature 31.332698822021484
96 | main criterion 0.006730636116117239
97 | Verifier accuracy: 10.714285850524902
98 | ------------iteration 1100----------
99 | total loss 0.5132060050964355
100 | loss_r_feature 26.838134765625
101 | main criterion 0.0011799221392720938
102 | Verifier accuracy: 9.523809432983398
103 | ------------iteration 1200----------
104 | total loss 0.4708683490753174
105 | loss_r_feature 24.262351989746094
106 | main criterion 0.0008049465250223875
107 | Verifier accuracy: 10.714285850524902
108 | ------------iteration 1300----------
109 | total loss 0.4237711727619171
110 | loss_r_feature 21.36450958251953
111 | main criterion 0.0014830997679382563
112 | Verifier accuracy: 5.952381134033203
113 | ------------iteration 1400----------
114 | total loss 0.39408981800079346
115 | loss_r_feature 20.10096549987793
116 | main criterion 0.0006744748097844422
117 | Verifier accuracy: 9.523809432983398
118 | ------------iteration 1500----------
119 | total loss 0.33674177527427673
120 | loss_r_feature 16.06620216369629
121 | main criterion 0.0008936041849665344
122 | Verifier accuracy: 10.714285850524902
123 | ------------iteration 1600----------
124 | total loss 0.3079169690608978
125 | loss_r_feature 14.631080627441406
126 | main criterion 0.0008300145273096859
127 | Verifier accuracy: 13.095237731933594
128 | ------------iteration 1700----------
129 | total loss 0.2896236181259155
130 | loss_r_feature 13.912619590759277
131 | main criterion 0.000469207763671875
132 | Verifier accuracy: 10.714285850524902
133 | ------------iteration 1800----------
134 | total loss 0.27033284306526184
135 | loss_r_feature 12.821993827819824
136 | main criterion 0.000999643700197339
137 | Verifier accuracy: 8.333333015441895
138 | ------------iteration 1900----------
139 | total loss 0.26189038157463074
140 | loss_r_feature 12.395000457763672
141 | main criterion 0.0004911195719614625
142 | Verifier accuracy: 8.333333015441895
143 | ------------iteration 2000----------
144 | total loss 0.2615959644317627
145 | loss_r_feature 12.435093879699707
146 | main criterion 0.0006502469186671078
147 | Verifier accuracy: 8.333333015441895
148 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
149 |
150 | Defaults for this optimization level are:
151 | enabled : True
152 | opt_level : O2
153 | cast_model_type : torch.float16
154 | patch_torch_functions : False
155 | keep_batchnorm_fp32 : True
156 | master_weights : True
157 | loss_scale : dynamic
158 | Processing user overrides (additional kwargs that are not None)...
159 | After processing overrides, optimization options are:
160 | enabled : True
161 | opt_level : O2
162 | cast_model_type : torch.float16
163 | patch_torch_functions : False
164 | keep_batchnorm_fp32 : True
165 | master_weights : True
166 | loss_scale : dynamic
167 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0
168 | ------------iteration 2100----------
169 | total loss 1.1602451801300049
170 | loss_r_feature 43.01371383666992
171 | main criterion 0.002820753026753664
172 | Verifier accuracy: 64.28571319580078
173 | ------------iteration 2200----------
174 | total loss 1.2936094999313354
175 | loss_r_feature 46.4441032409668
176 | main criterion 0.0017376854084432125
177 | Verifier accuracy: 65.47618865966797
178 | ------------iteration 2300----------
179 | total loss 0.9823821187019348
180 | loss_r_feature 38.5743408203125
181 | main criterion 0.0017328716348856688
182 | Verifier accuracy: 77.38095092773438
183 | ------------iteration 2400----------
184 | total loss 0.8478145003318787
185 | loss_r_feature 31.85713768005371
186 | main criterion 0.002227487973868847
187 | Verifier accuracy: 88.0952377319336
188 | ------------iteration 2500----------
189 | total loss 0.7831273674964905
190 | loss_r_feature 30.018978118896484
191 | main criterion 0.0015915121184661984
192 | Verifier accuracy: 88.0952377319336
193 | ------------iteration 2600----------
194 | total loss 0.6972528100013733
195 | loss_r_feature 25.377971649169922
196 | main criterion 0.0012009029742330313
197 | Verifier accuracy: 90.47618865966797
198 | ------------iteration 2700----------
199 | total loss 0.5916131734848022
200 | loss_r_feature 22.423397064208984
201 | main criterion 0.0010576248168945312
202 | Verifier accuracy: 96.42857360839844
203 | ------------iteration 2800----------
204 | total loss 0.4864479899406433
205 | loss_r_feature 18.20646095275879
206 | main criterion 0.0008288564858958125
207 | Verifier accuracy: 98.80952453613281
208 | ------------iteration 2900----------
209 | total loss 0.4284505844116211
210 | loss_r_feature 16.447580337524414
211 | main criterion 0.0007148924050852656
212 | Verifier accuracy: 100.0
213 | ------------iteration 3000----------
214 | total loss 0.4071371555328369
215 | loss_r_feature 15.495633125305176
216 | main criterion 0.0005214327829889953
217 | Verifier accuracy: 98.80952453613281
218 |
--------------------------------------------------------------------------------
/example_logs/fp16_set0_rn50_adi02.log:
--------------------------------------------------------------------------------
1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_adi02" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0.2 --adi_scale=0.2 --l2=0.00001 --fp16 > fp16_set0_rn50_adi02.log
2 | Namespace(adi_scale=0.2, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_adi02', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1)
3 | loading torchvision model for inversion with the name: resnet50
4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
5 |
6 | Defaults for this optimization level are:
7 | enabled : True
8 | opt_level : O2
9 | cast_model_type : torch.float16
10 | patch_torch_functions : False
11 | keep_batchnorm_fp32 : True
12 | master_weights : True
13 | loss_scale : dynamic
14 | Processing user overrides (additional kwargs that are not None)...
15 | After processing overrides, optimization options are:
16 | enabled : True
17 | opt_level : O2
18 | cast_model_type : torch.float16
19 | patch_torch_functions : False
20 | keep_batchnorm_fp32 : True
21 | master_weights : True
22 | loss_scale : dynamic
23 | ==> Resuming from checkpoint..
24 | ==> Getting BN params as feature statistics
25 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
26 |
27 | Defaults for this optimization level are:
28 | enabled : True
29 | opt_level : O2
30 | cast_model_type : torch.float16
31 | patch_torch_functions : False
32 | keep_batchnorm_fp32 : True
33 | master_weights : True
34 | loss_scale : dynamic
35 | Processing user overrides (additional kwargs that are not None)...
36 | After processing overrides, optimization options are:
37 | enabled : True
38 | opt_level : O2
39 | cast_model_type : torch.float16
40 | patch_torch_functions : False
41 | keep_batchnorm_fp32 : True
42 | master_weights : True
43 | loss_scale : dynamic
44 | Deep inversion class generation
45 | get_images call
46 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
47 |
48 | Defaults for this optimization level are:
49 | enabled : True
50 | opt_level : O2
51 | cast_model_type : torch.float16
52 | patch_torch_functions : False
53 | keep_batchnorm_fp32 : True
54 | master_weights : True
55 | loss_scale : dynamic
56 | Processing user overrides (additional kwargs that are not None)...
57 | After processing overrides, optimization options are:
58 | enabled : True
59 | opt_level : O2
60 | cast_model_type : torch.float16
61 | patch_torch_functions : False
62 | keep_batchnorm_fp32 : True
63 | master_weights : True
64 | loss_scale : dynamic
65 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0
66 | loss_verifier_cig 0.12461590766906738
67 | ------------iteration 100----------
68 | total loss 1.5717917680740356
69 | loss_r_feature 83.22199249267578
70 | main criterion 0.18712279200553894
71 | Verifier accuracy: 1.1904761791229248
72 | loss_verifier_cig 0.0
73 | ------------iteration 200----------
74 | total loss 1.2926644086837769
75 | loss_r_feature 75.16395568847656
76 | main criterion 0.009654657915234566
77 | Verifier accuracy: 3.5714285373687744
78 | loss_verifier_cig 0.029001593589782715
79 | ------------iteration 300----------
80 | total loss 1.1569472551345825
81 | loss_r_feature 66.2757797241211
82 | main criterion 0.016784803941845894
83 | Verifier accuracy: 3.5714285373687744
84 | loss_verifier_cig 0.0
85 | ------------iteration 400----------
86 | total loss 1.1502078771591187
87 | loss_r_feature 66.83061218261719
88 | main criterion 0.002368540968745947
89 | Verifier accuracy: 8.333333015441895
90 | loss_verifier_cig 0.0
91 | ------------iteration 500----------
92 | total loss 1.0058181285858154
93 | loss_r_feature 58.08607482910156
94 | main criterion 0.0015441349241882563
95 | Verifier accuracy: 4.761904716491699
96 | loss_verifier_cig 0.0
97 | ------------iteration 600----------
98 | total loss 1.006028175354004
99 | loss_r_feature 57.82709503173828
100 | main criterion 0.0010576475178822875
101 | Verifier accuracy: 2.3809523582458496
102 | loss_verifier_cig 0.0
103 | ------------iteration 700----------
104 | total loss 0.8937127590179443
105 | loss_r_feature 50.34546661376953
106 | main criterion 0.0013211341574788094
107 | Verifier accuracy: 5.952381134033203
108 | loss_verifier_cig 0.0
109 | ------------iteration 800----------
110 | total loss 0.7877690196037292
111 | loss_r_feature 44.32988357543945
112 | main criterion 0.0008785157115198672
113 | Verifier accuracy: 8.333333015441895
114 | loss_verifier_cig 0.0
115 | ------------iteration 900----------
116 | total loss 0.7290716171264648
117 | loss_r_feature 40.597286224365234
118 | main criterion 0.002379996469244361
119 | Verifier accuracy: 7.142857074737549
120 | loss_verifier_cig 0.04963517189025879
121 | ------------iteration 1000----------
122 | total loss 0.7164633870124817
123 | loss_r_feature 34.76940155029297
124 | main criterion 0.07151930779218674
125 | Verifier accuracy: 7.142857074737549
126 | loss_verifier_cig 0.0
127 | ------------iteration 1100----------
128 | total loss 0.5727385878562927
129 | loss_r_feature 30.762893676757812
130 | main criterion 0.0021644660737365484
131 | Verifier accuracy: 5.952381134033203
132 | loss_verifier_cig 0.029660344123840332
133 | ------------iteration 1200----------
134 | total loss 0.5012800097465515
135 | loss_r_feature 26.345243453979492
136 | main criterion 0.0016442026244476438
137 | Verifier accuracy: 4.761904716491699
138 | loss_verifier_cig 0.0
139 | ------------iteration 1300----------
140 | total loss 0.45760539174079895
141 | loss_r_feature 24.4682674407959
142 | main criterion 0.0017067590961232781
143 | Verifier accuracy: 7.142857074737549
144 | loss_verifier_cig 0.0
145 | ------------iteration 1400----------
146 | total loss 0.4178113341331482
147 | loss_r_feature 21.507671356201172
148 | main criterion 0.0004278591659385711
149 | Verifier accuracy: 8.333333015441895
150 | loss_verifier_cig 0.0
151 | ------------iteration 1500----------
152 | total loss 0.3722783029079437
153 | loss_r_feature 18.461204528808594
154 | main criterion 0.00037799563142471015
155 | Verifier accuracy: 4.761904716491699
156 | loss_verifier_cig 0.0
157 | ------------iteration 1600----------
158 | total loss 0.3312247097492218
159 | loss_r_feature 16.000471115112305
160 | main criterion 0.0004169373423792422
161 | Verifier accuracy: 5.952381134033203
162 | loss_verifier_cig 0.0
163 | ------------iteration 1700----------
164 | total loss 0.3149060606956482
165 | loss_r_feature 15.741662979125977
166 | main criterion 0.0002418699732515961
167 | Verifier accuracy: 8.333333015441895
168 | loss_verifier_cig 0.0
169 | ------------iteration 1800----------
170 | total loss 0.28622013330459595
171 | loss_r_feature 13.761808395385742
172 | main criterion 0.00026948112645186484
173 | Verifier accuracy: 5.952381134033203
174 | loss_verifier_cig 0.0
175 | ------------iteration 1900----------
176 | total loss 0.2793577015399933
177 | loss_r_feature 13.480212211608887
178 | main criterion 0.00037529354449361563
179 | Verifier accuracy: 5.952381134033203
180 | loss_verifier_cig 0.0
181 | ------------iteration 2000----------
182 | total loss 0.2799757122993469
183 | loss_r_feature 13.588874816894531
184 | main criterion 0.0003949347010347992
185 | Verifier accuracy: 4.761904716491699
186 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
187 |
188 | Defaults for this optimization level are:
189 | enabled : True
190 | opt_level : O2
191 | cast_model_type : torch.float16
192 | patch_torch_functions : False
193 | keep_batchnorm_fp32 : True
194 | master_weights : True
195 | loss_scale : dynamic
196 | Processing user overrides (additional kwargs that are not None)...
197 | After processing overrides, optimization options are:
198 | enabled : True
199 | opt_level : O2
200 | cast_model_type : torch.float16
201 | patch_torch_functions : False
202 | keep_batchnorm_fp32 : True
203 | master_weights : True
204 | loss_scale : dynamic
205 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0
206 | loss_verifier_cig 0.1591559648513794
207 | ------------iteration 2100----------
208 | total loss 1.305101990699768
209 | loss_r_feature 49.516441345214844
210 | main criterion 0.00047656468814238906
211 | Verifier accuracy: 9.523809432983398
212 | loss_verifier_cig 0.0571821928024292
213 | ------------iteration 2200----------
214 | total loss 1.1693669557571411
215 | loss_r_feature 46.006778717041016
216 | main criterion 0.00037206921842880547
217 | Verifier accuracy: 10.714285850524902
218 | loss_verifier_cig 0.0
219 | ------------iteration 2300----------
220 | total loss 0.9874729514122009
221 | loss_r_feature 38.61001205444336
222 | main criterion 0.0001040413262671791
223 | Verifier accuracy: 3.5714285373687744
224 | loss_verifier_cig 0.00839078426361084
225 | ------------iteration 2400----------
226 | total loss 0.8956082463264465
227 | loss_r_feature 34.82595443725586
228 | main criterion 8.489972242387012e-05
229 | Verifier accuracy: 3.5714285373687744
230 | loss_verifier_cig 0.0
231 | ------------iteration 2500----------
232 | total loss 0.8721445798873901
233 | loss_r_feature 34.64632797241211
234 | main criterion 0.00011521294072736055
235 | Verifier accuracy: 7.142857074737549
236 | loss_verifier_cig 0.056126534938812256
237 | ------------iteration 2600----------
238 | total loss 0.7324535250663757
239 | loss_r_feature 28.194265365600586
240 | main criterion 0.0001758393773343414
241 | Verifier accuracy: 9.523809432983398
242 | loss_verifier_cig 0.001925349235534668
243 | ------------iteration 2700----------
244 | total loss 0.6335565447807312
245 | loss_r_feature 24.50282096862793
246 | main criterion 7.415952859446406e-05
247 | Verifier accuracy: 10.714285850524902
248 | loss_verifier_cig 0.006530642509460449
249 | ------------iteration 2800----------
250 | total loss 0.5335965752601624
251 | loss_r_feature 20.933963775634766
252 | main criterion 9.012222290039062e-05
253 | Verifier accuracy: 16.66666603088379
254 | loss_verifier_cig 0.0
255 | ------------iteration 2900----------
256 | total loss 0.4517979919910431
257 | loss_r_feature 17.627702713012695
258 | main criterion 3.878275674651377e-05
259 | Verifier accuracy: 9.523809432983398
260 | loss_verifier_cig 0.0
261 | ------------iteration 3000----------
262 | total loss 0.43130582571029663
263 | loss_r_feature 16.748857498168945
264 | main criterion 5.76518832531292e-05
265 | Verifier accuracy: 10.714285850524902
266 |
--------------------------------------------------------------------------------
/example_logs/fp16_set0_rn50_adi02_output_00030_gpu_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set0_rn50_adi02_output_00030_gpu_0.jpg
--------------------------------------------------------------------------------
/example_logs/fp16_set0_rn50_output_00030_gpu_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set0_rn50_output_00030_gpu_0.jpg
--------------------------------------------------------------------------------
/example_logs/fp16_set1_rn50.log:
--------------------------------------------------------------------------------
1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp16_set1" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=1 --lr=0.2 --adi_scale=0.0 --l2=0.00001 --fp16
2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp16_set1', fp16=True, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=1, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1)
3 | loading torchvision model for inversion with the name: resnet50
4 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
5 |
6 | Defaults for this optimization level are:
7 | enabled : True
8 | opt_level : O2
9 | cast_model_type : torch.float16
10 | patch_torch_functions : False
11 | keep_batchnorm_fp32 : True
12 | master_weights : True
13 | loss_scale : dynamic
14 | Processing user overrides (additional kwargs that are not None)...
15 | After processing overrides, optimization options are:
16 | enabled : True
17 | opt_level : O2
18 | cast_model_type : torch.float16
19 | patch_torch_functions : False
20 | keep_batchnorm_fp32 : True
21 | master_weights : True
22 | loss_scale : dynamic
23 | ==> Resuming from checkpoint..
24 | ==> Getting BN params as feature statistics
25 | loading verifier: mobilenet_v2
26 | Deep inversion class generation
27 | get_images call
28 | Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
29 |
30 | Defaults for this optimization level are:
31 | enabled : True
32 | opt_level : O2
33 | cast_model_type : torch.float16
34 | patch_torch_functions : False
35 | keep_batchnorm_fp32 : True
36 | master_weights : True
37 | loss_scale : dynamic
38 | Processing user overrides (additional kwargs that are not None)...
39 | After processing overrides, optimization options are:
40 | enabled : True
41 | opt_level : O2
42 | cast_model_type : torch.float16
43 | patch_torch_functions : False
44 | keep_batchnorm_fp32 : True
45 | master_weights : True
46 | loss_scale : dynamic
47 | Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0
48 | ------------iteration 100----------
49 | total loss 1.8805938959121704
50 | loss_r_feature 76.06755828857422
51 | main criterion 0.006961311679333448
52 | Verifier accuracy: 7.142857074737549
53 | ------------iteration 200----------
54 | total loss 1.3011034727096558
55 | loss_r_feature 58.76356887817383
56 | main criterion 0.0030642691999673843
57 | Verifier accuracy: 40.47618865966797
58 | ------------iteration 300----------
59 | total loss 1.096110224723816
60 | loss_r_feature 50.20344924926758
61 | main criterion 0.0015656152972951531
62 | Verifier accuracy: 76.19047546386719
63 | ------------iteration 400----------
64 | total loss 1.0675420761108398
65 | loss_r_feature 46.52493667602539
66 | main criterion 0.0015199638437479734
67 | Verifier accuracy: 67.85713958740234
68 | ------------iteration 500----------
69 | total loss 1.0750524997711182
70 | loss_r_feature 42.88349533081055
71 | main criterion 0.001011076383292675
72 | Verifier accuracy: 70.23809814453125
73 | ------------iteration 600----------
74 | total loss 0.9664787650108337
75 | loss_r_feature 38.243465423583984
76 | main criterion 0.001509393914602697
77 | Verifier accuracy: 63.095237731933594
78 | ------------iteration 700----------
79 | total loss 0.9061623811721802
80 | loss_r_feature 36.1928825378418
81 | main criterion 0.002603746484965086
82 | Verifier accuracy: 64.28571319580078
83 | ------------iteration 800----------
84 | total loss 0.8216720223426819
85 | loss_r_feature 32.11809539794922
86 | main criterion 0.0023030894808471203
87 | Verifier accuracy: 75.0
88 | ------------iteration 900----------
89 | total loss 0.9163199067115784
90 | loss_r_feature 36.28833770751953
91 | main criterion 0.05596952140331268
92 | Verifier accuracy: 72.61904907226562
93 | ------------iteration 1000----------
94 | total loss 0.7641246318817139
95 | loss_r_feature 30.2562313079834
96 | main criterion 0.0017005829140543938
97 | Verifier accuracy: 82.14285278320312
98 | ------------iteration 1100----------
99 | total loss 0.7134910225868225
100 | loss_r_feature 27.443801879882812
101 | main criterion 0.0016521726502105594
102 | Verifier accuracy: 84.52381134033203
103 | ------------iteration 1200----------
104 | total loss 0.6644324064254761
105 | loss_r_feature 25.020050048828125
106 | main criterion 0.002477305242791772
107 | Verifier accuracy: 85.71428680419922
108 | ------------iteration 1300----------
109 | total loss 0.6082080006599426
110 | loss_r_feature 23.141693115234375
111 | main criterion 0.0009327275329269469
112 | Verifier accuracy: 91.66666412353516
113 | ------------iteration 1400----------
114 | total loss 0.5574465394020081
115 | loss_r_feature 21.793500900268555
116 | main criterion 0.0010219528339803219
117 | Verifier accuracy: 95.23809814453125
118 | ------------iteration 1500----------
119 | total loss 0.5180432796478271
120 | loss_r_feature 20.062559127807617
121 | main criterion 0.0009084542398341
122 | Verifier accuracy: 94.04761505126953
123 | ------------iteration 1600----------
124 | total loss 0.46451741456985474
125 | loss_r_feature 17.89577865600586
126 | main criterion 0.0005521774291992188
127 | Verifier accuracy: 97.61904907226562
128 | ------------iteration 1700----------
129 | total loss 0.42708495259284973
130 | loss_r_feature 17.111621856689453
131 | main criterion 0.0005137125845067203
132 | Verifier accuracy: 96.42857360839844
133 | ------------iteration 1800----------
134 | total loss 0.3998515009880066
135 | loss_r_feature 16.164155960083008
136 | main criterion 0.00046278181253001094
137 | Verifier accuracy: 96.42857360839844
138 | ------------iteration 1900----------
139 | total loss 0.39311376214027405
140 | loss_r_feature 16.403812408447266
141 | main criterion 0.0006096022552810609
142 | Verifier accuracy: 96.42857360839844
143 | ------------iteration 2000----------
144 | total loss 0.390480101108551
145 | loss_r_feature 16.30877113342285
146 | main criterion 0.0006353174103423953
147 | Verifier accuracy: 96.42857360839844
148 |
--------------------------------------------------------------------------------
/example_logs/fp16_set1_rn50_output_00020_gpu_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp16_set1_rn50_output_00020_gpu_0.jpg
--------------------------------------------------------------------------------
/example_logs/fp32_set0_mnv2.log:
--------------------------------------------------------------------------------
1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_mnv2_set0" --r_feature=0.01 --arch_name="mobilenet_v2" --verifier --verifier_arch="r
2 | esnet18" --setting_id=0 --lr=0.2 --adi_scale=0.0 --l2=0.00001
3 | Namespace(adi_scale=0.0, arch_name='mobilenet_v2', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_mnv2_set0', fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=
4 | 0, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='resnet18', wd=0.01, worldsize=1)
5 | loading torchvision model for inversion with the name: mobilenet_v2
6 | ==> Resuming from checkpoint..
7 | ==> Getting BN params as feature statistics
8 | loading verifier: resnet18
9 | Deep inversion class generation
10 | get_images call
11 | ------------iteration 100----------
12 | total loss 3.482088804244995
13 | loss_r_feature 188.9123077392578
14 | main criterion 0.8707727789878845
15 | Verifier accuracy: 0.0
16 | ------------iteration 200----------
17 | total loss 3.0610427856445312
18 | loss_r_feature 207.21432495117188
19 | main criterion 0.16552281379699707
20 | Verifier accuracy: 0.0
21 | ------------iteration 300----------
22 | total loss 3.0640416145324707
23 | loss_r_feature 197.27249145507812
24 | main criterion 0.31159523129463196
25 | Verifier accuracy: 5.952381134033203
26 | ------------iteration 400----------
27 | total loss 2.802457332611084
28 | loss_r_feature 192.7845001220703
29 | main criterion 0.13211332261562347
30 | Verifier accuracy: 3.5714285373687744
31 | ------------iteration 500----------
32 | total loss 2.019845724105835
33 | loss_r_feature 130.70864868164062
34 | main criterion 0.011200053617358208
35 | Verifier accuracy: 9.523809432983398
36 | ------------iteration 600----------
37 | total loss 2.262993812561035
38 | loss_r_feature 154.48069763183594
39 | main criterion 0.05309981480240822
40 | Verifier accuracy: 5.952381134033203
41 | ------------iteration 700----------
42 | total loss 1.8827401399612427
43 | loss_r_feature 122.04618072509766
44 | main criterion 0.02727217972278595
45 | Verifier accuracy: 7.142857074737549
46 | ------------iteration 800----------
47 | total loss 1.8060630559921265
48 | loss_r_feature 121.67047882080078
49 | main criterion 0.007290567737072706
50 | Verifier accuracy: 15.476190567016602
51 | ------------iteration 900----------
52 | total loss 1.6456643342971802
53 | loss_r_feature 108.90213775634766
54 | main criterion 0.013873236253857613
55 | Verifier accuracy: 14.285714149475098
56 | ------------iteration 1000----------
57 | total loss 1.403316855430603
58 | loss_r_feature 90.18961334228516
59 | main criterion 0.016842082142829895
60 | Verifier accuracy: 16.66666603088379
61 | ------------iteration 1100----------
62 | total loss 1.5693998336791992
63 | loss_r_feature 112.5223617553711
64 | main criterion 0.0024354797787964344
65 | Verifier accuracy: 17.85714340209961
66 | ------------iteration 1200----------
67 | total loss 1.3196629285812378
68 | loss_r_feature 88.94979858398438
69 | main criterion 0.019705001264810562
70 | Verifier accuracy: 28.571428298950195
71 | ------------iteration 1300----------
72 | total loss 1.132418155670166
73 | loss_r_feature 75.13455963134766
74 | main criterion 0.007880108430981636
75 | Verifier accuracy: 29.761903762817383
76 | ------------iteration 1400----------
77 | total loss 1.1630631685256958
78 | loss_r_feature 81.51448822021484
79 | main criterion 0.003338359761983156
80 | Verifier accuracy: 25.0
81 | ------------iteration 1500----------
82 | total loss 0.849071204662323
83 | loss_r_feature 52.43641662597656
84 | main criterion 0.0011611892841756344
85 | Verifier accuracy: 27.380952835083008
86 | ------------iteration 1600----------
87 | total loss 0.7553449869155884
88 | loss_r_feature 44.74074935913086
89 | main criterion 0.0010790483793243766
90 | Verifier accuracy: 15.476190567016602
91 | ------------iteration 1700----------
92 | total loss 0.280352383852005
93 | loss_r_feature 13.329302787780762
94 | main criterion 0.0007046745158731937
95 | Verifier accuracy: 26.190475463867188
96 | ------------iteration 1800----------
97 | total loss 0.26061105728149414
98 | loss_r_feature 11.900742530822754
99 | main criterion 0.0009304682607762516
100 | Verifier accuracy: 25.0
101 | ------------iteration 1900----------
102 | total loss 0.2539810240268707
103 | loss_r_feature 11.516486167907715
104 | main criterion 0.0013499259948730469
105 | Verifier accuracy: 27.380952835083008
106 | ------------iteration 2000----------
107 | total loss 0.25407370924949646
108 | loss_r_feature 11.625514030456543
109 | main criterion 0.0009400731069035828
110 | Verifier accuracy: 27.380952835083008
111 | ------------iteration 2100----------
112 | total loss 1.186625361442566
113 | loss_r_feature 43.52707290649414
114 | main criterion 0.008745352737605572
115 | Verifier accuracy: 51.19047546386719
116 | ------------iteration 2200----------
117 | total loss 1.031857967376709
118 | loss_r_feature 35.6275749206543
119 | main criterion 0.0017157054971903563
120 | Verifier accuracy: 73.80952453613281
121 | ------------iteration 2300----------
122 | total loss 0.8248406648635864
123 | loss_r_feature 30.27305030822754
124 | main criterion 0.001007091486826539
125 | Verifier accuracy: 85.71428680419922
126 | ------------iteration 2400----------
127 | total loss 0.8060064315795898
128 | loss_r_feature 28.234004974365234
129 | main criterion 0.0011983144795522094
130 | Verifier accuracy: 86.9047622680664
131 | ------------iteration 2500----------
132 | total loss 0.7223383188247681
133 | loss_r_feature 25.458057403564453
134 | main criterion 0.001484291860833764
135 | Verifier accuracy: 90.47618865966797
136 | ------------iteration 2600----------
137 | total loss 0.5898433923721313
138 | loss_r_feature 20.13226890563965
139 | main criterion 0.0014175687683746219
140 | Verifier accuracy: 91.66666412353516
141 | ------------iteration 2700----------
142 | total loss 0.5021806955337524
143 | loss_r_feature 16.399860382080078
144 | main criterion 0.0010060241911560297
145 | Verifier accuracy: 96.42857360839844
146 | ------------iteration 2800----------
147 | total loss 0.420066237449646
148 | loss_r_feature 13.413532257080078
149 | main criterion 0.000963824160862714
150 | Verifier accuracy: 92.85713958740234
151 | ------------iteration 2900----------
152 | total loss 0.3616832494735718
153 | loss_r_feature 10.820902824401855
154 | main criterion 0.0006766319274902344
155 | Verifier accuracy: 95.23809814453125
156 | ------------iteration 3000----------
157 | total loss 0.3493320047855377
158 | loss_r_feature 10.420402526855469
159 | main criterion 0.0005601133452728391
160 | Verifier accuracy: 96.42857360839844
--------------------------------------------------------------------------------
/example_logs/fp32_set0_mnv2_output_00030_gpu_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_mnv2_output_00030_gpu_0.jpg
--------------------------------------------------------------------------------
/example_logs/fp32_set0_rn50.log:
--------------------------------------------------------------------------------
1 | $ python imagenet_inversion.py --bs=84 --do_flip --exp_name="test_rn50_fp32" --r_feature=0.01 --arch_name="resnet50" --verifier --setting_id=0 --lr=0.
2 | 2 --adi_scale=0.0 --l2=0.00001
3 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='test_rn50_fp32', fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.2, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, t
4 | v_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', wd=0.01, worldsize=1)
5 | loading torchvision model for inversion with the name: resnet50
6 | ==> Resuming from checkpoint..
7 | ==> Getting BN params as feature statistics
8 | loading verifier: mobilenet_v2
9 | Deep inversion class generation
10 | get_images call
11 | ------------iteration 100----------
12 | total loss 1.7325098514556885
13 | loss_r_feature 90.60043334960938
14 | main criterion 0.22487393021583557
15 | Verifier accuracy: 1.1904761791229248
16 | ------------iteration 200----------
17 | total loss 1.3570595979690552
18 | loss_r_feature 73.92747497558594
19 | main criterion 0.08076667785644531
20 | Verifier accuracy: 5.952381134033203
21 | ------------iteration 300----------
22 | total loss 1.192380666732788
23 | loss_r_feature 66.64324951171875
24 | main criterion 0.04419786483049393
25 | Verifier accuracy: 4.761904716491699 [68/4971]
26 | ------------iteration 400----------
27 | total loss 1.1331039667129517
28 | loss_r_feature 63.37013244628906
29 | main criterion 0.027310485020279884
30 | Verifier accuracy: 15.476190567016602
31 | ------------iteration 500----------
32 | total loss 1.0495476722717285
33 | loss_r_feature 58.564552307128906
34 | main criterion 0.004942258354276419
35 | Verifier accuracy: 13.095237731933594
36 | ------------iteration 600----------
37 | total loss 1.05406653881073
38 | loss_r_feature 57.1053466796875
39 | main criterion 0.02717919647693634
40 | Verifier accuracy: 15.476190567016602
41 | ------------iteration 700----------
42 | total loss 0.8731762170791626
43 | loss_r_feature 48.39915466308594
44 | main criterion 0.0029847281984984875
45 | Verifier accuracy: 11.904762268066406
46 | ------------iteration 800----------
47 | total loss 0.8121204376220703
48 | loss_r_feature 44.756126403808594
49 | main criterion 0.0013222013367339969
50 | Verifier accuracy: 13.095237731933594
51 | ------------iteration 900----------
52 | total loss 0.7302803993225098
53 | loss_r_feature 39.77760314941406
54 | main criterion 0.00771959638223052
55 | Verifier accuracy: 15.476190567016602
56 | ------------iteration 1000----------
57 | total loss 0.6622204184532166
58 | loss_r_feature 35.62640380859375
59 | main criterion 0.0059319790452718735
60 | Verifier accuracy: 21.428571701049805
61 | ------------iteration 1100----------
62 | total loss 0.5391901731491089
63 | loss_r_feature 28.116008758544922
64 | main criterion 0.0015326568391174078
65 | Verifier accuracy: 23.809524536132812
66 | ------------iteration 1200----------
67 | total loss 0.5046427845954895
68 | loss_r_feature 26.73769187927246
69 | main criterion 0.001567840576171875
70 | Verifier accuracy: 21.428571701049805
71 | ------------iteration 1300----------
72 | total loss 0.45175373554229736
73 | loss_r_feature 23.481212615966797
74 | main criterion 0.0010218393290415406
75 | Verifier accuracy: 16.66666603088379
76 | ------------iteration 1400----------
77 | total loss 0.3774851858615875
78 | loss_r_feature 18.778751373291016
79 | main criterion 0.0009181839996017516
80 | Verifier accuracy: 28.571428298950195
81 | ------------iteration 1500----------
82 | total loss 0.3348837196826935
83 | loss_r_feature 16.353933334350586
84 | main criterion 0.0006400971324183047
85 | Verifier accuracy: 26.190475463867188
86 | ------------iteration 1600----------
87 | total loss 0.29075923562049866
88 | loss_r_feature 13.699808120727539
89 | main criterion 0.0005158697022125125
90 | Verifier accuracy: 32.14285659790039
91 | ------------iteration 1700----------
92 | total loss 0.280352383852005
93 | loss_r_feature 13.329302787780762
94 | main criterion 0.0007046745158731937
95 | Verifier accuracy: 26.190475463867188
96 | ------------iteration 1800----------
97 | total loss 0.26061105728149414
98 | loss_r_feature 11.900742530822754
99 | main criterion 0.0009304682607762516
100 | Verifier accuracy: 25.0
101 | ------------iteration 1900----------
102 | total loss 0.2539810240268707
103 | loss_r_feature 11.516486167907715
104 | main criterion 0.0013499259948730469
105 | Verifier accuracy: 27.380952835083008
106 | ------------iteration 2000----------
107 | total loss 0.25407370924949646
108 | loss_r_feature 11.625514030456543
109 | main criterion 0.0009400731069035828
110 | Verifier accuracy: 27.380952835083008
111 | ------------iteration 2100----------
112 | total loss 1.186625361442566
113 | loss_r_feature 43.52707290649414
114 | main criterion 0.008745352737605572
115 | Verifier accuracy: 51.19047546386719
116 | ------------iteration 2200----------
117 | total loss 1.031857967376709
118 | loss_r_feature 35.6275749206543
119 | main criterion 0.0017157054971903563
120 | Verifier accuracy: 73.80952453613281
121 | ------------iteration 2300----------
122 | total loss 0.8248406648635864
123 | loss_r_feature 30.27305030822754
124 | main criterion 0.001007091486826539
125 | Verifier accuracy: 85.71428680419922
126 | ------------iteration 2400----------
127 | total loss 0.8060064315795898
128 | loss_r_feature 28.234004974365234
129 | main criterion 0.0011983144795522094
130 | Verifier accuracy: 86.9047622680664
131 | ------------iteration 2500----------
132 | total loss 0.7223383188247681
133 | loss_r_feature 25.458057403564453
134 | main criterion 0.001484291860833764
135 | Verifier accuracy: 90.47618865966797
136 | ------------iteration 2600----------
137 | total loss 0.5898433923721313
138 | loss_r_feature 20.13226890563965
139 | main criterion 0.0014175687683746219
140 | Verifier accuracy: 91.66666412353516
141 | ------------iteration 2700----------
142 | total loss 0.5021806955337524
143 | loss_r_feature 16.399860382080078
144 | main criterion 0.0010060241911560297
145 | Verifier accuracy: 96.42857360839844
146 | ------------iteration 2800----------
147 | total loss 0.420066237449646
148 | loss_r_feature 13.413532257080078
149 | main criterion 0.000963824160862714
150 | Verifier accuracy: 92.85713958740234
151 | ------------iteration 2900----------
152 | total loss 0.3616832494735718
153 | loss_r_feature 10.820902824401855
154 | main criterion 0.0006766319274902344
155 | Verifier accuracy: 95.23809814453125
156 | ------------iteration 3000----------
157 | total loss 0.3493320047855377
158 | loss_r_feature 10.420402526855469
159 | main criterion 0.0005601133452728391
160 | Verifier accuracy: 96.42857360839844
--------------------------------------------------------------------------------
/example_logs/fp32_set0_rn50_first_bn_scaled.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_rn50_first_bn_scaled.jpg
--------------------------------------------------------------------------------
/example_logs/fp32_set0_rn50_first_bn_scaled.log:
--------------------------------------------------------------------------------
1 | python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion_first_bn_cpcp" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25
2 | Namespace(adi_scale=0.0, arch_name='resnet50', bs=84, comment='', do_flip=True, epochs=20000, exp_name='rn50_inversion_first_bn_cpcp', first_bn_multiplier=10.0, fp16=False, jitter=30, l2=1e-05, local_rank=0, lr=0.25, main_loss_multiplier=1.0, no_cuda=False, r_feature=0.01, random_label=False, setting_id=0, store_best_images=False, tv_l1=0.0, tv_l2=0.0001, verifier=True, verifier_arch='mobilenet_v2', worldsize=1)
3 | loading torchvision model for inversion with the name: resnet50
4 | ==> Resuming from checkpoint..
5 | ==> Getting BN params as feature statistics
6 | loading verifier: mobilenet_v2
7 | Deep inversion class generation
8 | get_images call
9 | ------------iteration 100----------
10 | total loss 4.216484546661377
11 | loss_r_feature 319.7900085449219
12 | main criterion 0.3684564530849457
13 | Verifier accuracy: 0.0
14 | ------------iteration 200----------
15 | total loss 3.074280023574829
16 | loss_r_feature 243.4846954345703
17 | main criterion 0.03289255499839783
18 | Verifier accuracy: 3.5714285373687744
19 | ------------iteration 300----------
20 | total loss 2.556410312652588
21 | loss_r_feature 185.2130584716797
22 | main criterion 0.1536623239517212
23 | Verifier accuracy: 2.3809523582458496
24 | ------------iteration 400----------
25 | total loss 2.1119472980499268
26 | loss_r_feature 157.12548828125
27 | main criterion 0.013974723406136036
28 | Verifier accuracy: 2.3809523582458496
29 | ------------iteration 500----------
30 | total loss 1.6994301080703735
31 | loss_r_feature 120.24372863769531
32 | main criterion 0.023798726499080658
33 | Verifier accuracy: 7.142857074737549
34 | ------------iteration 600----------
35 | total loss 1.5869784355163574
36 | loss_r_feature 102.4736099243164
37 | main criterion 0.12379765510559082
38 | Verifier accuracy: 8.333333015441895
39 | ------------iteration 700----------
40 | total loss 1.2205699682235718
41 | loss_r_feature 81.42609405517578
42 | main criterion 0.012213945388793945
43 | Verifier accuracy: 14.285714149475098
44 | ------------iteration 800----------
45 | total loss 1.0356595516204834
46 | loss_r_feature 67.60336303710938
47 | main criterion 0.0036269596312195063
48 | Verifier accuracy: 22.619047164916992
49 | ------------iteration 900----------
50 | total loss 0.9518051743507385
51 | loss_r_feature 57.53125762939453
52 | main criterion 0.054748646914958954
53 | Verifier accuracy: 20.238094329833984
54 | ------------iteration 1000----------
55 | total loss 0.762413501739502
56 | loss_r_feature 46.62602233886719
57 | main criterion 0.005131573881953955
58 | Verifier accuracy: 27.380952835083008
59 | ------------iteration 1100----------
60 | total loss 0.6755117774009705
61 | loss_r_feature 38.807640075683594
62 | main criterion 0.020008916035294533
63 | Verifier accuracy: 25.0
64 | ------------iteration 1200----------
65 | total loss 0.5802351832389832
66 | loss_r_feature 34.1475830078125
67 | main criterion 0.0015264465473592281
68 | Verifier accuracy: 28.571428298950195
69 | ------------iteration 1300----------
70 | total loss 0.49006298184394836
71 | loss_r_feature 27.85464096069336
72 | main criterion 0.001066457713022828
73 | Verifier accuracy: 32.14285659790039
74 | ------------iteration 1400----------
75 | total loss 0.4420880675315857
76 | loss_r_feature 23.253334045410156
77 | main criterion 0.0077417464926838875
78 | Verifier accuracy: 30.952381134033203
79 | ------------iteration 1500----------
80 | total loss 0.4081510007381439
81 | loss_r_feature 21.22683334350586
82 | main criterion 0.0008004733244888484
83 | Verifier accuracy: 40.47618865966797
84 | ------------iteration 1600----------
85 | total loss 0.36195051670074463
86 | loss_r_feature 17.16771697998047
87 | main criterion 0.0006432306254282594
88 | Verifier accuracy: 44.0476188659668
89 | ------------iteration 1700----------
90 | total loss 0.3319593071937561
91 | loss_r_feature 14.380390167236328
92 | main criterion 0.0008895737701095641
93 | Verifier accuracy: 35.71428680419922
94 | ------------iteration 1800----------
95 | total loss 0.31532973051071167
96 | loss_r_feature 12.930418968200684
97 | main criterion 0.0006053107208572328
98 | Verifier accuracy: 33.33333206176758
99 | ------------iteration 1900----------
100 | total loss 0.298938125371933
101 | loss_r_feature 11.36793041229248
102 | main criterion 0.0007616224465891719
103 | Verifier accuracy: 36.904762268066406
104 | ------------iteration 2000----------
105 | total loss 0.3005948066711426
106 | loss_r_feature 11.55113697052002
107 | main criterion 0.0006574335275217891
108 | Verifier accuracy: 38.095237731933594
109 | ------------iteration 2100----------
110 | total loss 1.8834247589111328
111 | loss_r_feature 103.4402847290039
112 | main criterion 0.0071905795484781265
113 | Verifier accuracy: 52.380950927734375
114 | ------------iteration 2200----------
115 | total loss 1.2003257274627686
116 | loss_r_feature 60.801185607910156
117 | main criterion 0.005007199011743069
118 | Verifier accuracy: 63.095237731933594
119 | ------------iteration 2300----------
120 | total loss 1.0916008949279785
121 | loss_r_feature 51.438507080078125
122 | main criterion 0.0014722461346536875
123 | Verifier accuracy: 85.71428680419922
124 | ------------iteration 2400----------
125 | total loss 0.8125420212745667
126 | loss_r_feature 33.99945068359375
127 | main criterion 0.0012546947691589594
128 | Verifier accuracy: 83.33333587646484
129 | ------------iteration 2500----------
130 | total loss 0.8069868683815002
131 | loss_r_feature 35.57963562011719
132 | main criterion 0.0014406385598704219
133 | Verifier accuracy: 89.28571319580078
134 | ------------iteration 2600----------
135 | total loss 0.6985664963722229
136 | loss_r_feature 28.85590934753418
137 | main criterion 0.0015078613068908453
138 | Verifier accuracy: 94.04761505126953
139 | ------------iteration 2700----------
140 | total loss 0.640400767326355
141 | loss_r_feature 23.789012908935547
142 | main criterion 0.0010293552186340094
143 | Verifier accuracy: 96.42857360839844
144 | ------------iteration 2800----------
145 | total loss 0.5649642944335938
146 | loss_r_feature 17.716434478759766
147 | main criterion 0.0007892563007771969
148 | Verifier accuracy: 91.66666412353516
149 | ------------iteration 2900----------
150 | total loss 0.5026633739471436
151 | loss_r_feature 12.565704345703125
152 | main criterion 0.0006156648742035031
153 | Verifier accuracy: 94.04761505126953
154 | ------------iteration 3000----------
155 | total loss 0.4865095913410187
156 | loss_r_feature 11.368947982788086
157 | main criterion 0.0005418686778284609
158 | Verifier accuracy: 91.66666412353516
159 |
--------------------------------------------------------------------------------
/example_logs/fp32_set0_rn50_output_00030_gpu_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/fp32_set0_rn50_output_00030_gpu_0.jpg
--------------------------------------------------------------------------------
/example_logs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DeepInversion/2413af1f97f6c73c4dcb47b6a44266aad458ff28/example_logs/teaser.png
--------------------------------------------------------------------------------
/imagenet_inversion.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
3 | # Nvidia Source Code License-NC
4 | # Official PyTorch implementation of CVPR2020 paper
5 | # Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion
6 | # Hongxu Yin, Pavlo Molchanov, Zhizhong Li, Jose M. Alvarez, Arun Mallya, Derek
7 | # Hoiem, Niraj K. Jha, and Jan Kautz
8 | # --------------------------------------------------------
9 |
10 | from __future__ import division, print_function
11 | from __future__ import absolute_import
12 | from __future__ import division
13 | from __future__ import unicode_literals
14 |
15 | import argparse
16 | import torch
17 | from torch import distributed, nn
18 | import random
19 | import torch.nn as nn
20 | import torch.nn.parallel
21 | import torch.utils.data
22 | from torchvision import datasets, transforms
23 |
24 | import numpy as np
25 | import torch.cuda.amp as amp
26 | import os
27 | import torchvision.models as models
28 | from utils.utils import load_model_pytorch, distributed_is_initialized
29 |
30 | random.seed(0)
31 |
32 |
33 | def validate_one(input, target, model):
34 | """Perform validation on the validation set"""
35 |
36 | def accuracy(output, target, topk=(1,)):
37 | """Computes the precision@k for the specified values of k"""
38 | maxk = max(topk)
39 | batch_size = target.size(0)
40 |
41 | _, pred = output.topk(maxk, 1, True, True)
42 | pred = pred.t()
43 | correct = pred.eq(target.view(1, -1).expand_as(pred))
44 |
45 | res = []
46 | for k in topk:
47 | correct_k = correct[:k].reshape(-1).float().sum(0)
48 | res.append(correct_k.mul_(100.0 / batch_size))
49 | return res
50 |
51 | with torch.no_grad():
52 | output = model(input)
53 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
54 |
55 | print("Verifier accuracy: ", prec1.item())
56 |
57 |
58 | def run(args):
59 | torch.manual_seed(args.local_rank)
60 | device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
61 |
62 | if args.arch_name == "resnet50v15":
63 | from models.resnetv15 import build_resnet
64 | net = build_resnet("resnet50", "classic")
65 | else:
66 | print("loading torchvision model for inversion with the name: {}".format(args.arch_name))
67 | net = models.__dict__[args.arch_name](pretrained=True)
68 |
69 | net = net.to(device)
70 |
71 | use_fp16 = args.fp16
72 | if use_fp16:
73 | net, _ = amp.initialize(net, [], opt_level="O2")
74 |
75 | print('==> Resuming from checkpoint..')
76 |
77 | ### load models
78 | if args.arch_name=="resnet50v15":
79 | path_to_model = "./models/resnet50v15/model_best.pth.tar"
80 | load_model_pytorch(net, path_to_model, gpu_n=torch.cuda.current_device())
81 |
82 | net.to(device)
83 | net.eval()
84 |
85 | # reserved to compute test accuracy on generated images by different networks
86 | net_verifier = None
87 | if args.verifier and args.adi_scale == 0:
88 | # if multiple GPUs are used then we can change code to load different verifiers to different GPUs
89 | if args.local_rank == 0:
90 | print("loading verifier: ", args.verifier_arch)
91 | net_verifier = models.__dict__[args.verifier_arch](pretrained=True).to(device)
92 | net_verifier.eval()
93 |
94 | if use_fp16:
95 | net_verifier = net_verifier.half()
96 |
97 | if args.adi_scale != 0.0:
98 | student_arch = "resnet18"
99 | net_verifier = models.__dict__[student_arch](pretrained=True).to(device)
100 | net_verifier.eval()
101 |
102 | if use_fp16:
103 | net_verifier, _ = amp.initialize(net_verifier, [], opt_level="O2")
104 |
105 | net_verifier = net_verifier.to(device)
106 | net_verifier.train()
107 |
108 | if use_fp16:
109 | for module in net_verifier.modules():
110 | if isinstance(module, nn.BatchNorm2d):
111 | module.eval().half()
112 |
113 | from deepinversion import DeepInversionClass
114 |
115 | exp_name = args.exp_name
116 | # final images will be stored here:
117 | adi_data_path = "./final_images/%s"%exp_name
118 | # temporal data and generations will be stored here
119 | exp_name = "generations/%s"%exp_name
120 |
121 | args.iterations = 2000
122 | args.start_noise = True
123 | # args.detach_student = False
124 |
125 | args.resolution = 224
126 | bs = args.bs
127 | jitter = 30
128 |
129 | parameters = dict()
130 | parameters["resolution"] = 224
131 | parameters["random_label"] = False
132 | parameters["start_noise"] = True
133 | parameters["detach_student"] = False
134 | parameters["do_flip"] = True
135 |
136 | parameters["do_flip"] = args.do_flip
137 | parameters["random_label"] = args.random_label
138 | parameters["store_best_images"] = args.store_best_images
139 |
140 | criterion = nn.CrossEntropyLoss()
141 |
142 | coefficients = dict()
143 | coefficients["r_feature"] = args.r_feature
144 | coefficients["first_bn_multiplier"] = args.first_bn_multiplier
145 | coefficients["tv_l1"] = args.tv_l1
146 | coefficients["tv_l2"] = args.tv_l2
147 | coefficients["l2"] = args.l2
148 | coefficients["lr"] = args.lr
149 | coefficients["main_loss_multiplier"] = args.main_loss_multiplier
150 | coefficients["adi_scale"] = args.adi_scale
151 |
152 | network_output_function = lambda x: x
153 |
154 | # check accuracy of verifier
155 | if args.verifier:
156 | hook_for_display = lambda x,y: validate_one(x, y, net_verifier)
157 | else:
158 | hook_for_display = None
159 |
160 | DeepInversionEngine = DeepInversionClass(net_teacher=net,
161 | final_data_path=adi_data_path,
162 | path=exp_name,
163 | parameters=parameters,
164 | setting_id=args.setting_id,
165 | bs = bs,
166 | use_fp16 = args.fp16,
167 | jitter = jitter,
168 | criterion=criterion,
169 | coefficients = coefficients,
170 | network_output_function = network_output_function,
171 | hook_for_display = hook_for_display)
172 | net_student=None
173 | if args.adi_scale != 0:
174 | net_student = net_verifier
175 | DeepInversionEngine.generate_batch(net_student=net_student)
176 |
177 | def main():
178 | parser = argparse.ArgumentParser()
179 | parser.add_argument('-s', '--worldsize', type=int, default=1, help='Number of processes participating in the job.')
180 | parser.add_argument('--local_rank', '--rank', type=int, default=0, help='Rank of the current process.')
181 | parser.add_argument('--adi_scale', type=float, default=0.0, help='Coefficient for Adaptive Deep Inversion')
182 | parser.add_argument('--no-cuda', action='store_true')
183 |
184 | parser.add_argument('--epochs', default=20000, type=int, help='batch size')
185 | parser.add_argument('--setting_id', default=0, type=int, help='settings for optimization: 0 - multi resolution, 1 - 2k iterations, 2 - 20k iterations')
186 | parser.add_argument('--bs', default=64, type=int, help='batch size')
187 | parser.add_argument('--jitter', default=30, type=int, help='batch size')
188 | parser.add_argument('--comment', default='', type=str, help='batch size')
189 | parser.add_argument('--arch_name', default='resnet50', type=str, help='model name from torchvision or resnet50v15')
190 |
191 | parser.add_argument('--fp16', action='store_true', help='use FP16 for optimization')
192 | parser.add_argument('--exp_name', type=str, default='test', help='where to store experimental data')
193 |
194 | parser.add_argument('--verifier', action='store_true', help='evaluate batch with another model')
195 | parser.add_argument('--verifier_arch', type=str, default='mobilenet_v2', help = "arch name from torchvision models to act as a verifier")
196 |
197 | parser.add_argument('--do_flip', action='store_true', help='apply flip during model inversion')
198 | parser.add_argument('--random_label', action='store_true', help='generate random label for optimization')
199 | parser.add_argument('--r_feature', type=float, default=0.05, help='coefficient for feature distribution regularization')
200 | parser.add_argument('--first_bn_multiplier', type=float, default=10., help='additional multiplier on first bn layer of R_feature')
201 | parser.add_argument('--tv_l1', type=float, default=0.0, help='coefficient for total variation L1 loss')
202 | parser.add_argument('--tv_l2', type=float, default=0.0001, help='coefficient for total variation L2 loss')
203 | parser.add_argument('--lr', type=float, default=0.2, help='learning rate for optimization')
204 | parser.add_argument('--l2', type=float, default=0.00001, help='l2 loss on the image')
205 | parser.add_argument('--main_loss_multiplier', type=float, default=1.0, help='coefficient for the main loss in optimization')
206 | parser.add_argument('--store_best_images', action='store_true', help='save best images as separate files')
207 |
208 | args = parser.parse_args()
209 | print(args)
210 |
211 | torch.backends.cudnn.benchmark = True
212 | run(args)
213 |
214 |
215 | if __name__ == '__main__':
216 | main()
217 |
--------------------------------------------------------------------------------
/models/resnetv15.py:
--------------------------------------------------------------------------------
1 | # Originated from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5
2 | # now code is at https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets
3 | import torch.nn as nn
4 |
5 | __all__ = ['ResNet', 'build_resnet', 'resnet_versions', 'resnet_configs']
6 |
7 | # ResNetBuilder {{{
8 |
9 | class ResNetBuilder(object):
10 | def __init__(self, version, config):
11 | self.config = config
12 |
13 | self.L = sum(version['layers'])
14 | self.M = version['block'].M
15 |
16 | self.layer_index = 0
17 |
18 | def conv(self, kernel_size, in_planes, out_planes, stride=1):
19 | if kernel_size == 3:
20 | conv = self.config['conv'](
21 | in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 | elif kernel_size == 1:
24 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
25 | bias=False)
26 | elif kernel_size == 5:
27 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
28 | padding=2, bias=False)
29 | elif kernel_size == 7:
30 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
31 | padding=3, bias=False)
32 | else:
33 | return None
34 |
35 | if self.config['nonlinearity'] == 'relu':
36 | nn.init.kaiming_normal_(conv.weight,
37 | mode=self.config['conv_init'],
38 | nonlinearity=self.config['nonlinearity'])
39 |
40 | return conv
41 |
42 | def conv3x3(self, in_planes, out_planes, stride=1):
43 | """3x3 convolution with padding"""
44 | c = self.conv(3, in_planes, out_planes, stride=stride)
45 | return c
46 |
47 | def conv1x1(self, in_planes, out_planes, stride=1):
48 | """1x1 convolution with padding"""
49 | c = self.conv(1, in_planes, out_planes, stride=stride)
50 | return c
51 |
52 | def conv7x7(self, in_planes, out_planes, stride=1):
53 | """7x7 convolution with padding"""
54 | c = self.conv(7, in_planes, out_planes, stride=stride)
55 | return c
56 |
57 | def conv5x5(self, in_planes, out_planes, stride=1):
58 | """5x5 convolution with padding"""
59 | c = self.conv(5, in_planes, out_planes, stride=stride)
60 | return c
61 |
62 | def batchnorm(self, planes, last_bn=False):
63 | bn = nn.BatchNorm2d(planes)
64 | gamma_init_val = 0 if last_bn and self.config['last_bn_0_init'] else 1
65 | nn.init.constant_(bn.weight, gamma_init_val)
66 | nn.init.constant_(bn.bias, 0)
67 |
68 | return bn
69 |
70 | def activation(self):
71 | return self.config['activation']()
72 |
73 | # ResNetBuilder }}}
74 |
75 | # BasicBlock {{{
76 | class BasicBlock(nn.Module):
77 | M = 2
78 | expansion = 1
79 |
80 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None):
81 | super(BasicBlock, self).__init__()
82 | self.conv1 = builder.conv3x3(inplanes, planes, stride)
83 | self.bn1 = builder.batchnorm(planes)
84 | self.relu = builder.activation()
85 | self.conv2 = builder.conv3x3(planes, planes)
86 | self.bn2 = builder.batchnorm(planes, last_bn=True)
87 | self.downsample = downsample
88 | self.stride = stride
89 |
90 | def forward(self, x):
91 | residual = x
92 |
93 | out = self.conv1(x)
94 | if self.bn1 is not None:
95 | out = self.bn1(out)
96 |
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 |
101 | if self.bn2 is not None:
102 | out = self.bn2(out)
103 |
104 | if self.downsample is not None:
105 | residual = self.downsample(x)
106 |
107 | out += residual
108 | out = self.relu(out)
109 |
110 | return out
111 | # BasicBlock }}}
112 |
113 | # Bottleneck {{{
114 | class Bottleneck(nn.Module):
115 | M = 3
116 | expansion = 4
117 |
118 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None):
119 | super(Bottleneck, self).__init__()
120 | self.conv1 = builder.conv1x1(inplanes, planes)
121 | self.bn1 = builder.batchnorm(planes)
122 | self.conv2 = builder.conv3x3(planes, planes, stride=stride)
123 | self.bn2 = builder.batchnorm(planes)
124 | self.conv3 = builder.conv1x1(planes, planes * self.expansion)
125 | self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True)
126 | self.relu = builder.activation()
127 | self.downsample = downsample
128 | self.stride = stride
129 |
130 |
131 | def forward(self, x):
132 | residual = x
133 |
134 | out = self.conv1(x)
135 | out = self.bn1(out)
136 | out = self.relu(out)
137 |
138 | out = self.conv2(out)
139 | out = self.bn2(out)
140 |
141 | out = self.relu(out)
142 |
143 | out = self.conv3(out)
144 | out = self.bn3(out)
145 |
146 | if self.downsample is not None:
147 | residual = self.downsample(x)
148 |
149 | out += residual
150 |
151 | out = self.relu(out)
152 |
153 | return out
154 | # Bottleneck }}}
155 |
156 | # ResNet {{{
157 | class ResNet(nn.Module):
158 | def __init__(self, builder, block, layers, num_classes=1000):
159 | self.inplanes = 64
160 | super(ResNet, self).__init__()
161 | self.conv1 = builder.conv7x7(3, 64, stride=2)
162 | self.bn1 = builder.batchnorm(64)
163 | self.relu = builder.activation()
164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
165 | self.layer1 = self._make_layer(builder, block, 64, layers[0])
166 | self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2)
167 | self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2)
168 | self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2)
169 | self.avgpool = nn.AdaptiveAvgPool2d(1)
170 | self.fc = nn.Linear(512 * block.expansion, num_classes)
171 |
172 | def _make_layer(self, builder, block, planes, blocks, stride=1):
173 | downsample = None
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | dconv = builder.conv1x1(self.inplanes, planes * block.expansion,
176 | stride=stride)
177 | dbn = builder.batchnorm(planes * block.expansion)
178 | if dbn is not None:
179 | downsample = nn.Sequential(dconv, dbn)
180 | else:
181 | downsample = dconv
182 |
183 | layers = []
184 | layers.append(block(builder, self.inplanes, planes, stride, downsample))
185 | self.inplanes = planes * block.expansion
186 | for i in range(1, blocks):
187 | layers.append(block(builder, self.inplanes, planes))
188 |
189 | builder.layer_index += 1
190 |
191 | return nn.Sequential(*layers)
192 |
193 | def forward(self, x):
194 | x = self.conv1(x)
195 | if self.bn1 is not None:
196 | x = self.bn1(x)
197 |
198 | x = self.relu(x)
199 | x = self.maxpool(x)
200 |
201 | x = self.layer1(x)
202 | x = self.layer2(x)
203 | x = self.layer3(x)
204 | x = self.layer4(x)
205 |
206 | x = self.avgpool(x)
207 | x = x.view(x.size(0), -1)
208 | x = self.fc(x)
209 |
210 | return x
211 | # ResNet }}}
212 |
213 |
214 | resnet_configs = {
215 | 'classic' : {
216 | 'conv' : nn.Conv2d,
217 | 'conv_init' : 'fan_out',
218 | 'nonlinearity' : 'relu',
219 | 'last_bn_0_init' : False,
220 | 'activation' : lambda: nn.ReLU(inplace=True),
221 | },
222 | 'fanin' : {
223 | 'conv' : nn.Conv2d,
224 | 'conv_init' : 'fan_in',
225 | 'nonlinearity' : 'relu',
226 | 'last_bn_0_init' : False,
227 | 'activation' : lambda: nn.ReLU(inplace=True),
228 | },
229 | }
230 |
231 | resnet_versions = {
232 | 'resnet18' : {
233 | 'net' : ResNet,
234 | 'block' : BasicBlock,
235 | 'layers' : [2, 2, 2, 2],
236 | 'num_classes' : 1000,
237 | },
238 | 'resnet34' : {
239 | 'net' : ResNet,
240 | 'block' : BasicBlock,
241 | 'layers' : [3, 4, 6, 3],
242 | 'num_classes' : 1000,
243 | },
244 | 'resnet50' : {
245 | 'net' : ResNet,
246 | 'block' : Bottleneck,
247 | 'layers' : [3, 4, 6, 3],
248 | 'num_classes' : 1000,
249 | },
250 | 'resnet101' : {
251 | 'net' : ResNet,
252 | 'block' : Bottleneck,
253 | 'layers' : [3, 4, 23, 3],
254 | 'num_classes' : 1000,
255 | },
256 | 'resnet152' : {
257 | 'net' : ResNet,
258 | 'block' : Bottleneck,
259 | 'layers' : [3, 8, 36, 3],
260 | 'num_classes' : 1000,
261 | },
262 | }
263 |
264 |
265 | def build_resnet(version, config, model_state=None):
266 | version = resnet_versions[version]
267 | config = resnet_configs[config]
268 |
269 | builder = ResNetBuilder(version, config)
270 | print("Version: {}".format(version))
271 | print("Config: {}".format(config))
272 | model = version['net'](builder,
273 | version['block'],
274 | version['layers'],
275 | version['num_classes'])
276 |
277 | return model
278 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
3 | # Nvidia Source Code License-NC
4 | # Code written by Pavlo Molchanov and Hongxu Yin
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | import os
9 | from torch import distributed, nn
10 | import random
11 | import numpy as np
12 |
13 | def load_model_pytorch(model, load_model, gpu_n=0):
14 | print("=> loading checkpoint '{}'".format(load_model))
15 |
16 | checkpoint = torch.load(load_model, map_location = lambda storage, loc: storage.cuda(gpu_n))
17 |
18 | if 'state_dict' in checkpoint.keys():
19 | load_from = checkpoint['state_dict']
20 | else:
21 | load_from = checkpoint
22 |
23 | if 1:
24 | if 'module.' in list(model.state_dict().keys())[0]:
25 | if 'module.' not in list(load_from.keys())[0]:
26 | from collections import OrderedDict
27 |
28 | load_from = OrderedDict([("module.{}".format(k), v) for k, v in load_from.items()])
29 |
30 | if 'module.' not in list(model.state_dict().keys())[0]:
31 | if 'module.' in list(load_from.keys())[0]:
32 | from collections import OrderedDict
33 |
34 | load_from = OrderedDict([(k.replace("module.", ""), v) for k, v in load_from.items()])
35 |
36 | if 1:
37 | if list(load_from.items())[0][0][:2] == "1." and list(model.state_dict().items())[0][0][:2] != "1.":
38 | load_from = OrderedDict([(k[2:], v) for k, v in load_from.items()])
39 |
40 | load_from = OrderedDict([(k, v) for k, v in load_from.items() if "gate" not in k])
41 |
42 | model.load_state_dict(load_from, strict=True)
43 |
44 | epoch_from = -1
45 | if 'epoch' in checkpoint.keys():
46 | epoch_from = checkpoint['epoch']
47 | print("=> loaded checkpoint '{}' (epoch {})"
48 | .format(load_model, epoch_from))
49 |
50 |
51 | def create_folder(directory):
52 | # from https://stackoverflow.com/a/273227
53 | if not os.path.exists(directory):
54 | os.makedirs(directory)
55 |
56 |
57 | random.seed(0)
58 |
59 | def distributed_is_initialized():
60 | if distributed.is_available():
61 | if distributed.is_initialized():
62 | return True
63 | return False
64 |
65 |
66 | def lr_policy(lr_fn):
67 | def _alr(optimizer, iteration, epoch):
68 | lr = lr_fn(iteration, epoch)
69 | for param_group in optimizer.param_groups:
70 | param_group['lr'] = lr
71 |
72 | return _alr
73 |
74 |
75 | def lr_cosine_policy(base_lr, warmup_length, epochs):
76 | def _lr_fn(iteration, epoch):
77 | if epoch < warmup_length:
78 | lr = base_lr * (epoch + 1) / warmup_length
79 | else:
80 | e = epoch - warmup_length
81 | es = epochs - warmup_length
82 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
83 | return lr
84 |
85 | return lr_policy(_lr_fn)
86 |
87 |
88 | def beta_policy(mom_fn):
89 | def _alr(optimizer, iteration, epoch, param, indx):
90 | mom = mom_fn(iteration, epoch)
91 | for param_group in optimizer.param_groups:
92 | param_group[param][indx] = mom
93 |
94 | return _alr
95 |
96 |
97 | def mom_cosine_policy(base_beta, warmup_length, epochs):
98 | def _beta_fn(iteration, epoch):
99 | if epoch < warmup_length:
100 | beta = base_beta * (epoch + 1) / warmup_length
101 | else:
102 | beta = base_beta
103 | return beta
104 |
105 | return beta_policy(_beta_fn)
106 |
107 |
108 | def clip(image_tensor, use_fp16=False):
109 | '''
110 | adjust the input based on mean and variance
111 | '''
112 | if use_fp16:
113 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16)
114 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16)
115 | else:
116 | mean = np.array([0.485, 0.456, 0.406])
117 | std = np.array([0.229, 0.224, 0.225])
118 | for c in range(3):
119 | m, s = mean[c], std[c]
120 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
121 | return image_tensor
122 |
123 |
124 | def denormalize(image_tensor, use_fp16=False):
125 | '''
126 | convert floats back to input
127 | '''
128 | if use_fp16:
129 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16)
130 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16)
131 | else:
132 | mean = np.array([0.485, 0.456, 0.406])
133 | std = np.array([0.229, 0.224, 0.225])
134 |
135 | for c in range(3):
136 | m, s = mean[c], std[c]
137 | image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)
138 |
139 | return image_tensor
--------------------------------------------------------------------------------