├── LICENSE ├── README.md ├── assets └── teaser_v7.jpeg ├── bpda_eot ├── LICENSE_BPDA └── bpda_eot_attack.py ├── classifiers ├── attribute_classifier.py ├── attribute_net.py └── cifar10_resnet.py ├── configs ├── celeba.yml ├── cifar10.yml └── imagenet.yml ├── data ├── __init__.py └── datasets.py ├── ddpm ├── LICENSE_UNET_DDPM └── unet_ddpm.py ├── diffpure.Dockerfile ├── eval_sde_adv.py ├── eval_sde_adv_bpda.py ├── guided_diffusion ├── LICENSE_GUIDED_DIFFUSION ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── run_scripts ├── celebahq │ ├── run_celebahq_bpda_glasses.sh │ └── run_celebahq_bpda_smiling.sh ├── cifar10 │ ├── run_cifar_bpda_eot.sh │ ├── run_cifar_rand_L2.sh │ ├── run_cifar_rand_L2_70-16-dp.sh │ ├── run_cifar_rand_L2_rn50_eps1.sh │ ├── run_cifar_rand_inf.sh │ ├── run_cifar_rand_inf_70-16-dp.sh │ ├── run_cifar_rand_inf_ode.sh │ ├── run_cifar_rand_inf_rn50.sh │ ├── run_cifar_stadv_rn50.sh │ ├── run_cifar_stand_L2.sh │ ├── run_cifar_stand_L2_70-16-dp.sh │ ├── run_cifar_stand_L2_rn50_eps1.sh │ ├── run_cifar_stand_inf.sh │ ├── run_cifar_stand_inf_70-16-dp.sh │ ├── run_cifar_stand_inf_ode.sh │ └── run_cifar_stand_inf_rn50.sh └── imagenet │ ├── run_in_rand_inf.sh │ ├── run_in_rand_inf_50-2.sh │ ├── run_in_rand_inf_deits.sh │ ├── run_in_stand_inf.sh │ ├── run_in_stand_inf_50-2.sh │ └── run_in_stand_inf_deits.sh ├── runners ├── diffpure_ddpm.py ├── diffpure_guided.py ├── diffpure_ldsde.py ├── diffpure_ode.py └── diffpure_sde.py ├── score_sde ├── LICENSE_SCORE_SDE ├── losses.py ├── models │ ├── __init__.py │ ├── ddpm.py │ ├── ema.py │ ├── layers.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── ncsnv2.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ └── utils.py ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── sampling.py └── sde_lib.py ├── stadv_eot ├── attacks.py └── recoloradv │ ├── LICENSE_RECOLORADV │ ├── __init__.py │ ├── color_spaces.py │ ├── color_transformers.py │ ├── mister_ed │ ├── README.md │ ├── __init__.py │ ├── adversarial_attacks.py │ ├── adversarial_perturbations.py │ ├── adversarial_training.py │ ├── config.py │ ├── loss_functions.py │ ├── spatial_transformers.py │ └── utils │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── discretization.py │ │ ├── image_utils.py │ │ ├── pytorch_ssim.py │ │ └── pytorch_utils.py │ ├── norms.py │ ├── perturbations.py │ └── utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for DiffPure 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /assets/teaser_v7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DiffPure/985703facc13da38d84bcaf52b3bc5b2ef9906f1/assets/teaser_v7.jpeg -------------------------------------------------------------------------------- /bpda_eot/LICENSE_BPDA: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mitch Hill and Jonathan Mitchell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | Code for train_ebm.py is derived from: https://github.com/point0bar1/ebm-anatomy 25 | Copyright (c) Mitch Hill and Erik Nijkamp under MIT License. 26 | 27 | MIT License 28 | 29 | Copyright (c) 2019 Mitch Hill and Erik Nijkamp 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | 49 | 50 | Code for WideResNet class in nets.py is derived from: https://github.com/meliketoy/wide-resnet.pytorch 51 | Copyright (c) Bumsoo Kim under MIT License. 52 | 53 | MIT License 54 | 55 | Copyright (c) 2018 Bumsoo Kim 56 | 57 | Permission is hereby granted, free of charge, to any person obtaining a copy 58 | of this software and associated documentation files (the "Software"), to deal 59 | in the Software without restriction, including without limitation the rights 60 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 61 | copies of the Software, and to permit persons to whom the Software is 62 | furnished to do so, subject to the following conditions: 63 | 64 | The above copyright notice and this permission notice shall be included in all 65 | copies or substantial portions of the Software. 66 | 67 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 68 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 69 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 70 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 71 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 72 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 73 | SOFTWARE. -------------------------------------------------------------------------------- /bpda_eot/bpda_eot_attack.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ebm-defense. 5 | # 6 | # Source: 7 | # https://github.com/point0bar1/ebm-defense/blob/master/bpda_eot_attack.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_BPDA). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | criterion = torch.nn.CrossEntropyLoss() 18 | 19 | 20 | class BPDA_EOT_Attack(): 21 | def __init__(self, model, adv_eps=8.0/255, eot_defense_reps=150, eot_attack_reps=15): 22 | self.model = model 23 | 24 | self.config = { 25 | 'eot_defense_ave': 'logits', 26 | 'eot_attack_ave': 'logits', 27 | 'eot_defense_reps': eot_defense_reps, 28 | 'eot_attack_reps': eot_attack_reps, 29 | 'adv_steps': 50, 30 | 'adv_norm': 'l_inf', 31 | 'adv_eps': adv_eps, 32 | 'adv_eta': 2.0 / 255, 33 | 'log_freq': 10 34 | } 35 | 36 | print(f'BPDA_EOT config: {self.config}') 37 | 38 | def purify(self, x): 39 | return self.model(x, mode='purify') 40 | 41 | def eot_defense_prediction(seslf, logits, reps=1, eot_defense_ave=None): 42 | if eot_defense_ave == 'logits': 43 | logits_pred = logits.view([reps, int(logits.shape[0]/reps), logits.shape[1]]).mean(0) 44 | elif eot_defense_ave == 'softmax': 45 | logits_pred = F.softmax(logits, dim=1).view([reps, int(logits.shape[0]/reps), logits.shape[1]]).mean(0) 46 | elif eot_defense_ave == 'logsoftmax': 47 | logits_pred = F.log_softmax(logits, dim=1).view([reps, int(logits.shape[0] / reps), logits.shape[1]]).mean(0) 48 | elif reps == 1: 49 | logits_pred = logits 50 | else: 51 | raise RuntimeError('Invalid ave_method_pred (use "logits" or "softmax" or "logsoftmax")') 52 | _, y_pred = torch.max(logits_pred, 1) 53 | return y_pred 54 | 55 | def eot_attack_loss(self, logits, y, reps=1, eot_attack_ave='loss'): 56 | if eot_attack_ave == 'logits': 57 | logits_loss = logits.view([reps, int(logits.shape[0] / reps), logits.shape[1]]).mean(0) 58 | y_loss = y 59 | elif eot_attack_ave == 'softmax': 60 | logits_loss = torch.log(F.softmax(logits, dim=1).view([reps, int(logits.shape[0] / reps), logits.shape[1]]).mean(0)) 61 | y_loss = y 62 | elif eot_attack_ave == 'logsoftmax': 63 | logits_loss = F.log_softmax(logits, dim=1).view([reps, int(logits.shape[0] / reps), logits.shape[1]]).mean(0) 64 | y_loss = y 65 | elif eot_attack_ave == 'loss': 66 | logits_loss = logits 67 | y_loss = y.repeat(reps) 68 | else: 69 | raise RuntimeError('Invalid ave_method_eot ("logits", "softmax", "logsoftmax", "loss")') 70 | loss = criterion(logits_loss, y_loss) 71 | return loss 72 | 73 | def predict(self, X, y, requires_grad=True, reps=1, eot_defense_ave=None, eot_attack_ave='loss'): 74 | if requires_grad: 75 | logits = self.model(X, mode='classify') 76 | else: 77 | with torch.no_grad(): 78 | logits = self.model(X.data, mode='classify') 79 | 80 | y_pred = self.eot_defense_prediction(logits.detach(), reps, eot_defense_ave) 81 | correct = torch.eq(y_pred, y) 82 | loss = self.eot_attack_loss(logits, y, reps, eot_attack_ave) 83 | 84 | return correct.detach(), loss 85 | 86 | def pgd_update(self, X_adv, grad, X, adv_norm, adv_eps, adv_eta, eps=1e-10): 87 | if adv_norm == 'l_inf': 88 | X_adv.data += adv_eta * torch.sign(grad) 89 | X_adv = torch.clamp(torch.min(X + adv_eps, torch.max(X - adv_eps, X_adv)), min=0, max=1) 90 | elif adv_norm == 'l_2': 91 | X_adv.data += adv_eta * grad / grad.view(X.shape[0], -1).norm(p=2, dim=1).view(X.shape[0], 1, 1, 1) 92 | dists = (X_adv - X).view(X.shape[0], -1).norm(dim=1, p=2).view(X.shape[0], 1, 1, 1) 93 | X_adv = torch.clamp(X + torch.min(dists, adv_eps*torch.ones_like(dists))*(X_adv-X)/(dists+eps), min=0, max=1) 94 | else: 95 | raise RuntimeError('Invalid adv_norm ("l_inf" or "l_2"') 96 | return X_adv 97 | 98 | def purify_and_predict(self, X, y, purify_reps=1, requires_grad=True): 99 | X_repeat = X.repeat([purify_reps, 1, 1, 1]) 100 | X_repeat_purified = self.purify(X_repeat).detach().clone() 101 | X_repeat_purified.requires_grad_() 102 | correct, loss = self.predict(X_repeat_purified, y, requires_grad, purify_reps, 103 | self.config['eot_defense_ave'], self.config['eot_attack_ave']) 104 | if requires_grad: 105 | X_grads = torch.autograd.grad(loss, [X_repeat_purified])[0] 106 | # average gradients over parallel samples for EOT attack 107 | attack_grad = X_grads.view([purify_reps]+list(X.shape)).mean(dim=0) 108 | return correct, attack_grad 109 | else: 110 | return correct, None 111 | 112 | def eot_defense_verification(self, X_adv, y, correct, defended): 113 | for verify_ind in range(correct.nelement()): 114 | if correct[verify_ind] == 0 and defended[verify_ind] == 1: 115 | defended[verify_ind] = self.purify_and_predict(X_adv[verify_ind].unsqueeze(0), y[verify_ind].view([1]), 116 | self.config['eot_defense_reps'], requires_grad=False)[0] 117 | return defended 118 | 119 | def eval_and_bpda_eot_grad(self, X_adv, y, defended, requires_grad=True): 120 | correct, attack_grad = self.purify_and_predict(X_adv, y, self.config['eot_attack_reps'], requires_grad) 121 | if self.config['eot_defense_reps'] > 0: 122 | defended = self.eot_defense_verification(X_adv, y, correct, defended) 123 | else: 124 | defended *= correct 125 | return defended, attack_grad 126 | 127 | def attack_batch(self, X, y): 128 | # get baseline accuracy for natural images 129 | defended = self.eval_and_bpda_eot_grad(X, y, torch.ones_like(y).bool(), False)[0] 130 | print('Baseline: {} of {}'.format(defended.sum(), len(defended))) 131 | 132 | class_batch = torch.zeros([self.config['adv_steps'] + 2, X.shape[0]]).bool() 133 | class_batch[0] = defended.cpu() 134 | ims_adv_batch = torch.zeros(X.shape) 135 | for ind in range(defended.nelement()): 136 | if defended[ind] == 0: 137 | ims_adv_batch[ind] = X[ind].cpu() 138 | 139 | X_adv = X.clone() 140 | 141 | # adversarial attacks on a single batch of images 142 | for step in range(self.config['adv_steps'] + 1): 143 | defended, attack_grad = self.eval_and_bpda_eot_grad(X_adv, y, defended) 144 | 145 | class_batch[step+1] = defended.cpu() 146 | for ind in range(defended.nelement()): 147 | if class_batch[step, ind] == 1 and defended[ind] == 0: 148 | ims_adv_batch[ind] = X_adv[ind].cpu() 149 | 150 | # update adversarial images (except on final iteration so final adv images match final eval) 151 | if step < self.config['adv_steps']: 152 | X_adv = self.pgd_update(X_adv, attack_grad, X, self.config['adv_norm'], self.config['adv_eps'], self.config['adv_eta']) 153 | X_adv = X_adv.detach().clone() 154 | 155 | if step == 1 or step % self.config['log_freq'] == 0 or step == self.config['adv_steps']: 156 | print('Attack {} of {} Batch defended: {} of {}'. 157 | format(step, self.config['adv_steps'], int(torch.sum(defended).cpu().numpy()), X_adv.shape[0])) 158 | 159 | if int(torch.sum(defended).cpu().numpy()) == 0: 160 | print('Attack successfully to the batch!') 161 | break 162 | 163 | for ind in range(defended.nelement()): 164 | if defended[ind] == 1: 165 | ims_adv_batch[ind] = X_adv[ind].cpu() 166 | 167 | return class_batch, ims_adv_batch 168 | 169 | def attack_all(self, X, y, batch_size): 170 | class_path = torch.zeros([self.config['adv_steps'] + 2, 0]).bool() 171 | ims_adv = torch.zeros(0) 172 | 173 | n_batches = X.shape[0] // batch_size 174 | if n_batches == 0 and X.shape[0] > 0: 175 | n_batches = 1 176 | for counter in range(n_batches): 177 | X_batch = X[counter * batch_size:min((counter + 1) * batch_size, X.shape[0])].clone().to(X.device) 178 | y_batch = y[counter * batch_size:min((counter + 1) * batch_size, X.shape[0])].clone().to(X.device) 179 | 180 | class_batch, ims_adv_batch = self.attack_batch(X_batch.contiguous(), y_batch.contiguous()) 181 | class_path = torch.cat((class_path, class_batch), dim=1) 182 | ims_adv = torch.cat((ims_adv, ims_adv_batch), dim=0) 183 | print(f'finished {counter}-th batch in attack_all') 184 | 185 | return class_path, ims_adv 186 | -------------------------------------------------------------------------------- /classifiers/attribute_classifier.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | import os 10 | from . import attribute_net 11 | 12 | softmax = torch.nn.Softmax(dim=1) 13 | 14 | 15 | def downsample(images, size=256): 16 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 17 | # follows https://github.com/NVlabs/stylegan/blob/master/metrics/linear_separability.py#L127 18 | if images.shape[2] > size: 19 | factor = images.shape[2] // size 20 | assert (factor * size == images.shape[2]) 21 | images = images.view( 22 | [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 23 | images = images.mean(dim=[3, 5]) 24 | return images 25 | else: 26 | assert (images.shape[-1] == 256) 27 | return images 28 | 29 | 30 | def get_logit(net, im): 31 | im_256 = downsample(im) 32 | logit = net(im_256) 33 | return logit 34 | 35 | 36 | def get_softmaxed(net, im): 37 | logit = get_logit(net, im) 38 | logits = torch.cat([logit, -logit], dim=1) 39 | softmaxed = softmax(torch.cat([logit, -logit], dim=1))[:, 1] 40 | return logits, softmaxed 41 | 42 | 43 | def load_attribute_classifier(attribute, ckpt_path=None): 44 | if ckpt_path is None: 45 | base_path = 'pretrained/celebahq' 46 | attribute_pkl = os.path.join(base_path, attribute, 'net_best.pth') 47 | ckpt = torch.load(attribute_pkl) 48 | else: 49 | ckpt = torch.load(ckpt_path) 50 | print("Using classifier at epoch: %d" % ckpt['epoch']) 51 | if 'valacc' in ckpt.keys(): 52 | print("Validation acc on raw images: %0.5f" % ckpt['valacc']) 53 | detector = attribute_net.from_state_dict( 54 | ckpt['state_dict'], fixed_size=True, use_mbstd=False).cuda().eval() 55 | return detector 56 | 57 | 58 | class ClassifierWrapper(torch.nn.Module): 59 | def __init__(self, classifier_name, ckpt_path=None, device='cuda'): 60 | super(ClassifierWrapper, self).__init__() 61 | self.net = load_attribute_classifier(classifier_name, ckpt_path).eval().to(device) 62 | 63 | def forward(self, ims): 64 | out = (ims - 0.5) / 0.5 65 | return get_softmaxed(self.net, out)[0] 66 | -------------------------------------------------------------------------------- /classifiers/attribute_net.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | 13 | def lerp_clip(a, b, t): 14 | return a + (b - a) * torch.clamp(t, 0.0, 1.0) 15 | 16 | 17 | class WScaleLayer(nn.Module): 18 | def __init__(self, size, fan_in, gain=np.sqrt(2), bias=True): 19 | super(WScaleLayer, self).__init__() 20 | self.scale = gain / np.sqrt(fan_in) # No longer a parameter 21 | if bias: 22 | self.b = nn.Parameter(torch.randn(size)) 23 | else: 24 | self.b = 0 25 | self.size = size 26 | 27 | def forward(self, x): 28 | x_size = x.size() 29 | x = x * self.scale 30 | # modified to remove warning 31 | if type(self.b) == nn.Parameter and len(x_size) == 4: 32 | x = x + self.b.view(1, -1, 1, 1).expand( 33 | x_size[0], self.size, x_size[2], x_size[3]) 34 | if type(self.b) == nn.Parameter and len(x_size) == 2: 35 | x = x + self.b.view(1, -1).expand( 36 | x_size[0], self.size) 37 | return x 38 | 39 | 40 | class WScaleConv2d(nn.Module): 41 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, 42 | bias=True, gain=np.sqrt(2)): 43 | super().__init__() 44 | self.conv = nn.Conv2d(in_channels, out_channels, 45 | kernel_size=kernel_size, 46 | padding=padding, 47 | bias=False) 48 | fan_in = in_channels * kernel_size * kernel_size 49 | self.wscale = WScaleLayer(out_channels, fan_in, gain=gain, bias=bias) 50 | 51 | def forward(self, x): 52 | return self.wscale(self.conv(x)) 53 | 54 | 55 | class WScaleLinear(nn.Module): 56 | def __init__(self, in_channels, out_channels, bias=True, gain=np.sqrt(2)): 57 | super().__init__() 58 | self.linear = nn.Linear(in_channels, out_channels, bias=False) 59 | self.wscale = WScaleLayer(out_channels, in_channels, gain=gain, 60 | bias=bias) 61 | 62 | def forward(self, x): 63 | return self.wscale(self.linear(x)) 64 | 65 | 66 | class FromRGB(nn.Module): 67 | def __init__(self, in_channels, out_channels, kernel_size, 68 | act=nn.LeakyReLU(0.2), bias=True): 69 | super().__init__() 70 | self.conv = WScaleConv2d(in_channels, out_channels, kernel_size, 71 | padding=0, bias=bias) 72 | self.act = act 73 | 74 | def forward(self, x): 75 | return self.act(self.conv(x)) 76 | 77 | 78 | class Downscale2d(nn.Module): 79 | def __init__(self, factor=2): 80 | super().__init__() 81 | self.downsample = nn.AvgPool2d(kernel_size=factor, stride=factor) 82 | 83 | def forward(self, x): 84 | return self.downsample(x) 85 | 86 | 87 | class DownscaleConvBlock(nn.Module): 88 | def __init__(self, in_channels, conv0_channels, conv1_channels, 89 | kernel_size, padding, bias=True, act=nn.LeakyReLU(0.2)): 90 | super().__init__() 91 | self.downscale = Downscale2d() 92 | self.conv0 = WScaleConv2d(in_channels, conv0_channels, 93 | kernel_size=kernel_size, 94 | padding=padding, 95 | bias=bias) 96 | self.conv1 = WScaleConv2d(conv0_channels, conv1_channels, 97 | kernel_size=kernel_size, 98 | padding=padding, 99 | bias=bias) 100 | self.act = act 101 | 102 | def forward(self, x): 103 | x = self.act(self.conv0(x)) 104 | # conv2d_downscale2d applies downscaling before activation 105 | # the order matters here! has to be conv -> bias -> downscale -> act 106 | x = self.conv1(x) 107 | x = self.downscale(x) 108 | x = self.act(x) 109 | return x 110 | 111 | 112 | class MinibatchStdLayer(nn.Module): 113 | def __init__(self, group_size=4): 114 | super().__init__() 115 | self.group_size = group_size 116 | 117 | def forward(self, x): 118 | group_size = min(self.group_size, x.shape[0]) 119 | s = x.shape 120 | y = x.view([group_size, -1, s[1], s[2], s[3]]) 121 | y = y.float() 122 | y = y - torch.mean(y, dim=0, keepdim=True) 123 | y = torch.mean(y * y, dim=0) 124 | y = torch.sqrt(y + 1e-8) 125 | y = torch.mean(torch.mean(torch.mean(y, dim=3, keepdim=True), 126 | dim=2, keepdim=True), dim=1, keepdim=True) 127 | y = y.type(x.type()) 128 | y = y.repeat(group_size, 1, s[2], s[3]) 129 | return torch.cat([x, y], dim=1) 130 | 131 | 132 | class PredictionBlock(nn.Module): 133 | def __init__(self, in_channels, dense0_feat, dense1_feat, out_feat, 134 | pool_size=2, act=nn.LeakyReLU(0.2), use_mbstd=True): 135 | super().__init__() 136 | self.use_mbstd = use_mbstd # attribute classifiers don't have this 137 | if self.use_mbstd: 138 | self.mbstd_layer = MinibatchStdLayer() 139 | # MinibatchStdLayer adds an additional feature dimension 140 | self.conv = WScaleConv2d(in_channels + int(self.use_mbstd), 141 | dense0_feat, kernel_size=3, padding=1) 142 | self.dense0 = WScaleLinear(dense0_feat * pool_size * pool_size, dense1_feat) 143 | self.dense1 = WScaleLinear(dense1_feat, out_feat, gain=1) 144 | self.act = act 145 | 146 | def forward(self, x): 147 | if self.use_mbstd: 148 | x = self.mbstd_layer(x) 149 | x = self.act(self.conv(x)) 150 | x = x.view([x.shape[0], -1]) 151 | x = self.act(self.dense0(x)) 152 | x = self.dense1(x) 153 | return x 154 | 155 | 156 | class D(nn.Module): 157 | 158 | def __init__( 159 | self, 160 | num_channels=3, # Number of input color channels. Overridden based on dataset. 161 | resolution=128, # Input resolution. Overridden based on dataset. 162 | fmap_base=8192, # Overall multiplier for the number of feature maps. 163 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 164 | fmap_max=512, # Maximum number of feature maps in any layer. 165 | fixed_size=False, # True = load fromrgb_lod0 weights only 166 | use_mbstd=True, # False = no mbstd layer in PredictionBlock 167 | **kwargs): # Ignore unrecognized keyword args. 168 | super().__init__() 169 | 170 | self.resolution_log2 = resolution_log2 = int(np.log2(resolution)) 171 | assert resolution == 2 ** resolution_log2 and resolution >= 4 172 | 173 | def nf(stage): 174 | return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 175 | 176 | self.register_buffer('lod_in', torch.from_numpy(np.array(0.0))) 177 | 178 | res = resolution_log2 179 | 180 | setattr(self, 'fromrgb_lod0', FromRGB(num_channels, nf(res - 1), 1)) 181 | 182 | for i, res in enumerate(range(resolution_log2, 2, -1), 1): 183 | lod = resolution_log2 - res 184 | block = DownscaleConvBlock(nf(res - 1), nf(res - 1), nf(res - 2), 185 | kernel_size=3, padding=1) 186 | setattr(self, '%dx%d' % (2 ** res, 2 ** res), block) 187 | fromrgb = FromRGB(3, nf(res - 2), 1) 188 | if not fixed_size: 189 | setattr(self, 'fromrgb_lod%d' % i, fromrgb) 190 | 191 | res = 2 192 | pool_size = 2 ** res 193 | block = PredictionBlock(nf(res + 1 - 2), nf(res - 1), nf(res - 2), 1, 194 | pool_size, use_mbstd=use_mbstd) 195 | setattr(self, '%dx%d' % (pool_size, pool_size), block) 196 | self.downscale = Downscale2d() 197 | self.fixed_size = fixed_size 198 | 199 | def forward(self, img): 200 | x = self.fromrgb_lod0(img) 201 | for i, res in enumerate(range(self.resolution_log2, 2, -1), 1): 202 | lod = self.resolution_log2 - res 203 | x = getattr(self, '%dx%d' % (2 ** res, 2 ** res))(x) 204 | if not self.fixed_size: 205 | img = self.downscale(img) 206 | y = getattr(self, 'fromrgb_lod%d' % i)(img) 207 | x = lerp_clip(x, y, self.lod_in - lod) 208 | res = 2 209 | pool_size = 2 ** res 210 | out = getattr(self, '%dx%d' % (pool_size, pool_size))(x) 211 | return out 212 | 213 | 214 | def max_res_from_state_dict(state_dict): 215 | for i in range(3, 12): 216 | if '%dx%d.conv0.conv.weight' % (2 ** i, 2 ** i) not in state_dict: 217 | break 218 | return 2 ** (i - 1) 219 | 220 | 221 | def from_state_dict(state_dict, fixed_size=False, use_mbstd=True): 222 | res = max_res_from_state_dict(state_dict) 223 | print(f'res: {res}') 224 | d = D(num_channels=3, resolution=res, fixed_size=fixed_size, 225 | use_mbstd=use_mbstd) 226 | d.load_state_dict(state_dict) 227 | return d 228 | -------------------------------------------------------------------------------- /classifiers/cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.nn as nn 13 | 14 | 15 | # ---------------------------- ResNet ---------------------------- 16 | 17 | class Bottleneck(nn.Module): 18 | expansion = 4 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(Bottleneck, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 28 | 29 | self.shortcut = nn.Sequential() 30 | if stride != 1 or in_planes != self.expansion * planes: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion * planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = F.relu(self.bn2(self.conv2(out))) 39 | out = self.bn3(self.conv3(out)) 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class ResNet(nn.Module): 46 | def __init__(self, block, num_blocks, num_classes=10): 47 | super(ResNet, self).__init__() 48 | self.in_planes = 64 49 | 50 | num_input_channels = 3 51 | mean = (0.4914, 0.4822, 0.4465) 52 | std = (0.2471, 0.2435, 0.2616) 53 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 54 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 55 | 56 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(64) 58 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 59 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 60 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 61 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 62 | self.linear = nn.Linear(512 * block.expansion, num_classes) 63 | 64 | def _make_layer(self, block, planes, num_blocks, stride): 65 | strides = [stride] + [1] * (num_blocks - 1) 66 | layers = [] 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, stride)) 69 | self.in_planes = planes * block.expansion 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | out = (x - self.mean.to(x.device)) / self.std.to(x.device) 74 | out = F.relu(self.bn1(self.conv1(out))) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = self.layer4(out) 79 | out = F.avg_pool2d(out, 4) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | return out 83 | 84 | 85 | def ResNet50(): 86 | return ResNet(Bottleneck, [3, 4, 6, 3]) 87 | 88 | 89 | # ---------------------------- ResNet ---------------------------- 90 | 91 | 92 | # ---------------------------- WideResNet ---------------------------- 93 | 94 | class BasicBlock(nn.Module): 95 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 96 | super(BasicBlock, self).__init__() 97 | self.bn1 = nn.BatchNorm2d(in_planes) 98 | self.relu1 = nn.ReLU(inplace=True) 99 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 100 | padding=1, bias=False) 101 | self.bn2 = nn.BatchNorm2d(out_planes) 102 | self.relu2 = nn.ReLU(inplace=True) 103 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 104 | padding=1, bias=False) 105 | self.droprate = dropRate 106 | self.equalInOut = (in_planes == out_planes) 107 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 108 | padding=0, bias=False) or None 109 | 110 | def forward(self, x): 111 | if not self.equalInOut: 112 | x = self.relu1(self.bn1(x)) 113 | else: 114 | out = self.relu1(self.bn1(x)) 115 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 116 | if self.droprate > 0: 117 | out = F.dropout(out, p=self.droprate, training=self.training) 118 | out = self.conv2(out) 119 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 120 | 121 | 122 | class NetworkBlock(nn.Module): 123 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 124 | super(NetworkBlock, self).__init__() 125 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 126 | 127 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 128 | layers = [] 129 | for i in range(int(nb_layers)): 130 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | return self.layer(x) 135 | 136 | 137 | class WideResNet(nn.Module): 138 | """ Based on code from https://github.com/yaodongyu/TRADES """ 139 | 140 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 141 | super(WideResNet, self).__init__() 142 | 143 | num_input_channels = 3 144 | mean = (0.4914, 0.4822, 0.4465) 145 | std = (0.2471, 0.2435, 0.2616) 146 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 147 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 148 | 149 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 150 | assert ((depth - 4) % 6 == 0) 151 | n = (depth - 4) / 6 152 | block = BasicBlock 153 | # 1st conv before any network block 154 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 155 | padding=1, bias=False) 156 | # 1st block 157 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 158 | if sub_block1: 159 | # 1st sub-block 160 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 161 | # 2nd block 162 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 163 | # 3rd block 164 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 165 | # global average pooling and classifier 166 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 169 | self.nChannels = nChannels[3] 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | elif isinstance(m, nn.Linear) and not m.bias is None: 179 | m.bias.data.zero_() 180 | 181 | def forward(self, x): 182 | out = (x - self.mean.to(x.device)) / self.std.to(x.device) 183 | out = self.conv1(out) 184 | out = self.block1(out) 185 | out = self.block2(out) 186 | out = self.block3(out) 187 | out = self.relu(self.bn1(out)) 188 | out = F.avg_pool2d(out, 8) 189 | out = out.view(-1, self.nChannels) 190 | return self.fc(out) 191 | 192 | 193 | def WideResNet_70_16(): 194 | return WideResNet(depth=70, widen_factor=16, dropRate=0.0) 195 | 196 | 197 | def WideResNet_70_16_dropout(): 198 | return WideResNet(depth=70, widen_factor=16, dropRate=0.3) 199 | # ---------------------------- WideResNet ---------------------------- 200 | -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CelebA_HQ" 3 | category: "celeba" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | sampling: 34 | batch_size: 8 35 | last_only: True -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | category: "cifar10" 4 | image_size: 32 5 | num_channels: 3 6 | random_flip: True 7 | centered: True 8 | uniform_dequantization: False 9 | 10 | model: 11 | sigma_min: 0.01 12 | sigma_max: 50 13 | num_scales: 1000 14 | beta_min: 0.1 15 | beta_max: 20. 16 | dropout: 0.1 17 | 18 | name: 'ncsnpp' 19 | scale_by_sigma: False 20 | ema_rate: 0.9999 21 | normalization: 'GroupNorm' 22 | nonlinearity: 'swish' 23 | nf: 128 24 | ch_mult: [1, 2, 2, 2] # (1, 2, 2, 2) 25 | num_res_blocks: 8 26 | attn_resolutions: [16] # (16,) 27 | resamp_with_conv: True 28 | conditional: True 29 | fir: False 30 | fir_kernel: [1, 3, 3, 1] 31 | skip_rescale: True 32 | resblock_type: 'biggan' 33 | progressive: 'none' 34 | progressive_input: 'none' 35 | progressive_combine: 'sum' 36 | attention_type: 'ddpm' 37 | init_scale: 0. 38 | embedding_type: 'positional' 39 | fourier_scale: 16 40 | conv_size: 3 41 | 42 | training: 43 | sde: 'vpsde' 44 | continuous: True 45 | reduce_mean: True 46 | n_iters: 950001 47 | 48 | optim: 49 | weight_decay: 0 50 | optimizer: 'Adam' 51 | lr: 0.0002 # 2e-4 52 | beta1: 0.9 53 | eps: 0.00000001 # 1e-8 54 | warmup: 5000 55 | grad_clip: 1. 56 | 57 | sampling: 58 | n_steps_each: 1 59 | noise_removal: True 60 | probability_flow: False 61 | snr: 0.16 62 | 63 | method: 'pc' 64 | predictor: 'euler_maruyama' 65 | corrector: 'none' -------------------------------------------------------------------------------- /configs/imagenet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "ImageNet" 3 | category: "imagenet" 4 | 5 | model: 6 | attention_resolutions: '32,16,8' 7 | class_cond: False 8 | diffusion_steps: 1000 9 | rescale_timesteps: True 10 | timestep_respacing: '1000' # Modify this value to decrease the number of timesteps. 11 | image_size: 256 12 | learn_sigma: True 13 | noise_schedule: 'linear' 14 | num_channels: 256 15 | num_head_channels: 64 16 | num_res_blocks: 2 17 | resblock_updown: True 18 | use_fp16: True 19 | use_scale_shift_norm: True 20 | 21 | sampling: 22 | batch_size: 8 23 | last_only: True -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from .datasets import imagenet_lmdb_dataset, imagenet_lmdb_dataset_sub, cifar10_dataset_sub 9 | 10 | def get_transform(dataset_name, transform_type, base_size=256): 11 | from . import datasets 12 | if dataset_name == 'celebahq': 13 | return datasets.get_transform(dataset_name, transform_type, base_size) 14 | elif 'imagenet' in dataset_name: 15 | return datasets.get_transform(dataset_name, transform_type, base_size) 16 | else: 17 | raise NotImplementedError 18 | 19 | 20 | def get_dataset(dataset_name, partition, *args, **kwargs): 21 | from . import datasets 22 | if dataset_name == 'celebahq': 23 | return datasets.CelebAHQDataset(partition, *args, **kwargs) 24 | else: 25 | raise NotImplementedError -------------------------------------------------------------------------------- /ddpm/LICENSE_UNET_DDPM: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ermon Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /diffpure.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.0.3-devel-ubuntu20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | # Install package dependencies 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | autoconf \ 9 | automake \ 10 | libtool \ 11 | pkg-config \ 12 | ca-certificates \ 13 | wget \ 14 | git \ 15 | curl \ 16 | ca-certificates \ 17 | libjpeg-dev \ 18 | libpng-dev \ 19 | python \ 20 | python3-dev \ 21 | python3-pip \ 22 | python3-setuptools \ 23 | zlib1g-dev \ 24 | swig \ 25 | cmake \ 26 | vim \ 27 | locales \ 28 | locales-all \ 29 | screen \ 30 | zip \ 31 | unzip 32 | RUN apt-get clean 33 | 34 | ENV LC_ALL en_US.UTF-8 35 | ENV LANG en_US.UTF-8 36 | ENV LANGUAGE en_US.UTF-8 37 | 38 | RUN cd /usr/local/bin && \ 39 | ln -s /usr/bin/python3 python && \ 40 | ln -s /usr/bin/pip3 pip && \ 41 | pip install --upgrade pip setuptools 42 | 43 | RUN pip install numpy==1.19.4 \ 44 | pyyaml==5.3.1 \ 45 | wheel==0.34.2 \ 46 | scipy==1.5.2 \ 47 | torch==1.7.1 \ 48 | torchvision==0.8.2 \ 49 | pillow==7.2.0 \ 50 | matplotlib==3.3.0 \ 51 | tqdm==4.56.1 \ 52 | tensorboardX==2.0 \ 53 | seaborn==0.10.1 \ 54 | pandas==1.2.0 \ 55 | requests==2.25.0 \ 56 | xvfbwrapper==0.2.9 \ 57 | torchdiffeq==0.2.1 \ 58 | timm==0.5.4 \ 59 | lmdb \ 60 | Ninja \ 61 | foolbox \ 62 | torchsde \ 63 | git+https://github.com/RobustBench/robustbench.git -------------------------------------------------------------------------------- /guided_diffusion/LICENSE_GUIDED_DIFFUSION: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/__init__.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 11 | """ 12 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/dist_util.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Helpers for distributed training. 11 | """ 12 | 13 | import io 14 | import os 15 | import socket 16 | 17 | import blobfile as bf 18 | from mpi4py import MPI 19 | import torch as th 20 | import torch.distributed as dist 21 | 22 | # Change this to reflect your cluster layout. 23 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 24 | GPUS_PER_NODE = 8 25 | 26 | SETUP_RETRY_COUNT = 3 27 | 28 | 29 | def setup_dist(): 30 | """ 31 | Setup a distributed process group. 32 | """ 33 | if dist.is_initialized(): 34 | return 35 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 36 | 37 | comm = MPI.COMM_WORLD 38 | backend = "gloo" if not th.cuda.is_available() else "nccl" 39 | 40 | if backend == "gloo": 41 | hostname = "localhost" 42 | else: 43 | hostname = socket.gethostbyname(socket.getfqdn()) 44 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 45 | os.environ["RANK"] = str(comm.rank) 46 | os.environ["WORLD_SIZE"] = str(comm.size) 47 | 48 | port = comm.bcast(_find_free_port(), root=0) 49 | os.environ["MASTER_PORT"] = str(port) 50 | dist.init_process_group(backend=backend, init_method="env://") 51 | 52 | 53 | def dev(): 54 | """ 55 | Get the device to use for torch.distributed. 56 | """ 57 | if th.cuda.is_available(): 58 | return th.device(f"cuda") 59 | return th.device("cpu") 60 | 61 | 62 | def load_state_dict(path, **kwargs): 63 | """ 64 | Load a PyTorch file without redundant fetches across MPI ranks. 65 | """ 66 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 67 | if MPI.COMM_WORLD.Get_rank() == 0: 68 | with bf.BlobFile(path, "rb") as f: 69 | data = f.read() 70 | num_chunks = len(data) // chunk_size 71 | if len(data) % chunk_size: 72 | num_chunks += 1 73 | MPI.COMM_WORLD.bcast(num_chunks) 74 | for i in range(0, len(data), chunk_size): 75 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 76 | else: 77 | num_chunks = MPI.COMM_WORLD.bcast(None) 78 | data = bytes() 79 | for _ in range(num_chunks): 80 | data += MPI.COMM_WORLD.bcast(None) 81 | 82 | return th.load(io.BytesIO(data), **kwargs) 83 | 84 | 85 | def sync_params(params): 86 | """ 87 | Synchronize a sequence of Tensors across ranks from rank 0. 88 | """ 89 | for p in params: 90 | with th.no_grad(): 91 | dist.broadcast(p, 0) 92 | 93 | 94 | def _find_free_port(): 95 | try: 96 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 97 | s.bind(("", 0)) 98 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 99 | return s.getsockname()[1] 100 | finally: 101 | s.close() 102 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/fp16_util.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Helpers to train with 16-bit precision. 11 | """ 12 | 13 | import numpy as np 14 | import torch as th 15 | import torch.nn as nn 16 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 17 | 18 | from . import logger 19 | 20 | INITIAL_LOG_LOSS_SCALE = 20.0 21 | 22 | 23 | def convert_module_to_f16(l): 24 | """ 25 | Convert primitive modules to float16. 26 | """ 27 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 28 | l.weight.data = l.weight.data.half() 29 | if l.bias is not None: 30 | l.bias.data = l.bias.data.half() 31 | 32 | 33 | def convert_module_to_f32(l): 34 | """ 35 | Convert primitive modules to float32, undoing convert_module_to_f16(). 36 | """ 37 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 38 | l.weight.data = l.weight.data.float() 39 | if l.bias is not None: 40 | l.bias.data = l.bias.data.float() 41 | 42 | 43 | def make_master_params(param_groups_and_shapes): 44 | """ 45 | Copy model parameters into a (differently-shaped) list of full-precision 46 | parameters. 47 | """ 48 | master_params = [] 49 | for param_group, shape in param_groups_and_shapes: 50 | master_param = nn.Parameter( 51 | _flatten_dense_tensors( 52 | [param.detach().float() for (_, param) in param_group] 53 | ).view(shape) 54 | ) 55 | master_param.requires_grad = True 56 | master_params.append(master_param) 57 | return master_params 58 | 59 | 60 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 61 | """ 62 | Copy the gradients from the model parameters into the master parameters 63 | from make_master_params(). 64 | """ 65 | for master_param, (param_group, shape) in zip( 66 | master_params, param_groups_and_shapes 67 | ): 68 | master_param.grad = _flatten_dense_tensors( 69 | [param_grad_or_zeros(param) for (_, param) in param_group] 70 | ).view(shape) 71 | 72 | 73 | def master_params_to_model_params(param_groups_and_shapes, master_params): 74 | """ 75 | Copy the master parameter data back into the model parameters. 76 | """ 77 | # Without copying to a list, if a generator is passed, this will 78 | # silently not copy any parameters. 79 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 80 | for (_, param), unflat_master_param in zip( 81 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 82 | ): 83 | param.detach().copy_(unflat_master_param) 84 | 85 | 86 | def unflatten_master_params(param_group, master_param): 87 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 88 | 89 | 90 | def get_param_groups_and_shapes(named_model_params): 91 | named_model_params = list(named_model_params) 92 | scalar_vector_named_params = ( 93 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 94 | (-1), 95 | ) 96 | matrix_named_params = ( 97 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 98 | (1, -1), 99 | ) 100 | return [scalar_vector_named_params, matrix_named_params] 101 | 102 | 103 | def master_params_to_state_dict( 104 | model, param_groups_and_shapes, master_params, use_fp16 105 | ): 106 | if use_fp16: 107 | state_dict = model.state_dict() 108 | for master_param, (param_group, _) in zip( 109 | master_params, param_groups_and_shapes 110 | ): 111 | for (name, _), unflat_master_param in zip( 112 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 113 | ): 114 | assert name in state_dict 115 | state_dict[name] = unflat_master_param 116 | else: 117 | state_dict = model.state_dict() 118 | for i, (name, _value) in enumerate(model.named_parameters()): 119 | assert name in state_dict 120 | state_dict[name] = master_params[i] 121 | return state_dict 122 | 123 | 124 | def state_dict_to_master_params(model, state_dict, use_fp16): 125 | if use_fp16: 126 | named_model_params = [ 127 | (name, state_dict[name]) for name, _ in model.named_parameters() 128 | ] 129 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 130 | master_params = make_master_params(param_groups_and_shapes) 131 | else: 132 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 133 | return master_params 134 | 135 | 136 | def zero_master_grads(master_params): 137 | for param in master_params: 138 | param.grad = None 139 | 140 | 141 | def zero_grad(model_params): 142 | for param in model_params: 143 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 144 | if param.grad is not None: 145 | param.grad.detach_() 146 | param.grad.zero_() 147 | 148 | 149 | def param_grad_or_zeros(param): 150 | if param.grad is not None: 151 | return param.grad.data.detach() 152 | else: 153 | return th.zeros_like(param) 154 | 155 | 156 | class MixedPrecisionTrainer: 157 | def __init__( 158 | self, 159 | *, 160 | model, 161 | use_fp16=False, 162 | fp16_scale_growth=1e-3, 163 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 164 | ): 165 | self.model = model 166 | self.use_fp16 = use_fp16 167 | self.fp16_scale_growth = fp16_scale_growth 168 | 169 | self.model_params = list(self.model.parameters()) 170 | self.master_params = self.model_params 171 | self.param_groups_and_shapes = None 172 | self.lg_loss_scale = initial_lg_loss_scale 173 | 174 | if self.use_fp16: 175 | self.param_groups_and_shapes = get_param_groups_and_shapes( 176 | self.model.named_parameters() 177 | ) 178 | self.master_params = make_master_params(self.param_groups_and_shapes) 179 | self.model.convert_to_fp16() 180 | 181 | def zero_grad(self): 182 | zero_grad(self.model_params) 183 | 184 | def backward(self, loss: th.Tensor): 185 | if self.use_fp16: 186 | loss_scale = 2 ** self.lg_loss_scale 187 | (loss * loss_scale).backward() 188 | else: 189 | loss.backward() 190 | 191 | def optimize(self, opt: th.optim.Optimizer): 192 | if self.use_fp16: 193 | return self._optimize_fp16(opt) 194 | else: 195 | return self._optimize_normal(opt) 196 | 197 | def _optimize_fp16(self, opt: th.optim.Optimizer): 198 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 199 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 200 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 201 | if check_overflow(grad_norm): 202 | self.lg_loss_scale -= 1 203 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 204 | zero_master_grads(self.master_params) 205 | return False 206 | 207 | logger.logkv_mean("grad_norm", grad_norm) 208 | logger.logkv_mean("param_norm", param_norm) 209 | 210 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 211 | opt.step() 212 | zero_master_grads(self.master_params) 213 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 214 | self.lg_loss_scale += self.fp16_scale_growth 215 | return True 216 | 217 | def _optimize_normal(self, opt: th.optim.Optimizer): 218 | grad_norm, param_norm = self._compute_norms() 219 | logger.logkv_mean("grad_norm", grad_norm) 220 | logger.logkv_mean("param_norm", param_norm) 221 | opt.step() 222 | return True 223 | 224 | def _compute_norms(self, grad_scale=1.0): 225 | grad_norm = 0.0 226 | param_norm = 0.0 227 | for p in self.master_params: 228 | with th.no_grad(): 229 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 230 | if p.grad is not None: 231 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 232 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 233 | 234 | def master_params_to_state_dict(self, master_params): 235 | return master_params_to_state_dict( 236 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 237 | ) 238 | 239 | def state_dict_to_master_params(self, state_dict): 240 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 241 | 242 | 243 | def check_overflow(value): 244 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 245 | -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | import math 10 | import random 11 | 12 | from PIL import Image 13 | import blobfile as bf 14 | from mpi4py import MPI 15 | import numpy as np 16 | from torch.utils.data import DataLoader, Dataset 17 | 18 | 19 | def load_data( 20 | *, 21 | data_dir, 22 | batch_size, 23 | image_size, 24 | class_cond=False, 25 | deterministic=False, 26 | random_crop=False, 27 | random_flip=True, 28 | ): 29 | """ 30 | For a dataset, create a generator over (images, kwargs) pairs. 31 | 32 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 33 | more keys, each of which map to a batched Tensor of their own. 34 | The kwargs dict can be used for class labels, in which case the key is "y" 35 | and the values are integer tensors of class labels. 36 | 37 | :param data_dir: a dataset directory. 38 | :param batch_size: the batch size of each returned pair. 39 | :param image_size: the size to which images are resized. 40 | :param class_cond: if True, include a "y" key in returned dicts for class 41 | label. If classes are not available and this is true, an 42 | exception will be raised. 43 | :param deterministic: if True, yield results in a deterministic order. 44 | :param random_crop: if True, randomly crop the images for augmentation. 45 | :param random_flip: if True, randomly flip the images for augmentation. 46 | """ 47 | if not data_dir: 48 | raise ValueError("unspecified data directory") 49 | all_files = _list_image_files_recursively(data_dir) 50 | classes = None 51 | if class_cond: 52 | # Assume classes are the first part of the filename, 53 | # before an underscore. 54 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 55 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 56 | classes = [sorted_classes[x] for x in class_names] 57 | dataset = ImageDataset( 58 | image_size, 59 | all_files, 60 | classes=classes, 61 | shard=MPI.COMM_WORLD.Get_rank(), 62 | num_shards=MPI.COMM_WORLD.Get_size(), 63 | random_crop=random_crop, 64 | random_flip=random_flip, 65 | ) 66 | if deterministic: 67 | loader = DataLoader( 68 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 69 | ) 70 | else: 71 | loader = DataLoader( 72 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 73 | ) 74 | while True: 75 | yield from loader 76 | 77 | 78 | def _list_image_files_recursively(data_dir): 79 | results = [] 80 | for entry in sorted(bf.listdir(data_dir)): 81 | full_path = bf.join(data_dir, entry) 82 | ext = entry.split(".")[-1] 83 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 84 | results.append(full_path) 85 | elif bf.isdir(full_path): 86 | results.extend(_list_image_files_recursively(full_path)) 87 | return results 88 | 89 | 90 | class ImageDataset(Dataset): 91 | def __init__( 92 | self, 93 | resolution, 94 | image_paths, 95 | classes=None, 96 | shard=0, 97 | num_shards=1, 98 | random_crop=False, 99 | random_flip=True, 100 | ): 101 | super().__init__() 102 | self.resolution = resolution 103 | self.local_images = image_paths[shard:][::num_shards] 104 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 105 | self.random_crop = random_crop 106 | self.random_flip = random_flip 107 | 108 | def __len__(self): 109 | return len(self.local_images) 110 | 111 | def __getitem__(self, idx): 112 | path = self.local_images[idx] 113 | with bf.BlobFile(path, "rb") as f: 114 | pil_image = Image.open(f) 115 | pil_image.load() 116 | pil_image = pil_image.convert("RGB") 117 | 118 | if self.random_crop: 119 | arr = random_crop_arr(pil_image, self.resolution) 120 | else: 121 | arr = center_crop_arr(pil_image, self.resolution) 122 | 123 | if self.random_flip and random.random() < 0.5: 124 | arr = arr[:, ::-1] 125 | 126 | arr = arr.astype(np.float32) / 127.5 - 1 127 | 128 | out_dict = {} 129 | if self.local_classes is not None: 130 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 131 | return np.transpose(arr, [2, 0, 1]), out_dict 132 | 133 | 134 | def center_crop_arr(pil_image, image_size): 135 | # We are not on a new enough PIL to support the `reducing_gap` 136 | # argument, which uses BOX downsampling at powers of two first. 137 | # Thus, we do it by hand to improve downsample quality. 138 | while min(*pil_image.size) >= 2 * image_size: 139 | pil_image = pil_image.resize( 140 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 141 | ) 142 | 143 | scale = image_size / min(*pil_image.size) 144 | pil_image = pil_image.resize( 145 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 146 | ) 147 | 148 | arr = np.array(pil_image) 149 | crop_y = (arr.shape[0] - image_size) // 2 150 | crop_x = (arr.shape[1] - image_size) // 2 151 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 152 | 153 | 154 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 155 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 156 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 157 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 158 | 159 | # We are not on a new enough PIL to support the `reducing_gap` 160 | # argument, which uses BOX downsampling at powers of two first. 161 | # Thus, we do it by hand to improve downsample quality. 162 | while min(*pil_image.size) >= 2 * smaller_dim_size: 163 | pil_image = pil_image.resize( 164 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 165 | ) 166 | 167 | scale = smaller_dim_size / min(*pil_image.size) 168 | pil_image = pil_image.resize( 169 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 170 | ) 171 | 172 | arr = np.array(pil_image) 173 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 174 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 175 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 176 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/losses.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Helpers for various likelihood-based losses. These are ported from the original 11 | Ho et al. diffusion models codebase: 12 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 13 | """ 14 | 15 | import numpy as np 16 | 17 | import torch as th 18 | 19 | 20 | def normal_kl(mean1, logvar1, mean2, logvar2): 21 | """ 22 | Compute the KL divergence between two gaussians. 23 | 24 | Shapes are automatically broadcasted, so batches can be compared to 25 | scalars, among other use cases. 26 | """ 27 | tensor = None 28 | for obj in (mean1, logvar1, mean2, logvar2): 29 | if isinstance(obj, th.Tensor): 30 | tensor = obj 31 | break 32 | assert tensor is not None, "at least one argument must be a Tensor" 33 | 34 | # Force variances to be Tensors. Broadcasting helps convert scalars to 35 | # Tensors, but it does not work for th.exp(). 36 | logvar1, logvar2 = [ 37 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 38 | for x in (logvar1, logvar2) 39 | ] 40 | 41 | return 0.5 * ( 42 | -1.0 43 | + logvar2 44 | - logvar1 45 | + th.exp(logvar1 - logvar2) 46 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 47 | ) 48 | 49 | 50 | def approx_standard_normal_cdf(x): 51 | """ 52 | A fast approximation of the cumulative distribution function of the 53 | standard normal. 54 | """ 55 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 56 | 57 | 58 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 59 | """ 60 | Compute the log-likelihood of a Gaussian distribution discretizing to a 61 | given image. 62 | 63 | :param x: the target images. It is assumed that this was uint8 values, 64 | rescaled to the range [-1, 1]. 65 | :param means: the Gaussian mean Tensor. 66 | :param log_scales: the Gaussian log stddev Tensor. 67 | :return: a tensor like x of log probabilities (in nats). 68 | """ 69 | assert x.shape == means.shape == log_scales.shape 70 | centered_x = x - means 71 | inv_stdv = th.exp(-log_scales) 72 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 73 | cdf_plus = approx_standard_normal_cdf(plus_in) 74 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 75 | cdf_min = approx_standard_normal_cdf(min_in) 76 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 77 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 78 | cdf_delta = cdf_plus - cdf_min 79 | log_probs = th.where( 80 | x < -0.999, 81 | log_cdf_plus, 82 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 83 | ) 84 | assert log_probs.shape == x.shape 85 | return log_probs 86 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Various utilities for neural networks. 11 | """ 12 | 13 | import math 14 | 15 | import torch as th 16 | import torch.nn as nn 17 | 18 | 19 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 20 | class SiLU(nn.Module): 21 | def forward(self, x): 22 | return x * th.sigmoid(x) 23 | 24 | 25 | class GroupNorm32(nn.GroupNorm): 26 | def forward(self, x): 27 | return super().forward(x.float()).type(x.dtype) 28 | 29 | 30 | def conv_nd(dims, *args, **kwargs): 31 | """ 32 | Create a 1D, 2D, or 3D convolution module. 33 | """ 34 | if dims == 1: 35 | return nn.Conv1d(*args, **kwargs) 36 | elif dims == 2: 37 | return nn.Conv2d(*args, **kwargs) 38 | elif dims == 3: 39 | return nn.Conv3d(*args, **kwargs) 40 | raise ValueError(f"unsupported dimensions: {dims}") 41 | 42 | 43 | def linear(*args, **kwargs): 44 | """ 45 | Create a linear module. 46 | """ 47 | return nn.Linear(*args, **kwargs) 48 | 49 | 50 | def avg_pool_nd(dims, *args, **kwargs): 51 | """ 52 | Create a 1D, 2D, or 3D average pooling module. 53 | """ 54 | if dims == 1: 55 | return nn.AvgPool1d(*args, **kwargs) 56 | elif dims == 2: 57 | return nn.AvgPool2d(*args, **kwargs) 58 | elif dims == 3: 59 | return nn.AvgPool3d(*args, **kwargs) 60 | raise ValueError(f"unsupported dimensions: {dims}") 61 | 62 | 63 | def update_ema(target_params, source_params, rate=0.99): 64 | """ 65 | Update target parameters to be closer to those of source parameters using 66 | an exponential moving average. 67 | 68 | :param target_params: the target parameter sequence. 69 | :param source_params: the source parameter sequence. 70 | :param rate: the EMA rate (closer to 1 means slower). 71 | """ 72 | for targ, src in zip(target_params, source_params): 73 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 74 | 75 | 76 | def zero_module(module): 77 | """ 78 | Zero out the parameters of a module and return it. 79 | """ 80 | for p in module.parameters(): 81 | p.detach().zero_() 82 | return module 83 | 84 | 85 | def scale_module(module, scale): 86 | """ 87 | Scale the parameters of a module and return it. 88 | """ 89 | for p in module.parameters(): 90 | p.detach().mul_(scale) 91 | return module 92 | 93 | 94 | def mean_flat(tensor): 95 | """ 96 | Take the mean over all non-batch dimensions. 97 | """ 98 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 99 | 100 | 101 | def normalization(channels): 102 | """ 103 | Make a standard normalization layer. 104 | 105 | :param channels: number of input channels. 106 | :return: an nn.Module for normalization. 107 | """ 108 | return GroupNorm32(32, channels) 109 | 110 | 111 | def timestep_embedding(timesteps, dim, max_period=10000): 112 | """ 113 | Create sinusoidal timestep embeddings. 114 | 115 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 116 | These may be fractional. 117 | :param dim: the dimension of the output. 118 | :param max_period: controls the minimum frequency of the embeddings. 119 | :return: an [N x dim] Tensor of positional embeddings. 120 | """ 121 | half = dim // 2 122 | freqs = th.exp( 123 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 124 | ).to(device=timesteps.device) 125 | args = timesteps[:, None].float() * freqs[None] 126 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 127 | if dim % 2: 128 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 129 | return embedding 130 | 131 | 132 | def checkpoint(func, inputs, params, flag): 133 | """ 134 | Evaluate a function without caching intermediate activations, allowing for 135 | reduced memory at the expense of extra compute in the backward pass. 136 | 137 | :param func: the function to evaluate. 138 | :param inputs: the argument sequence to pass to `func`. 139 | :param params: a sequence of parameters `func` depends on but does not 140 | explicitly take as arguments. 141 | :param flag: if False, disable gradient checkpointing. 142 | """ 143 | if flag: 144 | args = tuple(inputs) + tuple(params) 145 | return CheckpointFunction.apply(func, len(inputs), *args) 146 | else: 147 | return func(*inputs) 148 | 149 | 150 | class CheckpointFunction(th.autograd.Function): 151 | @staticmethod 152 | def forward(ctx, run_function, length, *args): 153 | ctx.run_function = run_function 154 | ctx.input_tensors = list(args[:length]) 155 | ctx.input_params = list(args[length:]) 156 | with th.no_grad(): 157 | output_tensors = ctx.run_function(*ctx.input_tensors) 158 | return output_tensors 159 | 160 | @staticmethod 161 | def backward(ctx, *output_grads): 162 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 163 | with th.enable_grad(): 164 | # Fixes a bug where the first op in run_function modifies the 165 | # Tensor storage in place, which is not allowed for detach()'d 166 | # Tensors. 167 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 168 | output_tensors = ctx.run_function(*shallow_copies) 169 | input_grads = th.autograd.grad( 170 | output_tensors, 171 | ctx.input_tensors + ctx.input_params, 172 | output_grads, 173 | allow_unused=True, 174 | ) 175 | del ctx.input_tensors 176 | del ctx.input_params 177 | del output_tensors 178 | return (None, None) + input_grads 179 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/resample.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | from abc import ABC, abstractmethod 10 | 11 | import numpy as np 12 | import torch as th 13 | import torch.distributed as dist 14 | 15 | 16 | def create_named_schedule_sampler(name, diffusion): 17 | """ 18 | Create a ScheduleSampler from a library of pre-defined samplers. 19 | 20 | :param name: the name of the sampler. 21 | :param diffusion: the diffusion object to sample for. 22 | """ 23 | if name == "uniform": 24 | return UniformSampler(diffusion) 25 | elif name == "loss-second-moment": 26 | return LossSecondMomentResampler(diffusion) 27 | else: 28 | raise NotImplementedError(f"unknown schedule sampler: {name}") 29 | 30 | 31 | class ScheduleSampler(ABC): 32 | """ 33 | A distribution over timesteps in the diffusion process, intended to reduce 34 | variance of the objective. 35 | 36 | By default, samplers perform unbiased importance sampling, in which the 37 | objective's mean is unchanged. 38 | However, subclasses may override sample() to change how the resampled 39 | terms are reweighted, allowing for actual changes in the objective. 40 | """ 41 | 42 | @abstractmethod 43 | def weights(self): 44 | """ 45 | Get a numpy array of weights, one per diffusion step. 46 | 47 | The weights needn't be normalized, but must be positive. 48 | """ 49 | 50 | def sample(self, batch_size, device): 51 | """ 52 | Importance-sample timesteps for a batch. 53 | 54 | :param batch_size: the number of timesteps. 55 | :param device: the torch device to save to. 56 | :return: a tuple (timesteps, weights): 57 | - timesteps: a tensor of timestep indices. 58 | - weights: a tensor of weights to scale the resulting losses. 59 | """ 60 | w = self.weights() 61 | p = w / np.sum(w) 62 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 63 | indices = th.from_numpy(indices_np).long().to(device) 64 | weights_np = 1 / (len(p) * p[indices_np]) 65 | weights = th.from_numpy(weights_np).float().to(device) 66 | return indices, weights 67 | 68 | 69 | class UniformSampler(ScheduleSampler): 70 | def __init__(self, diffusion): 71 | self.diffusion = diffusion 72 | self._weights = np.ones([diffusion.num_timesteps]) 73 | 74 | def weights(self): 75 | return self._weights 76 | 77 | 78 | class LossAwareSampler(ScheduleSampler): 79 | def update_with_local_losses(self, local_ts, local_losses): 80 | """ 81 | Update the reweighting using losses from a model. 82 | 83 | Call this method from each rank with a batch of timesteps and the 84 | corresponding losses for each of those timesteps. 85 | This method will perform synchronization to make sure all of the ranks 86 | maintain the exact same reweighting. 87 | 88 | :param local_ts: an integer Tensor of timesteps. 89 | :param local_losses: a 1D Tensor of losses. 90 | """ 91 | batch_sizes = [ 92 | th.tensor([0], dtype=th.int32, device=local_ts.device) 93 | for _ in range(dist.get_world_size()) 94 | ] 95 | dist.all_gather( 96 | batch_sizes, 97 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 98 | ) 99 | 100 | # Pad all_gather batches to be the maximum batch size. 101 | batch_sizes = [x.item() for x in batch_sizes] 102 | max_bs = max(batch_sizes) 103 | 104 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 105 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 106 | dist.all_gather(timestep_batches, local_ts) 107 | dist.all_gather(loss_batches, local_losses) 108 | timesteps = [ 109 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 110 | ] 111 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 112 | self.update_with_all_losses(timesteps, losses) 113 | 114 | @abstractmethod 115 | def update_with_all_losses(self, ts, losses): 116 | """ 117 | Update the reweighting using losses from a model. 118 | 119 | Sub-classes should override this method to update the reweighting 120 | using losses from the model. 121 | 122 | This method directly updates the reweighting without synchronizing 123 | between workers. It is called by update_with_local_losses from all 124 | ranks with identical arguments. Thus, it should have deterministic 125 | behavior to maintain state across workers. 126 | 127 | :param ts: a list of int timesteps. 128 | :param losses: a list of float losses, one per timestep. 129 | """ 130 | 131 | 132 | class LossSecondMomentResampler(LossAwareSampler): 133 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 134 | self.diffusion = diffusion 135 | self.history_per_term = history_per_term 136 | self.uniform_prob = uniform_prob 137 | self._loss_history = np.zeros( 138 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 139 | ) 140 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 141 | 142 | def weights(self): 143 | if not self._warmed_up(): 144 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 145 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 146 | weights /= np.sum(weights) 147 | weights *= 1 - self.uniform_prob 148 | weights += self.uniform_prob / len(weights) 149 | return weights 150 | 151 | def update_with_all_losses(self, ts, losses): 152 | for t, loss in zip(ts, losses): 153 | if self._loss_counts[t] == self.history_per_term: 154 | # Shift out the oldest loss term. 155 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 156 | self._loss_history[t, -1] = loss 157 | else: 158 | self._loss_history[t, self._loss_counts[t]] = loss 159 | self._loss_counts[t] += 1 160 | 161 | def _warmed_up(self): 162 | return (self._loss_counts == self.history_per_term).all() 163 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | import numpy as np 10 | import torch as th 11 | 12 | from .gaussian_diffusion import GaussianDiffusion 13 | 14 | 15 | def space_timesteps(num_timesteps, section_counts): 16 | """ 17 | Create a list of timesteps to use from an original diffusion process, 18 | given the number of timesteps we want to take from equally-sized portions 19 | of the original process. 20 | 21 | For example, if there's 300 timesteps and the section counts are [10,15,20] 22 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 23 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 24 | 25 | If the stride is a string starting with "ddim", then the fixed striding 26 | from the DDIM paper is used, and only one section is allowed. 27 | 28 | :param num_timesteps: the number of diffusion steps in the original 29 | process to divide up. 30 | :param section_counts: either a list of numbers, or a string containing 31 | comma-separated numbers, indicating the step count 32 | per section. As a special case, use "ddimN" where N 33 | is a number of steps to use the striding from the 34 | DDIM paper. 35 | :return: a set of diffusion steps from the original process to use. 36 | """ 37 | if isinstance(section_counts, str): 38 | if section_counts.startswith("ddim"): 39 | desired_count = int(section_counts[len("ddim") :]) 40 | for i in range(1, num_timesteps): 41 | if len(range(0, num_timesteps, i)) == desired_count: 42 | return set(range(0, num_timesteps, i)) 43 | raise ValueError( 44 | f"cannot create exactly {num_timesteps} steps with an integer stride" 45 | ) 46 | section_counts = [int(x) for x in section_counts.split(",")] 47 | size_per = num_timesteps // len(section_counts) 48 | extra = num_timesteps % len(section_counts) 49 | start_idx = 0 50 | all_steps = [] 51 | for i, section_count in enumerate(section_counts): 52 | size = size_per + (1 if i < extra else 0) 53 | if size < section_count: 54 | raise ValueError( 55 | f"cannot divide section of {size} steps into {section_count}" 56 | ) 57 | if section_count <= 1: 58 | frac_stride = 1 59 | else: 60 | frac_stride = (size - 1) / (section_count - 1) 61 | cur_idx = 0.0 62 | taken_steps = [] 63 | for _ in range(section_count): 64 | taken_steps.append(start_idx + round(cur_idx)) 65 | cur_idx += frac_stride 66 | all_steps += taken_steps 67 | start_idx += size 68 | return set(all_steps) 69 | 70 | 71 | class SpacedDiffusion(GaussianDiffusion): 72 | """ 73 | A diffusion process which can skip steps in a base diffusion process. 74 | 75 | :param use_timesteps: a collection (sequence or set) of timesteps from the 76 | original diffusion process to retain. 77 | :param kwargs: the kwargs to create the base diffusion process. 78 | """ 79 | 80 | def __init__(self, use_timesteps, **kwargs): 81 | self.use_timesteps = set(use_timesteps) 82 | self.timestep_map = [] 83 | self.original_num_steps = len(kwargs["betas"]) 84 | 85 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 86 | last_alpha_cumprod = 1.0 87 | new_betas = [] 88 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 89 | if i in self.use_timesteps: 90 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 91 | last_alpha_cumprod = alpha_cumprod 92 | self.timestep_map.append(i) 93 | kwargs["betas"] = np.array(new_betas) 94 | super().__init__(**kwargs) 95 | 96 | def p_mean_variance( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 100 | 101 | def training_losses( 102 | self, model, *args, **kwargs 103 | ): # pylint: disable=signature-differs 104 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 105 | 106 | def condition_mean(self, cond_fn, *args, **kwargs): 107 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 108 | 109 | def condition_score(self, cond_fn, *args, **kwargs): 110 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 111 | 112 | def _wrap_model(self, model): 113 | if isinstance(model, _WrappedModel): 114 | return model 115 | return _WrappedModel( 116 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 117 | ) 118 | 119 | def _scale_timesteps(self, t): 120 | # Scaling is done by the wrapped model. 121 | return t 122 | 123 | 124 | class _WrappedModel: 125 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 126 | self.model = model 127 | self.timestep_map = timestep_map 128 | self.rescale_timesteps = rescale_timesteps 129 | self.original_num_steps = original_num_steps 130 | 131 | def __call__(self, x, ts, **kwargs): 132 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 133 | new_ts = map_tensor[ts] 134 | if self.rescale_timesteps: 135 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 136 | return self.model(x, new_ts, **kwargs) 137 | -------------------------------------------------------------------------------- /run_scripts/celebahq/run_celebahq_bpda_glasses.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for classifier_name in celebahq__Eyeglasses; do 8 | for t in 500; do 9 | for adv_eps in 0.062745098; do 10 | for seed in $SEED1; do 11 | for data_seed in $SEED2; do 12 | 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv_bpda.py --exp ./exp_results --config celeba.yml \ 14 | -i celebahq-adv-$t-eps$adv_eps-2x4-disc-bpda-rev \ 15 | --t $t \ 16 | --adv_eps $adv_eps \ 17 | --adv_batch_size 2 \ 18 | --domain celebahq \ 19 | --classifier_name $classifier_name \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type celebahq-ddpm \ 23 | --eot_defense_reps 20 \ 24 | --eot_attack_reps 15 \ 25 | 26 | done 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/celebahq/run_celebahq_bpda_smiling.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for classifier_name in celebahq__Smiling; do 8 | for t in 500; do 9 | for adv_eps in 0.062745098; do 10 | for seed in $SEED1; do 11 | for data_seed in $SEED2; do 12 | 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv_bpda.py --exp ./exp_results --config celeba.yml \ 14 | -i celebahq-adv-$t-eps$adv_eps-2x4-disc-bpda-rev \ 15 | --t $t \ 16 | --adv_eps $adv_eps \ 17 | --adv_batch_size 2 \ 18 | --domain celebahq \ 19 | --classifier_name $classifier_name \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type celebahq-ddpm \ 23 | --eot_defense_reps 20 \ 24 | --eot_attack_reps 15 \ 25 | 26 | done 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_bpda_eot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv_bpda.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-200x1-bm0-t0-end1e-5-cont-bpda \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 10 \ 17 | --num_sub 200 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | 25 | done 26 | done 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_L2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 75; do 8 | for adv_eps in 0.5; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-L2-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 \ 26 | --lp_norm L2 \ 27 | 28 | done 29 | done 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_L2_70-16-dp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 75; do 8 | for adv_eps in 0.5; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-wres70-16-L2-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wrn-70-16-dropout \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 \ 26 | --lp_norm L2 \ 27 | 28 | done 29 | done 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_L2_rn50_eps1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 125; do 8 | for adv_eps in 1; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-L2-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-resnet-50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 \ 26 | --lp_norm L2 \ 27 | 28 | done 29 | done 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_inf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_inf_70-16-dp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-wres70-16-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wrn-70-16-dropout \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_inf_ode.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-ode-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type ode \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 \ 26 | --step_size 1e-3 27 | 28 | done 29 | done 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_rand_inf_rn50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 125; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-resnet-50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version rand \ 25 | --eot_iter 20 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stadv_rn50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 125; do 8 | for adv_eps in 0.05; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-stadv-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-resnet-50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version stadv \ 25 | --eot_iter 20 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_L2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 75; do 8 | for adv_eps in 0.5; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-L2 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | --lp_norm L2 \ 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_L2_70-16-dp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 75; do 8 | for adv_eps in 0.5; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-L2-wres70-16 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wrn-70-16-dropout \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | --lp_norm L2 \ 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_L2_rn50_eps1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 125; do 8 | for adv_eps in 1; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-L2 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-resnet-50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | --lp_norm L2 \ 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_inf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_inf_70-16-dp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-wres70-16 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wrn-70-16-dropout \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_inf_ode.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 100; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont-ode \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-wideresnet-28-10 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type ode \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | --step_size 1e-3 26 | 27 | done 28 | done 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_scripts/cifar10/run_cifar_stand_inf_rn50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 125; do 8 | for adv_eps in 0.031373; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0 python eval_sde_adv.py --exp ./exp_results --config cifar10.yml \ 13 | -i cifar10-robust_adv-$t-eps$adv_eps-64x1-bm0-t0-end1e-5-cont \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 64 \ 17 | --num_sub 64 \ 18 | --domain cifar10 \ 19 | --classifier_name cifar10-resnet-50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --score_type score_sde \ 24 | --attack_version standard \ 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_rand_inf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-resnet50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version rand \ 24 | --eot_iter 20 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_rand_inf_50-2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-wideresnet-50-2 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version rand \ 24 | --eot_iter 20 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_rand_inf_deits.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont-eot20 \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-deit-s \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version rand \ 24 | --eot_iter 20 25 | 26 | done 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_stand_inf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-resnet50 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version standard \ 24 | 25 | done 26 | done 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_stand_inf_50-2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-wideresnet-50-2 \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version standard \ 24 | 25 | done 26 | done 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /run_scripts/imagenet/run_in_stand_inf_deits.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../.. 3 | 4 | SEED1=$1 5 | SEED2=$2 6 | 7 | for t in 150; do 8 | for adv_eps in 0.0157; do 9 | for seed in $SEED1; do 10 | for data_seed in $SEED2; do 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 python eval_sde_adv.py --exp ./exp_results --config imagenet.yml \ 13 | -i imagenet-robust_adv-$t-eps$adv_eps-4x4-bm0-t0-end1e-5-cont \ 14 | --t $t \ 15 | --adv_eps $adv_eps \ 16 | --adv_batch_size 4 \ 17 | --num_sub 16 \ 18 | --domain imagenet \ 19 | --classifier_name imagenet-deit-s \ 20 | --seed $seed \ 21 | --data_seed $data_seed \ 22 | --diffusion_type sde \ 23 | --attack_version standard \ 24 | 25 | done 26 | done 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /runners/diffpure_ddpm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import random 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torchvision.utils as tvu 15 | 16 | from ddpm.unet_ddpm import Model 17 | 18 | 19 | def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps): 20 | betas = np.linspace(beta_start, beta_end, 21 | num_diffusion_timesteps, dtype=np.float64) 22 | assert betas.shape == (num_diffusion_timesteps,) 23 | return betas 24 | 25 | 26 | def extract(a, t, x_shape): 27 | """Extract coefficients from a based on t and reshape to make it 28 | broadcastable with x_shape.""" 29 | bs, = t.shape 30 | assert x_shape[0] == bs 31 | out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long()) 32 | assert out.shape == (bs,) 33 | out = out.reshape((bs,) + (1,) * (len(x_shape) - 1)) 34 | return out 35 | 36 | 37 | def image_editing_denoising_step_flexible_mask(x, t, *, model, logvar, betas): 38 | """ 39 | Sample from p(x_{t-1} | x_t) 40 | """ 41 | alphas = 1.0 - betas 42 | alphas_cumprod = alphas.cumprod(dim=0) 43 | 44 | model_output = model(x, t) 45 | weighted_score = betas / torch.sqrt(1 - alphas_cumprod) 46 | mean = extract(1 / torch.sqrt(alphas), t, x.shape) * (x - extract(weighted_score, t, x.shape) * model_output) 47 | 48 | logvar = extract(logvar, t, x.shape) 49 | noise = torch.randn_like(x) 50 | mask = 1 - (t == 0).float() 51 | mask = mask.reshape((x.shape[0],) + (1,) * (len(x.shape) - 1)) 52 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 53 | sample = sample.float() 54 | return sample 55 | 56 | 57 | class Diffusion(torch.nn.Module): 58 | def __init__(self, args, config, device=None): 59 | super().__init__() 60 | self.args = args 61 | self.config = config 62 | if device is None: 63 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 64 | self.device = device 65 | 66 | print("Loading model") 67 | if self.config.data.dataset == "CelebA_HQ": 68 | url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt" 69 | else: 70 | raise ValueError 71 | 72 | model = Model(self.config) 73 | ckpt = torch.hub.load_state_dict_from_url(url, map_location='cpu') 74 | model.load_state_dict(ckpt) 75 | model.eval() 76 | 77 | self.model = model 78 | 79 | self.model_var_type = config.model.var_type 80 | betas = get_beta_schedule( 81 | beta_start=config.diffusion.beta_start, 82 | beta_end=config.diffusion.beta_end, 83 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps 84 | ) 85 | self.betas = torch.from_numpy(betas).float() 86 | self.num_timesteps = betas.shape[0] 87 | 88 | alphas = 1.0 - betas 89 | alphas_cumprod = np.cumprod(alphas, axis=0) 90 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 91 | posterior_variance = betas * \ 92 | (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 93 | if self.model_var_type == "fixedlarge": 94 | self.logvar = np.log(np.append(posterior_variance[1], betas[1:])) 95 | 96 | elif self.model_var_type == 'fixedsmall': 97 | self.logvar = np.log(np.maximum(posterior_variance, 1e-20)) 98 | 99 | def image_editing_sample(self, img=None, bs_id=0, tag=None): 100 | assert isinstance(img, torch.Tensor) 101 | batch_size = img.shape[0] 102 | 103 | with torch.no_grad(): 104 | if tag is None: 105 | tag = 'rnd' + str(random.randint(0, 10000)) 106 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag) 107 | 108 | assert img.ndim == 4, img.ndim 109 | x0 = img 110 | 111 | if bs_id < 2: 112 | os.makedirs(out_dir, exist_ok=True) 113 | tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'original_input.png')) 114 | 115 | xs = [] 116 | for it in range(self.args.sample_step): 117 | e = torch.randn_like(x0) 118 | total_noise_levels = self.args.t 119 | a = (1 - self.betas).cumprod(dim=0).to(x0.device) 120 | x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt() 121 | 122 | if bs_id < 2: 123 | tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'init_{it}.png')) 124 | 125 | for i in reversed(range(total_noise_levels)): 126 | t = torch.tensor([i] * batch_size, device=img.device) 127 | x = image_editing_denoising_step_flexible_mask(x, t=t, model=self.model, 128 | logvar=self.logvar, 129 | betas=self.betas.to(img.device)) 130 | # added intermediate step vis 131 | if (i - 49) % 50 == 0 and bs_id < 2: 132 | tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'noise_t_{i}_{it}.png')) 133 | 134 | x0 = x 135 | 136 | if bs_id < 2: 137 | torch.save(x0, os.path.join(out_dir, f'samples_{it}.pth')) 138 | tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'samples_{it}.png')) 139 | 140 | xs.append(x0) 141 | 142 | return torch.cat(xs, dim=0) 143 | -------------------------------------------------------------------------------- /runners/diffpure_guided.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import random 10 | 11 | import torch 12 | import torchvision.utils as tvu 13 | 14 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults 15 | 16 | 17 | class GuidedDiffusion(torch.nn.Module): 18 | def __init__(self, args, config, device=None, model_dir='pretrained/guided_diffusion'): 19 | super().__init__() 20 | self.args = args 21 | self.config = config 22 | if device is None: 23 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 24 | self.device = device 25 | 26 | # load model 27 | model_config = model_and_diffusion_defaults() 28 | model_config.update(vars(self.config.model)) 29 | print(f'model_config: {model_config}') 30 | model, diffusion = create_model_and_diffusion(**model_config) 31 | model.load_state_dict(torch.load(f'{model_dir}/256x256_diffusion_uncond.pt', map_location='cpu')) 32 | model.requires_grad_(False).eval().to(self.device) 33 | 34 | if model_config['use_fp16']: 35 | model.convert_to_fp16() 36 | 37 | self.model = model 38 | self.diffusion = diffusion 39 | self.betas = torch.from_numpy(diffusion.betas).float().to(self.device) 40 | 41 | def image_editing_sample(self, img, bs_id=0, tag=None): 42 | with torch.no_grad(): 43 | assert isinstance(img, torch.Tensor) 44 | batch_size = img.shape[0] 45 | 46 | if tag is None: 47 | tag = 'rnd' + str(random.randint(0, 10000)) 48 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag) 49 | 50 | assert img.ndim == 4, img.ndim 51 | img = img.to(self.device) 52 | x0 = img 53 | 54 | if bs_id < 2: 55 | os.makedirs(out_dir, exist_ok=True) 56 | tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'original_input.png')) 57 | 58 | xs = [] 59 | for it in range(self.args.sample_step): 60 | e = torch.randn_like(x0) 61 | total_noise_levels = self.args.t 62 | a = (1 - self.betas).cumprod(dim=0) 63 | x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt() 64 | 65 | if bs_id < 2: 66 | tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'init_{it}.png')) 67 | 68 | for i in reversed(range(total_noise_levels)): 69 | t = torch.tensor([i] * batch_size, device=self.device) 70 | 71 | x = self.diffusion.p_sample(self.model, x, t, 72 | clip_denoised=True, 73 | denoised_fn=None, 74 | cond_fn=None, 75 | model_kwargs=None)["sample"] 76 | 77 | # added intermediate step vis 78 | if (i - 99) % 100 == 0 and bs_id < 2: 79 | tvu.save_image((x + 1) * 0.5, os.path.join(out_dir, f'noise_t_{i}_{it}.png')) 80 | 81 | x0 = x 82 | 83 | if bs_id < 2: 84 | torch.save(x0, os.path.join(out_dir, f'samples_{it}.pth')) 85 | tvu.save_image((x0 + 1) * 0.5, os.path.join(out_dir, f'samples_{it}.png')) 86 | 87 | xs.append(x0) 88 | 89 | return torch.cat(xs, dim=0) 90 | -------------------------------------------------------------------------------- /score_sde/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions related to loss computation and optimization. 17 | """ 18 | 19 | import torch 20 | import torch.optim as optim 21 | import numpy as np 22 | from .models import utils as mutils 23 | from .sde_lib import VESDE, VPSDE 24 | 25 | 26 | def get_optimizer(config, params): 27 | """Returns a flax optimizer object based on `config`.""" 28 | if config.optim.optimizer == 'Adam': 29 | optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, 30 | weight_decay=config.optim.weight_decay) 31 | else: 32 | raise NotImplementedError( 33 | f'Optimizer {config.optim.optimizer} not supported yet!') 34 | 35 | return optimizer 36 | 37 | 38 | def optimization_manager(config): 39 | """Returns an optimize_fn based on `config`.""" 40 | 41 | def optimize_fn(optimizer, params, step, lr=config.optim.lr, 42 | warmup=config.optim.warmup, 43 | grad_clip=config.optim.grad_clip): 44 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 45 | if warmup > 0: 46 | for g in optimizer.param_groups: 47 | g['lr'] = lr * np.minimum(step / warmup, 1.0) 48 | if grad_clip >= 0: 49 | torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) 50 | optimizer.step() 51 | 52 | return optimize_fn 53 | 54 | 55 | def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5): 56 | """Create a loss function for training with arbirary SDEs. 57 | 58 | Args: 59 | sde: An `sde_lib.SDE` object that represents the forward SDE. 60 | train: `True` for training loss and `False` for evaluation loss. 61 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 62 | continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires 63 | ad-hoc interpolation to take continuous time steps. 64 | likelihood_weighting: If `True`, weight the mixture of score matching losses 65 | according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper. 66 | eps: A `float` number. The smallest time step to sample from. 67 | 68 | Returns: 69 | A loss function. 70 | """ 71 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 72 | 73 | def loss_fn(model, batch): 74 | """Compute the loss function. 75 | 76 | Args: 77 | model: A score model. 78 | batch: A mini-batch of training data. 79 | 80 | Returns: 81 | loss: A scalar that represents the average loss value across the mini-batch. 82 | """ 83 | score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) 84 | t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps 85 | z = torch.randn_like(batch) 86 | mean, std = sde.marginal_prob(batch, t) 87 | perturbed_data = mean + std[:, None, None, None] * z 88 | score = score_fn(perturbed_data, t) 89 | 90 | if not likelihood_weighting: 91 | losses = torch.square(score * std[:, None, None, None] + z) 92 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) 93 | else: 94 | g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2 95 | losses = torch.square(score + z / std[:, None, None, None]) 96 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2 97 | 98 | loss = torch.mean(losses) 99 | return loss 100 | 101 | return loss_fn 102 | 103 | 104 | def get_smld_loss_fn(vesde, train, reduce_mean=False): 105 | """Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work.""" 106 | assert isinstance(vesde, VESDE), "SMLD training only works for VESDEs." 107 | 108 | # Previous SMLD models assume descending sigmas 109 | smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,)) 110 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 111 | 112 | def loss_fn(model, batch): 113 | model_fn = mutils.get_model_fn(model, train=train) 114 | labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device) 115 | sigmas = smld_sigma_array.to(batch.device)[labels] 116 | noise = torch.randn_like(batch) * sigmas[:, None, None, None] 117 | perturbed_data = noise + batch 118 | score = model_fn(perturbed_data, labels) 119 | target = -noise / (sigmas ** 2)[:, None, None, None] 120 | losses = torch.square(score - target) 121 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2 122 | loss = torch.mean(losses) 123 | return loss 124 | 125 | return loss_fn 126 | 127 | 128 | def get_ddpm_loss_fn(vpsde, train, reduce_mean=True): 129 | """Legacy code to reproduce previous results on DDPM. Not recommended for new work.""" 130 | assert isinstance(vpsde, VPSDE), "DDPM training only works for VPSDEs." 131 | 132 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 133 | 134 | def loss_fn(model, batch): 135 | model_fn = mutils.get_model_fn(model, train=train) 136 | labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) 137 | sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) 138 | sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) 139 | noise = torch.randn_like(batch) 140 | perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \ 141 | sqrt_1m_alphas_cumprod[labels, None, None, None] * noise 142 | score = model_fn(perturbed_data, labels) 143 | losses = torch.square(score - noise) 144 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) 145 | loss = torch.mean(losses) 146 | return loss 147 | 148 | return loss_fn 149 | 150 | 151 | def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False): 152 | """Create a one-step training/evaluation function. 153 | 154 | Args: 155 | sde: An `sde_lib.SDE` object that represents the forward SDE. 156 | optimize_fn: An optimization function. 157 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 158 | continuous: `True` indicates that the model is defined to take continuous time steps. 159 | likelihood_weighting: If `True`, weight the mixture of score matching losses according to 160 | https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. 161 | 162 | Returns: 163 | A one-step function for training or evaluation. 164 | """ 165 | if continuous: 166 | loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, 167 | continuous=True, likelihood_weighting=likelihood_weighting) 168 | else: 169 | assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." 170 | if isinstance(sde, VESDE): 171 | loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean) 172 | elif isinstance(sde, VPSDE): 173 | loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) 174 | else: 175 | raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") 176 | 177 | def step_fn(state, batch): 178 | """Running one step of training or evaluation. 179 | 180 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together 181 | for faster execution. 182 | 183 | Args: 184 | state: A dictionary of training information, containing the score model, optimizer, 185 | EMA status, and number of optimization steps. 186 | batch: A mini-batch of training/evaluation data. 187 | 188 | Returns: 189 | loss: The average loss value of this state. 190 | """ 191 | model = state['model'] 192 | if train: 193 | optimizer = state['optimizer'] 194 | optimizer.zero_grad() 195 | loss = loss_fn(model, batch) 196 | loss.backward() 197 | optimize_fn(optimizer, model.parameters(), step=state['step']) 198 | state['step'] += 1 199 | state['ema'].update(model.parameters()) 200 | else: 201 | with torch.no_grad(): 202 | ema = state['ema'] 203 | ema.store(model.parameters()) 204 | ema.copy_to(model.parameters()) 205 | loss = loss_fn(model, batch) 206 | ema.restore(model.parameters()) 207 | 208 | return loss 209 | 210 | return step_fn 211 | -------------------------------------------------------------------------------- /score_sde/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from . import ncsnpp 16 | -------------------------------------------------------------------------------- /score_sde/models/ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import functools 25 | 26 | from . import utils, layers, normalization 27 | 28 | RefineBlock = layers.RefineBlock 29 | ResidualBlock = layers.ResidualBlock 30 | ResnetBlockDDPM = layers.ResnetBlockDDPM 31 | Upsample = layers.Upsample 32 | Downsample = layers.Downsample 33 | conv3x3 = layers.ddpm_conv3x3 34 | get_act = layers.get_act 35 | get_normalization = normalization.get_normalization 36 | default_initializer = layers.default_init 37 | 38 | 39 | @utils.register_model(name='ddpm') 40 | class DDPM(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 45 | 46 | self.nf = nf = config.model.nf 47 | ch_mult = config.model.ch_mult 48 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 49 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 50 | dropout = config.model.dropout 51 | resamp_with_conv = config.model.resamp_with_conv 52 | self.num_resolutions = num_resolutions = len(ch_mult) 53 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 54 | 55 | AttnBlock = functools.partial(layers.AttnBlock) 56 | self.conditional = conditional = config.model.conditional 57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 58 | if conditional: 59 | # Condition on noise levels. 60 | modules = [nn.Linear(nf, nf * 4)] 61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 62 | nn.init.zeros_(modules[0].bias) 63 | modules.append(nn.Linear(nf * 4, nf * 4)) 64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 65 | nn.init.zeros_(modules[1].bias) 66 | 67 | self.centered = config.data.centered 68 | channels = config.data.num_channels 69 | 70 | # Downsampling block 71 | modules.append(conv3x3(channels, nf)) 72 | hs_c = [nf] 73 | in_ch = nf 74 | for i_level in range(num_resolutions): 75 | # Residual blocks for this resolution 76 | for i_block in range(num_res_blocks): 77 | out_ch = nf * ch_mult[i_level] 78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 79 | in_ch = out_ch 80 | if all_resolutions[i_level] in attn_resolutions: 81 | modules.append(AttnBlock(channels=in_ch)) 82 | hs_c.append(in_ch) 83 | if i_level != num_resolutions - 1: 84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 85 | hs_c.append(in_ch) 86 | 87 | in_ch = hs_c[-1] 88 | modules.append(ResnetBlock(in_ch=in_ch)) 89 | modules.append(AttnBlock(channels=in_ch)) 90 | modules.append(ResnetBlock(in_ch=in_ch)) 91 | 92 | # Upsampling block 93 | for i_level in reversed(range(num_resolutions)): 94 | for i_block in range(num_res_blocks + 1): 95 | out_ch = nf * ch_mult[i_level] 96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 97 | in_ch = out_ch 98 | if all_resolutions[i_level] in attn_resolutions: 99 | modules.append(AttnBlock(channels=in_ch)) 100 | if i_level != 0: 101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 102 | 103 | assert not hs_c 104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 105 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 106 | self.all_modules = nn.ModuleList(modules) 107 | 108 | self.scale_by_sigma = config.model.scale_by_sigma 109 | 110 | def forward(self, x, labels): 111 | modules = self.all_modules 112 | m_idx = 0 113 | if self.conditional: 114 | # timestep/scale embedding 115 | timesteps = labels 116 | temb = layers.get_timestep_embedding(timesteps, self.nf) 117 | temb = modules[m_idx](temb) 118 | m_idx += 1 119 | temb = modules[m_idx](self.act(temb)) 120 | m_idx += 1 121 | else: 122 | temb = None 123 | 124 | if self.centered: 125 | # Input is in [-1, 1] 126 | h = x 127 | else: 128 | # Input is in [0, 1] 129 | h = 2 * x - 1. 130 | 131 | # Downsampling block 132 | hs = [modules[m_idx](h)] 133 | m_idx += 1 134 | for i_level in range(self.num_resolutions): 135 | # Residual blocks for this resolution 136 | for i_block in range(self.num_res_blocks): 137 | h = modules[m_idx](hs[-1], temb) 138 | m_idx += 1 139 | if h.shape[-1] in self.attn_resolutions: 140 | h = modules[m_idx](h) 141 | m_idx += 1 142 | hs.append(h) 143 | if i_level != self.num_resolutions - 1: 144 | hs.append(modules[m_idx](hs[-1])) 145 | m_idx += 1 146 | 147 | h = hs[-1] 148 | h = modules[m_idx](h, temb) 149 | m_idx += 1 150 | h = modules[m_idx](h) 151 | m_idx += 1 152 | h = modules[m_idx](h, temb) 153 | m_idx += 1 154 | 155 | # Upsampling block 156 | for i_level in reversed(range(self.num_resolutions)): 157 | for i_block in range(self.num_res_blocks + 1): 158 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 159 | m_idx += 1 160 | if h.shape[-1] in self.attn_resolutions: 161 | h = modules[m_idx](h) 162 | m_idx += 1 163 | if i_level != 0: 164 | h = modules[m_idx](h) 165 | m_idx += 1 166 | 167 | assert not hs 168 | h = self.act(modules[m_idx](h)) 169 | m_idx += 1 170 | h = modules[m_idx](h) 171 | m_idx += 1 172 | assert m_idx == len(modules) 173 | 174 | if self.scale_by_sigma: 175 | # Divide the output by sigmas. Useful for training with the NCSN loss. 176 | # The DDPM loss scales the network output by sigma in the loss function, 177 | # so no need of doing it here. 178 | used_sigmas = self.sigmas[labels, None, None, None] 179 | h = h / used_sigmas 180 | 181 | return h 182 | -------------------------------------------------------------------------------- /score_sde/models/ema.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | # --------------------------------------------------------------- 8 | 9 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 10 | 11 | from __future__ import division 12 | from __future__ import unicode_literals 13 | 14 | import torch 15 | 16 | 17 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 18 | class ExponentialMovingAverage: 19 | """ 20 | Maintains (exponential) moving average of a set of parameters. 21 | """ 22 | 23 | def __init__(self, parameters, decay, use_num_updates=True): 24 | """ 25 | Args: 26 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 27 | `model.parameters()`. 28 | decay: The exponential decay. 29 | use_num_updates: Whether to use number of updates when computing 30 | averages. 31 | """ 32 | if decay < 0.0 or decay > 1.0: 33 | raise ValueError('Decay must be between 0 and 1') 34 | self.decay = decay 35 | self.num_updates = 0 if use_num_updates else None 36 | self.shadow_params = [p.clone().detach() 37 | for p in parameters if p.requires_grad] 38 | self.collected_params = [] 39 | 40 | def update(self, parameters): 41 | """ 42 | Update currently maintained parameters. 43 | 44 | Call this every time the parameters are updated, such as the result of 45 | the `optimizer.step()` call. 46 | 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 55 | one_minus_decay = 1.0 - decay 56 | with torch.no_grad(): 57 | parameters = [p for p in parameters if p.requires_grad] 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | 65 | Args: 66 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 67 | updated with the stored moving averages. 68 | """ 69 | parameters = [p for p in parameters if p.requires_grad] 70 | for s_param, param in zip(self.shadow_params, parameters): 71 | if param.requires_grad: 72 | param.data.copy_(s_param.data) 73 | 74 | def store(self, parameters): 75 | """ 76 | Save the current parameters for restoring later. 77 | 78 | Args: 79 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 80 | temporarily stored. 81 | """ 82 | self.collected_params = [param.clone() for param in parameters] 83 | 84 | def restore(self, parameters): 85 | """ 86 | Restore the parameters stored with the `store` method. 87 | Useful to validate the model with EMA parameters without affecting the 88 | original optimization process. Store the parameters before the 89 | `copy_to` method. After validation (or model saving), use this to 90 | restore the former parameters. 91 | 92 | Args: 93 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 94 | updated with the stored parameters. 95 | """ 96 | for c_param, param in zip(self.collected_params, parameters): 97 | param.data.copy_(c_param.data) 98 | 99 | def state_dict(self): 100 | return dict(decay=self.decay, num_updates=self.num_updates, 101 | shadow_params=self.shadow_params) 102 | 103 | def load_state_dict(self, state_dict): 104 | self.decay = state_dict['decay'] 105 | self.num_updates = state_dict['num_updates'] 106 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /score_sde/models/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /score_sde/models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | from score_sde import sde_lib 21 | import numpy as np 22 | 23 | _MODELS = {} 24 | 25 | 26 | def register_model(cls=None, *, name=None): 27 | """A decorator for registering model classes.""" 28 | 29 | def _register(cls): 30 | if name is None: 31 | local_name = cls.__name__ 32 | else: 33 | local_name = name 34 | if local_name in _MODELS: 35 | raise ValueError(f'Already registered model with name: {local_name}') 36 | _MODELS[local_name] = cls 37 | return cls 38 | 39 | if cls is None: 40 | return _register 41 | else: 42 | return _register(cls) 43 | 44 | 45 | def get_model(name): 46 | return _MODELS[name] 47 | 48 | 49 | def get_sigmas(config): 50 | """Get sigmas --- the set of noise levels for SMLD from config files. 51 | Args: 52 | config: A ConfigDict object parsed from the config file 53 | Returns: 54 | sigmas: a jax numpy arrary of noise levels 55 | """ 56 | sigmas = np.exp( 57 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 58 | 59 | return sigmas 60 | 61 | 62 | def get_ddpm_params(config): 63 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 64 | num_diffusion_timesteps = 1000 65 | # parameters need to be adapted if number of time steps differs from 1000 66 | beta_start = config.model.beta_min / config.model.num_scales 67 | beta_end = config.model.beta_max / config.model.num_scales 68 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 69 | 70 | alphas = 1. - betas 71 | alphas_cumprod = np.cumprod(alphas, axis=0) 72 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 73 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 74 | 75 | return { 76 | 'betas': betas, 77 | 'alphas': alphas, 78 | 'alphas_cumprod': alphas_cumprod, 79 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 80 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 81 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 82 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 83 | 'num_diffusion_timesteps': num_diffusion_timesteps 84 | } 85 | 86 | 87 | def create_model(config): 88 | """Create the score model.""" 89 | model_name = config.model.name 90 | score_model = get_model(model_name)(config) 91 | # score_model = score_model.to(config.device) 92 | # score_model = torch.nn.DataParallel(score_model) 93 | return score_model 94 | 95 | 96 | def get_model_fn(model, train=False): 97 | """Create a function to give the output of the score-based model. 98 | 99 | Args: 100 | model: The score model. 101 | train: `True` for training and `False` for evaluation. 102 | 103 | Returns: 104 | A model function. 105 | """ 106 | 107 | def model_fn(x, labels): 108 | """Compute the output of the score-based model. 109 | 110 | Args: 111 | x: A mini-batch of input data. 112 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 113 | for different models. 114 | 115 | Returns: 116 | A tuple of (model output, new mutable states) 117 | """ 118 | if not train: 119 | model.eval() 120 | return model(x, labels) 121 | else: 122 | model.train() 123 | return model(x, labels) 124 | 125 | return model_fn 126 | 127 | 128 | def get_score_fn(sde, model, train=False, continuous=False): 129 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 130 | 131 | Args: 132 | sde: An `sde_lib.SDE` object that represents the forward SDE. 133 | model: A score model. 134 | train: `True` for training and `False` for evaluation. 135 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 136 | 137 | Returns: 138 | A score function. 139 | """ 140 | model_fn = get_model_fn(model, train=train) 141 | 142 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 143 | def score_fn(x, t): 144 | # Scale neural network output by standard deviation and flip sign 145 | if continuous or isinstance(sde, sde_lib.subVPSDE): 146 | # For VP-trained models, t=0 corresponds to the lowest noise level 147 | # The maximum value of time embedding is assumed to 999 for 148 | # continuously-trained models. 149 | labels = t * 999 150 | score = model_fn(x, labels) 151 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 152 | else: 153 | # For VP-trained models, t=0 corresponds to the lowest noise level 154 | labels = t * (sde.N - 1) 155 | score = model_fn(x, labels) 156 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 157 | 158 | score = -score / std[:, None, None, None] 159 | return score 160 | 161 | elif isinstance(sde, sde_lib.VESDE): 162 | def score_fn(x, t): 163 | if continuous: 164 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 165 | else: 166 | # For VE-trained models, t=0 corresponds to the highest noise level 167 | labels = sde.T - t 168 | labels *= sde.N - 1 169 | labels = torch.round(labels).long() 170 | 171 | score = model_fn(x, labels) 172 | return score 173 | 174 | else: 175 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 176 | 177 | return score_fn 178 | 179 | 180 | def to_flattened_numpy(x): 181 | """Flatten a torch tensor `x` and convert it to numpy.""" 182 | return x.detach().cpu().numpy().reshape((-1,)) 183 | 184 | 185 | def from_flattened_numpy(x, shape): 186 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 187 | return torch.from_numpy(x.reshape(shape)) 188 | -------------------------------------------------------------------------------- /score_sde/op/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/yang-song/score_sde_pytorch/blob/main/op/__init__.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | # --------------------------------------------------------------- 8 | 9 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 10 | from .upfirdn2d import upfirdn2d 11 | -------------------------------------------------------------------------------- /score_sde/op/fused_act.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/yang-song/score_sde_pytorch/blob/main/op/fused_act.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | # --------------------------------------------------------------- 8 | 9 | import os 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.autograd import Function 15 | from torch.utils.cpp_extension import load 16 | 17 | 18 | module_path = os.path.dirname(__file__) 19 | fused = load( 20 | "fused", 21 | sources=[ 22 | os.path.join(module_path, "fused_bias_act.cpp"), 23 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 24 | ], 25 | ) 26 | 27 | 28 | class FusedLeakyReLUFunctionBackward(Function): 29 | @staticmethod 30 | def forward(ctx, grad_output, out, negative_slope, scale): 31 | ctx.save_for_backward(out) 32 | ctx.negative_slope = negative_slope 33 | ctx.scale = scale 34 | 35 | empty = grad_output.new_empty(0) 36 | 37 | grad_input = fused.fused_bias_act( 38 | grad_output, empty, out, 3, 1, negative_slope, scale 39 | ) 40 | 41 | dim = [0] 42 | 43 | if grad_input.ndim > 2: 44 | dim += list(range(2, grad_input.ndim)) 45 | 46 | grad_bias = grad_input.sum(dim).detach() 47 | 48 | return grad_input, grad_bias 49 | 50 | @staticmethod 51 | def backward(ctx, gradgrad_input, gradgrad_bias): 52 | out, = ctx.saved_tensors 53 | gradgrad_out = fused.fused_bias_act( 54 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 55 | ) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | @staticmethod 62 | def forward(ctx, input, bias, negative_slope, scale): 63 | empty = input.new_empty(0) 64 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 65 | ctx.save_for_backward(out) 66 | ctx.negative_slope = negative_slope 67 | ctx.scale = scale 68 | 69 | return out 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | out, = ctx.saved_tensors 74 | 75 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 76 | grad_output, out, ctx.negative_slope, ctx.scale 77 | ) 78 | 79 | return grad_input, grad_bias, None, None 80 | 81 | 82 | class FusedLeakyReLU(nn.Module): 83 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 95 | if input.device.type == "cpu": 96 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 97 | return ( 98 | F.leaky_relu( 99 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 100 | ) 101 | * scale 102 | ) 103 | 104 | else: 105 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 106 | -------------------------------------------------------------------------------- /score_sde/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // --------------------------------------------------------------- 2 | // Taken from the following link as is from: 3 | // https://github.com/yang-song/score_sde_pytorch/blob/main/op/fused_bias_act.cpp 4 | // 5 | // The license for the original version of this file can be 6 | // found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | // --------------------------------------------------------------- 8 | 9 | #include 10 | 11 | 12 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 13 | int act, int grad, float alpha, float scale); 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 20 | int act, int grad, float alpha, float scale) { 21 | CHECK_CUDA(input); 22 | CHECK_CUDA(bias); 23 | 24 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 25 | } 26 | 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 29 | } -------------------------------------------------------------------------------- /score_sde/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /score_sde/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // --------------------------------------------------------------- 2 | // Taken from the following link as is from: 3 | // https://github.com/yang-song/score_sde_pytorch/blob/main/op/upfirdn2d.cpp 4 | // 5 | // The license for the original version of this file can be 6 | // found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | // --------------------------------------------------------------- 8 | 9 | #include 10 | 11 | 12 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 19 | 20 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 21 | int up_x, int up_y, int down_x, int down_y, 22 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 23 | CHECK_CUDA(input); 24 | CHECK_CUDA(kernel); 25 | 26 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /score_sde/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/yang-song/score_sde_pytorch/blob/main/op/upfirdn2d.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in the `score_sde` directory (LICENSE_SCORE_SDE). 7 | # --------------------------------------------------------------- 8 | 9 | import os 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | from torch.autograd import Function 14 | from torch.utils.cpp_extension import load 15 | 16 | 17 | module_path = os.path.dirname(__file__) 18 | upfirdn2d_op = load( 19 | "upfirdn2d", 20 | sources=[ 21 | os.path.join(module_path, "upfirdn2d.cpp"), 22 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 23 | ], 24 | ) 25 | 26 | 27 | class UpFirDn2dBackward(Function): 28 | @staticmethod 29 | def forward( 30 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 31 | ): 32 | 33 | up_x, up_y = up 34 | down_x, down_y = down 35 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 36 | 37 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 38 | 39 | grad_input = upfirdn2d_op.upfirdn2d( 40 | grad_output, 41 | grad_kernel, 42 | down_x, 43 | down_y, 44 | up_x, 45 | up_y, 46 | g_pad_x0, 47 | g_pad_x1, 48 | g_pad_y0, 49 | g_pad_y1, 50 | ) 51 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 52 | 53 | ctx.save_for_backward(kernel) 54 | 55 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 56 | 57 | ctx.up_x = up_x 58 | ctx.up_y = up_y 59 | ctx.down_x = down_x 60 | ctx.down_y = down_y 61 | ctx.pad_x0 = pad_x0 62 | ctx.pad_x1 = pad_x1 63 | ctx.pad_y0 = pad_y0 64 | ctx.pad_y1 = pad_y1 65 | ctx.in_size = in_size 66 | ctx.out_size = out_size 67 | 68 | return grad_input 69 | 70 | @staticmethod 71 | def backward(ctx, gradgrad_input): 72 | kernel, = ctx.saved_tensors 73 | 74 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 75 | 76 | gradgrad_out = upfirdn2d_op.upfirdn2d( 77 | gradgrad_input, 78 | kernel, 79 | ctx.up_x, 80 | ctx.up_y, 81 | ctx.down_x, 82 | ctx.down_y, 83 | ctx.pad_x0, 84 | ctx.pad_x1, 85 | ctx.pad_y0, 86 | ctx.pad_y1, 87 | ) 88 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 89 | gradgrad_out = gradgrad_out.view( 90 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 91 | ) 92 | 93 | return gradgrad_out, None, None, None, None, None, None, None, None 94 | 95 | 96 | class UpFirDn2d(Function): 97 | @staticmethod 98 | def forward(ctx, input, kernel, up, down, pad): 99 | up_x, up_y = up 100 | down_x, down_y = down 101 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 102 | 103 | kernel_h, kernel_w = kernel.shape 104 | batch, channel, in_h, in_w = input.shape 105 | ctx.in_size = input.shape 106 | 107 | input = input.reshape(-1, in_h, in_w, 1) 108 | 109 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 110 | 111 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 112 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 113 | ctx.out_size = (out_h, out_w) 114 | 115 | ctx.up = (up_x, up_y) 116 | ctx.down = (down_x, down_y) 117 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 118 | 119 | g_pad_x0 = kernel_w - pad_x0 - 1 120 | g_pad_y0 = kernel_h - pad_y0 - 1 121 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 122 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 123 | 124 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 125 | 126 | out = upfirdn2d_op.upfirdn2d( 127 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 128 | ) 129 | # out = out.view(major, out_h, out_w, minor) 130 | out = out.view(-1, channel, out_h, out_w) 131 | 132 | return out 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | kernel, grad_kernel = ctx.saved_tensors 137 | 138 | grad_input = UpFirDn2dBackward.apply( 139 | grad_output, 140 | kernel, 141 | grad_kernel, 142 | ctx.up, 143 | ctx.down, 144 | ctx.pad, 145 | ctx.g_pad, 146 | ctx.in_size, 147 | ctx.out_size, 148 | ) 149 | 150 | return grad_input, None, None, None, None 151 | 152 | 153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 154 | if input.device.type == "cpu": 155 | out = upfirdn2d_native( 156 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 157 | ) 158 | 159 | else: 160 | out = UpFirDn2d.apply( 161 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 162 | ) 163 | 164 | return out 165 | 166 | 167 | def upfirdn2d_native( 168 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 169 | ): 170 | _, channel, in_h, in_w = input.shape 171 | input = input.reshape(-1, in_h, in_w, 1) 172 | 173 | _, in_h, in_w, minor = input.shape 174 | kernel_h, kernel_w = kernel.shape 175 | 176 | out = input.view(-1, in_h, 1, in_w, 1, minor) 177 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 178 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 179 | 180 | out = F.pad( 181 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 182 | ) 183 | out = out[ 184 | :, 185 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 186 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 187 | :, 188 | ] 189 | 190 | out = out.permute(0, 3, 1, 2) 191 | out = out.reshape( 192 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 193 | ) 194 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 195 | out = F.conv2d(out, w) 196 | out = out.reshape( 197 | -1, 198 | minor, 199 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 200 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 201 | ) 202 | out = out.permute(0, 2, 3, 1) 203 | out = out[:, ::down_y, ::down_x, :] 204 | 205 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 206 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 207 | 208 | return out.view(-1, channel, out_h, out_w) 209 | -------------------------------------------------------------------------------- /score_sde/sde_lib.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/yang-song/score_sde_pytorch/blob/main/sde_lib.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_SCORE_SDE). 7 | # --------------------------------------------------------------- 8 | 9 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 10 | import abc 11 | import torch 12 | import numpy as np 13 | 14 | 15 | class SDE(abc.ABC): 16 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 17 | 18 | def __init__(self, N): 19 | """Construct an SDE. 20 | 21 | Args: 22 | N: number of discretization time steps. 23 | """ 24 | super().__init__() 25 | self.N = N 26 | 27 | @property 28 | @abc.abstractmethod 29 | def T(self): 30 | """End time of the SDE.""" 31 | pass 32 | 33 | @abc.abstractmethod 34 | def sde(self, x, t): 35 | pass 36 | 37 | @abc.abstractmethod 38 | def marginal_prob(self, x, t): 39 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" 40 | pass 41 | 42 | @abc.abstractmethod 43 | def prior_sampling(self, shape): 44 | """Generate one sample from the prior distribution, $p_T(x)$.""" 45 | pass 46 | 47 | @abc.abstractmethod 48 | def prior_logp(self, z): 49 | """Compute log-density of the prior distribution. 50 | 51 | Useful for computing the log-likelihood via probability flow ODE. 52 | 53 | Args: 54 | z: latent code 55 | Returns: 56 | log probability density 57 | """ 58 | pass 59 | 60 | def discretize(self, x, t): 61 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 62 | 63 | Useful for reverse diffusion sampling and probabiliy flow sampling. 64 | Defaults to Euler-Maruyama discretization. 65 | 66 | Args: 67 | x: a torch tensor 68 | t: a torch float representing the time step (from 0 to `self.T`) 69 | 70 | Returns: 71 | f, G 72 | """ 73 | dt = 1 / self.N 74 | drift, diffusion = self.sde(x, t) 75 | f = drift * dt 76 | G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) 77 | return f, G 78 | 79 | def reverse(self, score_fn, probability_flow=False): 80 | """Create the reverse-time SDE/ODE. 81 | 82 | Args: 83 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 84 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 85 | """ 86 | N = self.N 87 | T = self.T 88 | sde_fn = self.sde 89 | discretize_fn = self.discretize 90 | 91 | # Build the class for reverse-time SDE. 92 | class RSDE(self.__class__): 93 | def __init__(self): 94 | self.N = N 95 | self.probability_flow = probability_flow 96 | 97 | @property 98 | def T(self): 99 | return T 100 | 101 | def sde(self, x, t): 102 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 103 | drift, diffusion = sde_fn(x, t) 104 | score = score_fn(x, t) 105 | drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 106 | # Set the diffusion function to zero for ODEs. 107 | diffusion = 0. if self.probability_flow else diffusion 108 | return drift, diffusion 109 | 110 | def discretize(self, x, t): 111 | """Create discretized iteration rules for the reverse diffusion sampler.""" 112 | f, G = discretize_fn(x, t) 113 | rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.) 114 | rev_G = torch.zeros_like(G) if self.probability_flow else G 115 | return rev_f, rev_G 116 | 117 | return RSDE() 118 | 119 | 120 | class VPSDE(SDE): 121 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 122 | """Construct a Variance Preserving SDE. 123 | 124 | Args: 125 | beta_min: value of beta(0) 126 | beta_max: value of beta(1) 127 | N: number of discretization steps 128 | """ 129 | super().__init__(N) 130 | self.beta_0 = beta_min 131 | self.beta_1 = beta_max 132 | self.N = N 133 | self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) 134 | self.alphas = 1. - self.discrete_betas 135 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 136 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 137 | self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 138 | 139 | @property 140 | def T(self): 141 | return 1 142 | 143 | def sde(self, x, t): 144 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 145 | drift = -0.5 * beta_t[:, None, None, None] * x 146 | diffusion = torch.sqrt(beta_t) 147 | return drift, diffusion 148 | 149 | def marginal_prob(self, x, t): 150 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 151 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x 152 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 153 | return mean, std 154 | 155 | def prior_sampling(self, shape): 156 | return torch.randn(*shape) 157 | 158 | def prior_logp(self, z): 159 | shape = z.shape 160 | N = np.prod(shape[1:]) 161 | logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 162 | return logps 163 | 164 | def discretize(self, x, t): 165 | """DDPM discretization.""" 166 | timestep = (t * (self.N - 1) / self.T).long() 167 | beta = self.discrete_betas.to(x.device)[timestep] 168 | alpha = self.alphas.to(x.device)[timestep] 169 | sqrt_beta = torch.sqrt(beta) 170 | f = torch.sqrt(alpha)[:, None, None, None] * x - x 171 | G = sqrt_beta 172 | return f, G 173 | 174 | 175 | class subVPSDE(SDE): 176 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 177 | """Construct the sub-VP SDE that excels at likelihoods. 178 | 179 | Args: 180 | beta_min: value of beta(0) 181 | beta_max: value of beta(1) 182 | N: number of discretization steps 183 | """ 184 | super().__init__(N) 185 | self.beta_0 = beta_min 186 | self.beta_1 = beta_max 187 | self.N = N 188 | 189 | @property 190 | def T(self): 191 | return 1 192 | 193 | def sde(self, x, t): 194 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 195 | drift = -0.5 * beta_t[:, None, None, None] * x 196 | discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) 197 | diffusion = torch.sqrt(beta_t * discount) 198 | return drift, diffusion 199 | 200 | def marginal_prob(self, x, t): 201 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 202 | mean = torch.exp(log_mean_coeff)[:, None, None, None] * x 203 | std = 1 - torch.exp(2. * log_mean_coeff) 204 | return mean, std 205 | 206 | def prior_sampling(self, shape): 207 | return torch.randn(*shape) 208 | 209 | def prior_logp(self, z): 210 | shape = z.shape 211 | N = np.prod(shape[1:]) 212 | return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 213 | 214 | 215 | class VESDE(SDE): 216 | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): 217 | """Construct a Variance Exploding SDE. 218 | 219 | Args: 220 | sigma_min: smallest sigma. 221 | sigma_max: largest sigma. 222 | N: number of discretization steps 223 | """ 224 | super().__init__(N) 225 | self.sigma_min = sigma_min 226 | self.sigma_max = sigma_max 227 | self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) 228 | self.N = N 229 | 230 | @property 231 | def T(self): 232 | return 1 233 | 234 | def sde(self, x, t): 235 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 236 | drift = torch.zeros_like(x) 237 | diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), 238 | device=t.device)) 239 | return drift, diffusion 240 | 241 | def marginal_prob(self, x, t): 242 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 243 | mean = x 244 | return mean, std 245 | 246 | def prior_sampling(self, shape): 247 | return torch.randn(*shape) * self.sigma_max 248 | 249 | def prior_logp(self, z): 250 | shape = z.shape 251 | N = np.prod(shape[1:]) 252 | return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) 253 | 254 | def discretize(self, x, t): 255 | """SMLD(NCSN) discretization.""" 256 | timestep = (t * (self.N - 1) / self.T).long() 257 | sigma = self.discrete_sigmas.to(t.device)[timestep] 258 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), 259 | self.discrete_sigmas[timestep - 1].to(t.device)) 260 | f = torch.zeros_like(x) 261 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) 262 | return f, G -------------------------------------------------------------------------------- /stadv_eot/attacks.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for DiffPure. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import functools 9 | import torch 10 | from torch import nn 11 | from torch import optim 12 | 13 | # mister_ed 14 | from .recoloradv.mister_ed import loss_functions as lf 15 | from .recoloradv.mister_ed import adversarial_training as advtrain 16 | from .recoloradv.mister_ed import adversarial_perturbations as ap 17 | from .recoloradv.mister_ed import adversarial_attacks as aa 18 | from .recoloradv.mister_ed import spatial_transformers as st 19 | 20 | 21 | PGD_ITERS = 20 22 | 23 | 24 | def run_attack_with_random_targets(attack, model, inputs, labels, num_classes): 25 | """ 26 | Runs an attack with targets randomly selected from all classes besides the 27 | correct one. The attack should be a function from (inputs, labels) to 28 | adversarial examples. 29 | """ 30 | 31 | rand_targets = torch.randint( 32 | 0, num_classes - 1, labels.size(), 33 | dtype=labels.dtype, device=labels.device, 34 | ) 35 | targets = torch.remainder(labels + rand_targets + 1, num_classes) 36 | 37 | adv_inputs = attack(inputs, targets) 38 | adv_labels = model(adv_inputs).argmax(1) 39 | unsuccessful = adv_labels != targets 40 | adv_inputs[unsuccessful] = inputs[unsuccessful] 41 | 42 | return adv_inputs 43 | 44 | 45 | class MisterEdAttack(nn.Module): 46 | """ 47 | Base class for attacks using the mister_ed library. 48 | """ 49 | 50 | def __init__(self, model, threat_model, randomize=False, 51 | perturbation_norm_loss=False, lr=0.001, random_targets=False, 52 | num_classes=None, **kwargs): 53 | super().__init__() 54 | 55 | self.model = model 56 | self.normalizer = nn.Identity() 57 | 58 | self.threat_model = threat_model 59 | self.randomize = randomize 60 | self.perturbation_norm_loss = perturbation_norm_loss 61 | self.attack_kwargs = kwargs 62 | self.lr = lr 63 | self.random_targets = random_targets 64 | self.num_classes = num_classes 65 | 66 | self.attack = None 67 | 68 | def _setup_attack(self): 69 | cw_loss = lf.CWLossF6(self.model, self.normalizer, kappa=float('inf')) 70 | if self.random_targets: 71 | cw_loss.forward = functools.partial(cw_loss.forward, targeted=True) 72 | perturbation_loss = lf.PerturbationNormLoss(lp=2) 73 | pert_factor = 0.0 74 | if self.perturbation_norm_loss is True: 75 | pert_factor = 0.05 76 | elif type(self.perturbation_norm_loss) is float: 77 | pert_factor = self.perturbation_norm_loss 78 | adv_loss = lf.RegularizedLoss({ 79 | 'cw': cw_loss, 80 | 'pert': perturbation_loss, 81 | }, { 82 | 'cw': 1.0, 83 | 'pert': pert_factor, 84 | }, negate=True) 85 | 86 | self.pgd_attack = aa.PGD(self.model, self.normalizer, 87 | self.threat_model(), adv_loss) 88 | 89 | attack_params = { 90 | 'optimizer': optim.Adam, 91 | 'optimizer_kwargs': {'lr': self.lr}, 92 | 'signed': False, 93 | 'verbose': False, 94 | 'num_iterations': 0 if self.randomize else PGD_ITERS, 95 | 'random_init': self.randomize, 96 | } 97 | attack_params.update(self.attack_kwargs) 98 | 99 | self.attack = advtrain.AdversarialAttackParameters( 100 | self.pgd_attack, 101 | 1.0, 102 | attack_specific_params={'attack_kwargs': attack_params}, 103 | ) 104 | self.attack.set_gpu(False) 105 | 106 | def forward(self, inputs, labels): 107 | if self.attack is None: 108 | self._setup_attack() 109 | assert self.attack is not None 110 | 111 | if self.random_targets: 112 | return run_attack_with_random_targets( 113 | lambda inputs, labels: self.attack.attack(inputs, labels)[0], 114 | self.model, 115 | inputs, 116 | labels, 117 | num_classes=self.num_classes, 118 | ) 119 | else: 120 | return self.attack.attack(inputs, labels)[0] 121 | 122 | 123 | class StAdvAttack(MisterEdAttack): 124 | def __init__(self, model, bound=0.05, **kwargs): 125 | kwargs.setdefault('lr', 0.01) 126 | super().__init__( 127 | model, 128 | threat_model=lambda: ap.ThreatModel(ap.ParameterizedXformAdv, { 129 | 'lp_style': 'inf', 130 | 'lp_bound': bound, 131 | 'xform_class': st.FullSpatial, 132 | 'use_stadv': True, 133 | }), 134 | perturbation_norm_loss=0.0025 / bound, 135 | **kwargs, 136 | ) 137 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/LICENSE_RECOLORADV: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2018 YOUR NAME 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. -------------------------------------------------------------------------------- /stadv_eot/recoloradv/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/__init__.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/color_spaces.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/color_spaces.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | """ 15 | Contains classes that convert from RGB to various other color spaces and back. 16 | """ 17 | 18 | import torch 19 | import numpy as np 20 | import math 21 | 22 | 23 | class ColorSpace(object): 24 | """ 25 | Base class for color spaces. 26 | """ 27 | 28 | def from_rgb(self, imgs): 29 | """ 30 | Converts an Nx3xWxH tensor in RGB color space to a Nx3xWxH tensor in 31 | this color space. All outputs should be in the 0-1 range. 32 | """ 33 | raise NotImplementedError() 34 | 35 | def to_rgb(self, imgs): 36 | """ 37 | Converts an Nx3xWxH tensor in this color space to a Nx3xWxH tensor in 38 | RGB color space. 39 | """ 40 | raise NotImplementedError() 41 | 42 | 43 | class RGBColorSpace(ColorSpace): 44 | """ 45 | RGB color space. Just applies identity transformation. 46 | """ 47 | 48 | def from_rgb(self, imgs): 49 | return imgs 50 | 51 | def to_rgb(self, imgs): 52 | return imgs 53 | 54 | 55 | class YPbPrColorSpace(ColorSpace): 56 | """ 57 | YPbPr color space. Uses ITU-R BT.601 standard by default. 58 | """ 59 | 60 | def __init__(self, kr=0.299, kg=0.587, kb=0.114, luma_factor=1, 61 | chroma_factor=1): 62 | self.kr, self.kg, self.kb = kr, kg, kb 63 | self.luma_factor = luma_factor 64 | self.chroma_factor = chroma_factor 65 | 66 | def from_rgb(self, imgs): 67 | r, g, b = imgs.permute(1, 0, 2, 3) 68 | 69 | y = r * self.kr + g * self.kg + b * self.kb 70 | pb = (b - y) / (2 * (1 - self.kb)) 71 | pr = (r - y) / (2 * (1 - self.kr)) 72 | 73 | return torch.stack([y * self.luma_factor, 74 | pb * self.chroma_factor + 0.5, 75 | pr * self.chroma_factor + 0.5], 1) 76 | 77 | def to_rgb(self, imgs): 78 | y_prime, pb_prime, pr_prime = imgs.permute(1, 0, 2, 3) 79 | y = y_prime / self.luma_factor 80 | pb = (pb_prime - 0.5) / self.chroma_factor 81 | pr = (pr_prime - 0.5) / self.chroma_factor 82 | 83 | b = pb * 2 * (1 - self.kb) + y 84 | r = pr * 2 * (1 - self.kr) + y 85 | g = (y - r * self.kr - b * self.kb) / self.kg 86 | 87 | return torch.stack([r, g, b], 1).clamp(0, 1) 88 | 89 | 90 | class ApproxHSVColorSpace(ColorSpace): 91 | """ 92 | Converts from RGB to approximately the HSV cone using a much smoother 93 | transformation. 94 | """ 95 | 96 | def from_rgb(self, imgs): 97 | r, g, b = imgs.permute(1, 0, 2, 3) 98 | 99 | x = r * np.sqrt(2) / 3 - g / (np.sqrt(2) * 3) - b / (np.sqrt(2) * 3) 100 | y = g / np.sqrt(6) - b / np.sqrt(6) 101 | z, _ = imgs.max(1) 102 | 103 | return torch.stack([z, x + 0.5, y + 0.5], 1) 104 | 105 | def to_rgb(self, imgs): 106 | z, xp, yp = imgs.permute(1, 0, 2, 3) 107 | x, y = xp - 0.5, yp - 0.5 108 | 109 | rp = float(np.sqrt(2)) * x 110 | gp = -x / np.sqrt(2) + y * np.sqrt(3 / 2) 111 | bp = -x / np.sqrt(2) - y * np.sqrt(3 / 2) 112 | 113 | delta = z - torch.max(torch.stack([rp, gp, bp], 1), 1)[0] 114 | r, g, b = rp + delta, gp + delta, bp + delta 115 | 116 | return torch.stack([r, g, b], 1).clamp(0, 1) 117 | 118 | 119 | class HSVConeColorSpace(ColorSpace): 120 | """ 121 | Converts from RGB to the HSV "cone", where (x, y, z) = 122 | (s * v cos h, s * v sin h, v). Note that this cone is then squashed to fit 123 | in [0, 1]^3 by letting (x', y', z') = ((x + 1) / 2, (y + 1) / 2, z). 124 | 125 | WARNING: has a very complex derivative, not very useful in practice 126 | """ 127 | 128 | def from_rgb(self, imgs): 129 | r, g, b = imgs.permute(1, 0, 2, 3) 130 | 131 | mx, argmx = imgs.max(1) 132 | mn, _ = imgs.min(1) 133 | chroma = mx - mn 134 | eps = 1e-10 135 | h_max_r = math.pi / 3 * (g - b) / (chroma + eps) 136 | h_max_g = math.pi / 3 * (b - r) / (chroma + eps) + math.pi * 2 / 3 137 | h_max_b = math.pi / 3 * (r - g) / (chroma + eps) + math.pi * 4 / 3 138 | 139 | h = (((argmx == 0) & (chroma != 0)).float() * h_max_r 140 | + ((argmx == 1) & (chroma != 0)).float() * h_max_g 141 | + ((argmx == 2) & (chroma != 0)).float() * h_max_b) 142 | 143 | x = torch.cos(h) * chroma 144 | y = torch.sin(h) * chroma 145 | z = mx 146 | 147 | return torch.stack([(x + 1) / 2, (y + 1) / 2, z], 1) 148 | 149 | def _to_rgb_part(self, h, chroma, v, n): 150 | """ 151 | Implements the function f(n) defined here: 152 | https://en.wikipedia.org/wiki/HSL_and_HSV#Alternative_HSV_to_RGB 153 | """ 154 | 155 | k = (n + h * math.pi / 3) % 6 156 | return v - chroma * torch.min(k, 4 - k).clamp(0, 1) 157 | 158 | def to_rgb(self, imgs): 159 | xp, yp, z = imgs.permute(1, 0, 2, 3) 160 | x, y = xp * 2 - 1, yp * 2 - 1 161 | 162 | # prevent NaN gradients when calculating atan2 163 | x_nonzero = (1 - 2 * (torch.sign(x) == -1).float()) * (torch.abs(x) + 1e-10) 164 | h = torch.atan2(y, x_nonzero) 165 | v = z.clamp(0, 1) 166 | chroma = torch.min(torch.sqrt(x ** 2 + y ** 2 + 1e-10), v) 167 | 168 | r = self._to_rgb_part(h, chroma, v, 5) 169 | g = self._to_rgb_part(h, chroma, v, 3) 170 | b = self._to_rgb_part(h, chroma, v, 1) 171 | 172 | return torch.stack([r, g, b], 1).clamp(0, 1) 173 | 174 | 175 | class CIEXYZColorSpace(ColorSpace): 176 | """ 177 | The 1931 CIE XYZ color space (assuming input is in sRGB). 178 | 179 | Warning: may have values outside [0, 1] range. Should only be used in 180 | the process of converting to/from other color spaces. 181 | """ 182 | 183 | def from_rgb(self, imgs): 184 | # apply gamma correction 185 | small_values_mask = (imgs < 0.04045).float() 186 | imgs_corrected = ( 187 | (imgs / 12.92) * small_values_mask + 188 | ((imgs + 0.055) / 1.055) ** 2.4 * (1 - small_values_mask) 189 | ) 190 | 191 | # linear transformation to XYZ 192 | r, g, b = imgs_corrected.permute(1, 0, 2, 3) 193 | x = 0.4124 * r + 0.3576 * g + 0.1805 * b 194 | y = 0.2126 * r + 0.7152 * g + 0.0722 * b 195 | z = 0.0193 * r + 0.1192 * g + 0.9504 * b 196 | 197 | return torch.stack([x, y, z], 1) 198 | 199 | def to_rgb(self, imgs): 200 | # linear transformation 201 | x, y, z = imgs.permute(1, 0, 2, 3) 202 | r = 3.2406 * x - 1.5372 * y - 0.4986 * z 203 | g = -0.9689 * x + 1.8758 * y + 0.0415 * z 204 | b = 0.0557 * x - 0.2040 * y + 1.0570 * z 205 | 206 | imgs = torch.stack([r, g, b], 1) 207 | 208 | # apply gamma correction 209 | small_values_mask = (imgs < 0.0031308).float() 210 | imgs_clamped = imgs.clamp(min=1e-10) # prevent NaN gradients 211 | imgs_corrected = ( 212 | (12.92 * imgs) * small_values_mask + 213 | (1.055 * imgs_clamped ** (1 / 2.4) - 0.055) * 214 | (1 - small_values_mask) 215 | ) 216 | 217 | return imgs_corrected 218 | 219 | 220 | class CIELUVColorSpace(ColorSpace): 221 | """ 222 | Converts to the 1976 CIE L*u*v* color space. 223 | """ 224 | 225 | def __init__(self, up_white=0.1978, vp_white=0.4683, y_white=1, 226 | eps=1e-10): 227 | self.xyz_cspace = CIEXYZColorSpace() 228 | self.up_white = up_white 229 | self.vp_white = vp_white 230 | self.y_white = y_white 231 | self.eps = eps 232 | 233 | def from_rgb(self, imgs): 234 | x, y, z = self.xyz_cspace.from_rgb(imgs).permute(1, 0, 2, 3) 235 | 236 | # calculate u' and v' 237 | denom = x + 15 * y + 3 * z + self.eps 238 | up = 4 * x / denom 239 | vp = 9 * y / denom 240 | 241 | # calculate L*, u*, and v* 242 | small_values_mask = (y / self.y_white < (6 / 29) ** 3).float() 243 | y_clamped = y.clamp(min=self.eps) # prevent NaN gradients 244 | L = ( 245 | ((29 / 3) ** 3 * y / self.y_white) * small_values_mask + 246 | (116 * (y_clamped / self.y_white) ** (1 / 3) - 16) * 247 | (1 - small_values_mask) 248 | ) 249 | u = 13 * L * (up - self.up_white) 250 | v = 13 * L * (vp - self.vp_white) 251 | 252 | return torch.stack([L / 100, (u + 100) / 200, (v + 100) / 200], 1) 253 | 254 | def to_rgb(self, imgs): 255 | L = imgs[:, 0, :, :] * 100 256 | u = imgs[:, 1, :, :] * 200 - 100 257 | v = imgs[:, 2, :, :] * 200 - 100 258 | 259 | up = u / (13 * L + self.eps) + self.up_white 260 | vp = v / (13 * L + self.eps) + self.vp_white 261 | 262 | small_values_mask = (L <= 8).float() 263 | y = ( 264 | (self.y_white * L * (3 / 29) ** 3) * small_values_mask + 265 | (self.y_white * ((L + 16) / 116) ** 3) * (1 - small_values_mask) 266 | ) 267 | denom = 4 * vp + self.eps 268 | x = y * 9 * up / denom 269 | z = y * (12 - 3 * up - 20 * vp) / denom 270 | 271 | return self.xyz_cspace.to_rgb( 272 | torch.stack([x, y, z], 1).clamp(0, 1.1)).clamp(0, 1) 273 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/README.md: -------------------------------------------------------------------------------- 1 | Code in this directory is adapted from the [`mister_ed`](https://github.com/revbucket/mister_ed) library. -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/mister_ed/__init__.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in the `recoloradv` directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/config.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/mister_ed/config.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in the `recoloradv` directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import os 15 | 16 | config_dir = os.path.abspath(os.path.dirname(__file__)) 17 | 18 | 19 | def path_resolver(path): 20 | if path.startswith('~/'): 21 | return os.path.expanduser(path) 22 | 23 | if path.startswith('./'): 24 | return os.path.join(*[config_dir] + path.split('/')[1:]) 25 | 26 | 27 | DEFAULT_DATASETS_DIR = path_resolver('~/datasets') 28 | MODEL_PATH = path_resolver('./pretrained_models/') 29 | OUTPUT_IMAGE_PATH = path_resolver('./output_images/') 30 | 31 | DEFAULT_BATCH_SIZE = 128 32 | DEFAULT_WORKERS = 4 33 | CIFAR10_MEANS = [0.485, 0.456, 0.406] 34 | CIFAR10_STDS = [0.229, 0.224, 0.225] 35 | 36 | WIDE_CIFAR10_MEANS = [0.4914, 0.4822, 0.4465] 37 | WIDE_CIFAR10_STDS = [0.2023, 0.1994, 0.2010] 38 | 39 | IMAGENET_MEANS = [0.485, 0.456, 0.406] 40 | IMAGENET_STDS = [0.229, 0.224, 0.225] 41 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/mister_ed/utils/__init__.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in the `recoloradv` directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/mister_ed/utils/image_utils.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in the `recoloradv` directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | """ Specific utilities for image classification 15 | (i.e. RGB images i.e. tensors of the form NxCxHxW ) 16 | """ 17 | 18 | from __future__ import print_function 19 | import torch 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | import random 23 | 24 | 25 | def nhwc255_xform(img_np_array): 26 | """ Takes in a numpy array and transposes it so that the channel is the last 27 | axis. Also multiplies all values by 255.0 28 | ARGS: 29 | img_np_array : np.ndarray - array of shape (NxHxWxC) or (NxCxHxW) 30 | [assumes that we're in NCHW by default, 31 | but if not ambiguous will handle NHWC too ] 32 | RETURNS: 33 | array of form NHWC 34 | """ 35 | assert isinstance(img_np_array, np.ndarray) 36 | shape = img_np_array.shape 37 | assert len(shape) == 4 38 | 39 | # determine which configuration we're in 40 | ambiguous = (shape[1] == shape[3] == 3) 41 | nhwc = (shape[1] == 3) 42 | 43 | # transpose unless we're unambiguously in nhwc case 44 | if nhwc and not ambiguous: 45 | return img_np_array * 255.0 46 | else: 47 | return np.transpose(img_np_array, (0, 2, 3, 1)) * 255.0 48 | 49 | 50 | def show_images(images, normalize=None, ipython=True, 51 | margin_height=2, margin_color='red', 52 | figsize=(18, 16)): 53 | """ Shows pytorch tensors/variables as images """ 54 | 55 | # first format the first arg to be hz-stacked numpy arrays 56 | if not isinstance(images, list): 57 | images = [images] 58 | images = [np.dstack(image.cpu().numpy()) for image in images] 59 | image_shape = images[0].shape 60 | assert all(image.shape == image_shape for image in images) 61 | assert all(image.ndim == 3 for image in images) # CxHxW 62 | 63 | # now build the list of final rows 64 | rows = [] 65 | if margin_height > 0: 66 | assert margin_color in ['red', 'black'] 67 | margin_shape = list(image_shape) 68 | margin_shape[1] = margin_height 69 | margin = np.zeros(margin_shape) 70 | if margin_color == 'red': 71 | margin[0] = 1 72 | else: 73 | margin = None 74 | 75 | for image_row in images: 76 | rows.append(margin) 77 | rows.append(image_row) 78 | 79 | rows = [_ for _ in rows[1:] if _ is not None] 80 | plt.figure(figsize=figsize, dpi=80, facecolor='w', edgecolor='k') 81 | 82 | cat_rows = np.concatenate(rows, 1).transpose(1, 2, 0) 83 | imshow_kwargs = {} 84 | if cat_rows.shape[-1] == 1: # 1 channel: greyscale 85 | cat_rows = cat_rows.squeeze() 86 | imshow_kwargs['cmap'] = 'gray' 87 | 88 | plt.imshow(cat_rows, **imshow_kwargs) 89 | 90 | plt.show() 91 | 92 | 93 | def display_adversarial_2row(classifier_net, normalizer, original_images, 94 | adversarial_images, num_to_show=4, which='incorrect', 95 | ipython=False, margin_width=2): 96 | """ Displays adversarial images side-by-side with their unperturbed 97 | counterparts. Opens a window displaying two rows: top row is original 98 | images, bottom row is perturbed 99 | ARGS: 100 | classifier_net : nn - with a .forward method that takes normalized 101 | variables and outputs logits 102 | normalizer : object w/ .forward method - should probably be an instance 103 | of utils.DifferentiableNormalize or utils.IdentityNormalize 104 | original_images: Variable or Tensor (NxCxHxW) - original images to 105 | display. Images in [0., 1.] range 106 | adversarial_images: Variable or Tensor (NxCxHxW) - perturbed images to 107 | display. Should be same shape as original_images 108 | num_to_show : int - number of images to show 109 | which : string in ['incorrect', 'random', 'correct'] - which images to 110 | show. 111 | -- 'incorrect' means successfully attacked images, 112 | -- 'random' means some random selection of images 113 | -- 'correct' means unsuccessfully attacked images 114 | ipython: bool - if True, we use in an ipython notebook so slightly 115 | different way to show Images 116 | margin_width - int : height in pixels of the red margin separating top 117 | and bottom rows. Set to 0 for no margin 118 | RETURNS: 119 | None, but displays images 120 | """ 121 | assert which in ['incorrect', 'random', 'correct'] 122 | 123 | # If not 'random' selection, prune to only the valid things 124 | to_sample_idxs = [] 125 | if which != 'random': 126 | classifier_net.eval() # can never be too safe =) 127 | 128 | # classify the originals with top1 129 | original_norm_var = normalizer.forward(original_images) 130 | original_out_logits = classifier_net.forward(original_norm_var) 131 | _, original_out_classes = original_out_logits.max(1) 132 | 133 | # classify the adversarials with top1 134 | adv_norm_var = normalizer.forward(adversarial_images) 135 | adv_out_logits = classifier_net.forward(adv_norm_var) 136 | _, adv_out_classes = adv_out_logits.max(1) 137 | 138 | # collect indices of matching 139 | selector = lambda var: (which == 'correct') == bool(float(var)) 140 | for idx, var_el in enumerate(original_out_classes == adv_out_classes): 141 | if selector(var_el): 142 | to_sample_idxs.append(idx) 143 | else: 144 | to_sample_idxs = list(range(original_images.shape[0])) 145 | 146 | # Now select some indices to show 147 | if to_sample_idxs == []: 148 | print("Couldn't show anything. Try changing the 'which' argument here") 149 | return 150 | 151 | to_show_idxs = random.sample(to_sample_idxs, min([num_to_show, 152 | len(to_sample_idxs)])) 153 | 154 | # Now start building up the images : first horizontally, then vertically 155 | top_row = torch.cat([original_images[idx] for idx in to_show_idxs], dim=2) 156 | bottom_row = torch.cat([adversarial_images[idx] for idx in to_show_idxs], 157 | dim=2) 158 | 159 | if margin_width > 0: 160 | margin = torch.zeros(3, margin_width, top_row.shape[-1]) 161 | margin[0] = 1.0 # make it red 162 | margin = margin.type(type(top_row)) 163 | stack = [top_row, margin, bottom_row] 164 | else: 165 | stack = [top_row, bottom_row] 166 | 167 | plt.imshow(torch.cat(stack, dim=1).cpu().numpy().transpose(1, 2, 0)) 168 | plt.show() 169 | 170 | 171 | def display_adversarial_notebook(): 172 | pass 173 | 174 | 175 | def nchw_l2(x, y, squared=True): 176 | """ Computes l2 norm between two NxCxHxW images 177 | ARGS: 178 | x, y: Tensor/Variable (NxCxHxW) - x, y must be same type & shape. 179 | squared : bool - if True we return squared loss, otherwise we return 180 | square root of l2 181 | RETURNS: 182 | ||x - y ||_2 ^2 (no exponent if squared == False), 183 | shape is (Nx1x1x1) 184 | """ 185 | temp = torch.pow(x - y, 2) # square diff 186 | 187 | for i in range(1, temp.dim()): # reduce on all but first dimension 188 | temp = torch.sum(temp, i, keepdim=True) 189 | 190 | if not squared: 191 | temp = torch.pow(temp, 0.5) 192 | 193 | return temp.squeeze() 194 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/mister_ed/utils/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/mister_ed/utils/pytorch_ssim.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in the `recoloradv` directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | """ Implementation directly lifted from Po-Hsun-Su for pytorch ssim 15 | See github repo here: https://github.com/Po-Hsun-Su/pytorch-ssim 16 | """ 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | from torch.autograd import Variable 21 | from math import exp 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 25 | return gauss/gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 34 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 35 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1*mu2 40 | 41 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 42 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 43 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 44 | 45 | C1 = 0.01**2 46 | C2 = 0.03**2 47 | 48 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 49 | 50 | if size_average: 51 | return ssim_map.mean() 52 | else: 53 | return ssim_map.mean(1).mean(1).mean(1) 54 | 55 | class SSIM(torch.nn.Module): 56 | def __init__(self, window_size = 11, size_average = True): 57 | super(SSIM, self).__init__() 58 | self.window_size = window_size 59 | self.size_average = size_average 60 | self.channel = 1 61 | self.window = create_window(window_size, self.channel) 62 | 63 | def forward(self, img1, img2): 64 | (_, channel, _, _) = img1.size() 65 | 66 | if channel == self.channel and self.window.data.type() == img1.data.type(): 67 | window = self.window 68 | else: 69 | window = create_window(self.window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | self.window = window 76 | self.channel = channel 77 | 78 | 79 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 80 | 81 | def ssim(img1, img2, window_size = 11, size_average = True): 82 | (_, channel, _, _) = img1.size() 83 | window = create_window(window_size, channel) 84 | 85 | if img1.is_cuda: 86 | window = window.cuda(img1.get_device()) 87 | window = window.type_as(img1) 88 | 89 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /stadv_eot/recoloradv/norms.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/norms.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import torch 15 | from torch.autograd import Variable 16 | 17 | 18 | def smoothness(grid): 19 | """ 20 | Given a variable of dimensions (N, X, Y, [Z], C), computes the sum of 21 | the differences between adjacent points in the grid formed by the 22 | dimensions X, Y, and (optionally) Z. Returns a tensor of dimension N. 23 | """ 24 | 25 | num_dims = len(grid.size()) - 2 26 | batch_size = grid.size()[0] 27 | norm = Variable(torch.zeros(batch_size, dtype=grid.data.dtype, 28 | device=grid.data.device)) 29 | 30 | for dim in range(num_dims): 31 | slice_before = (slice(None),) * (dim + 1) 32 | slice_after = (slice(None),) * (num_dims - dim) 33 | shifted_grids = [ 34 | # left 35 | torch.cat([ 36 | grid[slice_before + (slice(1, None),) + slice_after], 37 | grid[slice_before + (slice(-1, None),) + slice_after], 38 | ], dim + 1), 39 | # right 40 | torch.cat([ 41 | grid[slice_before + (slice(None, 1),) + slice_after], 42 | grid[slice_before + (slice(None, -1),) + slice_after], 43 | ], dim + 1) 44 | ] 45 | for shifted_grid in shifted_grids: 46 | delta = shifted_grid - grid 47 | norm_components = (delta.pow(2).sum(-1) + 1e-10).pow(0.5) 48 | norm.add_(norm_components.sum( 49 | tuple(range(1, len(norm_components.size()))))) 50 | 51 | return norm 52 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/perturbations.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/perturbations.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | from .mister_ed import adversarial_perturbations as ap 15 | from .mister_ed.adversarial_perturbations import initialized 16 | from .mister_ed.utils import pytorch_utils as utils 17 | 18 | from . import color_transformers as ct 19 | from . import color_spaces as cs 20 | 21 | 22 | class ReColorAdv(ap.AdversarialPerturbation): 23 | """ 24 | Puts the color at each pixel in the image through the same transformation. 25 | 26 | Parameters: 27 | - lp_style: number or 'inf' 28 | - lp_bound: maximum norm of color transformation. Can be a tensor of size 29 | (num_channels,), in which case each channel will be bounded by the 30 | cooresponding bound in the tensor. For instance, passing 31 | [0.1, 0.15, 0.05] would allow a norm of 0.1 for R, 0.15 for G, and 0.05 32 | for B. Not supported by all transformations. 33 | - use_smooth_loss: whether to optimize using the loss function 34 | for FullSpatial that rewards smooth vector fields 35 | - xform_class: a subclass of 36 | color_transformers.ParameterizedTransformation 37 | - xform_params: dict of parameters to pass to the xform_class. 38 | - cspace_class: a subclass of color_spaces.ColorSpace that indicates 39 | in which color space the transformation should be performed 40 | (RGB by default) 41 | """ 42 | 43 | def __init__(self, threat_model, perturbation_params, *other_args): 44 | super().__init__(threat_model, perturbation_params) 45 | assert issubclass(perturbation_params.xform_class, 46 | ct.ParameterizedTransformation) 47 | 48 | self.lp_style = perturbation_params.lp_style 49 | self.lp_bound = perturbation_params.lp_bound 50 | self.use_smooth_loss = perturbation_params.use_smooth_loss 51 | self.scalar_step = perturbation_params.scalar_step or 1.0 52 | self.cspace = perturbation_params.cspace or cs.RGBColorSpace() 53 | 54 | def _merge_setup(self, num_examples, new_xform): 55 | """ DANGEROUS TO BE CALLED OUTSIDE OF THIS FILE!!!""" 56 | self.num_examples = num_examples 57 | self.xform = new_xform 58 | self.initialized = True 59 | 60 | def setup(self, originals): 61 | super().setup(originals) 62 | self.xform = self.perturbation_params.xform_class( 63 | shape=originals.shape, manual_gpu=self.use_gpu, 64 | cspace=self.cspace, 65 | **(self.perturbation_params.xform_params or {}), 66 | ) 67 | self.initialized = True 68 | 69 | @initialized 70 | def perturbation_norm(self, x=None, lp_style=None): 71 | lp_style = lp_style or self.lp_style 72 | if self.use_smooth_loss: 73 | assert isinstance(self.xform, ct.FullSpatial) 74 | return self.xform.smoothness_norm() 75 | else: 76 | return self.xform.norm(lp=lp_style) 77 | 78 | @initialized 79 | def constrain_params(self, x=None): 80 | # Do lp projections 81 | if isinstance(self.lp_style, int) or self.lp_style == 'inf': 82 | self.xform.project_params(self.lp_style, self.lp_bound) 83 | 84 | @initialized 85 | def update_params(self, step_fxn): 86 | param_list = list(self.xform.parameters()) 87 | assert len(param_list) == 1 88 | params = param_list[0] 89 | assert params.grad.data is not None 90 | self.add_to_params(step_fxn(params.grad.data) * self.scalar_step) 91 | 92 | @initialized 93 | def add_to_params(self, grad_data): 94 | """ Assumes only one parameters object in the Spatial Transform """ 95 | param_list = list(self.xform.parameters()) 96 | assert len(param_list) == 1 97 | params = param_list[0] 98 | params.data.add_(grad_data) 99 | 100 | @initialized 101 | def random_init(self): 102 | param_list = list(self.xform.parameters()) 103 | assert len(param_list) == 1 104 | param = param_list[0] 105 | random_perturb = utils.random_from_lp_ball(param.data, 106 | self.lp_style, 107 | self.lp_bound) 108 | 109 | param.data.add_(self.xform.identity_params + 110 | random_perturb - self.xform.xform_params.data) 111 | 112 | @initialized 113 | def merge_perturbation(self, other, self_mask): 114 | super().merge_perturbation(other, self_mask) 115 | new_perturbation = ReColorAdv(self.threat_model, 116 | self.perturbation_params) 117 | 118 | new_xform = self.xform.merge_xform(other.xform, self_mask) 119 | new_perturbation._merge_setup(self.num_examples, new_xform) 120 | 121 | return new_perturbation 122 | 123 | def forward(self, x): 124 | if not self.initialized: 125 | self.setup(x) 126 | self.constrain_params() 127 | 128 | return self.cspace.to_rgb( 129 | self.xform.forward(self.cspace.from_rgb(x))) 130 | -------------------------------------------------------------------------------- /stadv_eot/recoloradv/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from recoloradv. 5 | # 6 | # Source: 7 | # https://github.com/cassidylaidlaw/ReColorAdv/blob/master/recoloradv/utils.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_RECOLORADV). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | from torch import nn 15 | from torch import optim 16 | 17 | from .mister_ed.utils.pytorch_utils import DifferentiableNormalize 18 | from .mister_ed import adversarial_perturbations as ap 19 | from .mister_ed import adversarial_attacks as aa 20 | from .mister_ed import spatial_transformers as st 21 | from .mister_ed import loss_functions as lf 22 | from .mister_ed import adversarial_training as advtrain 23 | 24 | from . import perturbations as pt 25 | from . import color_transformers as ct 26 | from . import color_spaces as cs 27 | 28 | 29 | def get_attack_from_name( 30 | name: str, 31 | classifier: nn.Module, 32 | normalizer: DifferentiableNormalize, 33 | verbose: bool = False, 34 | ) -> advtrain.AdversarialAttackParameters: 35 | """ 36 | Builds an attack from a name like "recoloradv" or "stadv+delta" or 37 | "recoloradv+stadv+delta". 38 | """ 39 | 40 | threats = [] 41 | norm_weights = [] 42 | 43 | for attack_part in name.split('+'): 44 | if attack_part == 'delta': 45 | threats.append(ap.ThreatModel( 46 | ap.DeltaAddition, 47 | ap.PerturbationParameters( 48 | lp_style='inf', 49 | lp_bound=8.0 / 255, 50 | ), 51 | )) 52 | norm_weights.append(0.0) 53 | elif attack_part == 'stadv': 54 | threats.append(ap.ThreatModel( 55 | ap.ParameterizedXformAdv, 56 | ap.PerturbationParameters( 57 | lp_style='inf', 58 | lp_bound=0.05, 59 | xform_class=st.FullSpatial, 60 | use_stadv=True, 61 | ), 62 | )) 63 | norm_weights.append(1.0) 64 | elif attack_part == 'recoloradv': 65 | threats.append(ap.ThreatModel( 66 | pt.ReColorAdv, 67 | ap.PerturbationParameters( 68 | lp_style='inf', 69 | lp_bound=[0.06, 0.06, 0.06], 70 | xform_params={ 71 | 'resolution_x': 16, 72 | 'resolution_y': 32, 73 | 'resolution_z': 32, 74 | }, 75 | xform_class=ct.FullSpatial, 76 | use_smooth_loss=True, 77 | cspace=cs.CIELUVColorSpace(), 78 | ), 79 | )) 80 | norm_weights.append(1.0) 81 | else: 82 | raise ValueError(f'Invalid attack "{attack_part}"') 83 | 84 | sequence_threat = ap.ThreatModel( 85 | ap.SequentialPerturbation, 86 | threats, 87 | ap.PerturbationParameters(norm_weights=norm_weights), 88 | ) 89 | 90 | # use PGD attack 91 | adv_loss = lf.CWLossF6(classifier, normalizer, kappa=float('inf')) 92 | st_loss = lf.PerturbationNormLoss(lp=2) 93 | loss_fxn = lf.RegularizedLoss({'adv': adv_loss, 'pert': st_loss}, 94 | {'adv': 1.0, 'pert': 0.05}, 95 | negate=True) 96 | 97 | pgd_attack = aa.PGD(classifier, normalizer, sequence_threat, loss_fxn) 98 | return advtrain.AdversarialAttackParameters( 99 | pgd_attack, 100 | 1.0, 101 | attack_specific_params={'attack_kwargs': { 102 | 'num_iterations': 100, 103 | 'optimizer': optim.Adam, 104 | 'optimizer_kwargs': {'lr': 0.001}, 105 | 'signed': False, 106 | 'verbose': verbose, 107 | }}, 108 | ) 109 | --------------------------------------------------------------------------------