├── README.md ├── architectures.py ├── archs └── cifar_resnet.py ├── certification_pic.py ├── classifiers ├── attribute_classifier.py ├── attribute_net.py └── cifar10_resnet.py ├── compute_accuracy.py ├── configs ├── cifar10.yml └── imagenet.yml ├── core.py ├── data ├── __init__.py └── datasets.py ├── datasets.py ├── ddpm └── unet_ddpm.py ├── eval_certified_densepure.py ├── guided_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── improved_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── networks │ ├── __init__.py │ ├── lenet.py │ ├── resnet.py │ ├── vggnet.py │ └── wide_resnet.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── networks ├── __init__.py ├── lenet.py ├── resnet.py ├── vggnet.py └── wide_resnet.py ├── pictures └── densepure_flowchart.png ├── requirements.txt ├── results ├── merge_cifar10.sh ├── merge_imagenet.sh └── merge_results.py ├── run_scripts ├── carlini22_cifar10.sh ├── carlini22_imagenet.sh ├── densepure_cifar10.sh └── densepure_imagenet.sh ├── runners ├── diffpure_ddpm_densepure.py └── diffpure_guided_densepure.py ├── utils.py └── zipdata.py /README.md: -------------------------------------------------------------------------------- 1 | # DensePure: Understanding Diffusion Models towards Adversarial Robustness 2 | 3 |

4 | 5 |

6 | 7 | Official PyTorch implementation of the paper:
8 | **[DensePure: Understanding Diffusion Models towards Adversarial Robustness](https://arxiv.org/abs/2211.00322)** 9 |
10 | Chaowei Xiao, Zhongzhu Chen, Kun Jin, Jiongxiao Wang, Weili Nie, Mingyan Liu, Anima Anandkumar, Bo Li, Dawn Song
11 | https://densepure.github.io
12 | 13 | Abstract: *Diffusion models have been recently employed to improve certified robustness through the process of denoising. However, the theoretical understanding of why diffusion models are able to improve the certified robustness is still lacking, preventing from further improvement. In this study, we close this gap by analyzing the fundamental properties of diffusion models and establishing the conditions under which they can enhance certified robustness. This deeper understanding allows us to propose a new method DensePure, designed to improve the certified robustness of a pretrained model (i.e. classifier). Given an (adversarial) input, DensePure consists of multiple runs of denoising via the reverse process of the diffusion model (with different random seeds) to get multiple reversed samples, which are then passed through the classifier, followed by majority voting of inferred labels to make the final prediction. This design of using multiple runs of denoising is informed by our theoretical analysis of the conditional distribution of the reversed sample. Specifically, when the data density of a clean sample is high, its conditional density under the reverse process in a diffusion model is also high; thus sampling from the latter conditional distribution can purify the adversarial example and return the corresponding clean sample with a high probability. By using the highest density point in the conditional distribution as the reversed sample, we identify the robust region of a given instance under the diffusion model's reverse process. We show that this robust region is a union of multiple convex sets, and is potentially much larger than the robust regions identified in previous works. In practice, DensePure can approximate the label of the high density region in the conditional distribution so that it can enhance certified robustness. We conduct extensive experiments to demonstrate the effectiveness of DensePure by evaluating its certified robustness given a standard model via randomized smoothing. We show that DensePure is consistently better than existing methods on ImageNet, with 7% improvement on average.* 14 | 15 | ## Requirements 16 | 17 | - Python 3.8.5 18 | - CUDA=11.1 19 | - Installation of PyTorch 1.8.0: 20 | ```bash 21 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 22 | ``` 23 | - Installation of required packages: 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Datasets, Pre-trained Diffusion Models and Classifiers 29 | Before running our code, you need to first prepare two datasets CIFAR-10 and ImageNet. CIFAR-10 will be downloaded automatically. 30 | For ImageNet, you need to download validation images of ILSVRC2012 from https://www.image-net.org/. And the images need to be preprocessed by running the scripts `valprep.sh` from https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 31 | under validation directory. 32 | 33 | Please change IMAGENET_DIR to your own location of ImageNet dataset in `datasets.py` before running the code. 34 | 35 | For the pre-trained diffusion models, you need to first download them from the following links: 36 | - [Improved Diffusion](https://github.com/openai/improved-diffusion) for 37 | CIFAR-10: (`cifar10_uncond_50M_500K.pt`: [download link](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_50M_500K.pt)) 38 | - [Guided Diffusion](https://github.com/openai/guided-diffusion) for 39 | ImageNet: (`256x256 diffusion unconditional`: [download link](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt)) 40 | 41 | For the pre-trained classifiers, ViT-B/16 model on CIFAR-10 will be automatically downloaded by `transformers`. For ImageNet BEiT large model, you need to dwonload from the following links: 42 | - [BEiT](https://github.com/microsoft/unilm/tree/master/beit) for 43 | ImageNet: (`beit_large_patch16_512_pt22k_ft22kto1k.pth`: [download link](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth)) 44 | 45 | Please place all the pretrained models in the `pretrained` directory. If you want to use your own classifiers, code need to be changed in `eval_certified_densepure.py`. 46 | 47 | ## Run Experiments of Carlini 2022 48 | We provide our own code implementation for the paper [(Certified!!) Adversarial Robustness for Free!](https://arxiv.org/abs/2206.10550) to compare with DensePure. 49 | 50 | To gain the results in Table 1 about Carlini22, please run the following scripts using different noise levels `sigma`: 51 | ``` 52 | cd run_scripts 53 | bash carlini22_cifar10.sh [sigma] # For CIFAR-10 54 | bash carlini22_imagenet.sh [sigma] # For ImageNet 55 | ``` 56 | 57 | ## Run Experiments of DensePure 58 | To get certified accuracy under DensePure, please run the following scripts: 59 | ``` 60 | cd run_scripts 61 | bash densepure_cifar10.sh [sigma] [steps] [reverse_seed] # For CIFAR-10 62 | bash densepure_imagenet.sh [sigma] [steps] [reverse_seed] # For ImageNet 63 | ``` 64 | 65 | Note: `sigma` is the noise level of randomized smoothing. `steps` is the parameter for fast sampling steps in Section 5.2 and it must be larger than one and smaller than the total reverse steps. `reverse_seed` is a parameter which control majority vote process in Section 5.2. For example, you need to run `densepure_cifar10.sh` 10 times with 10 different `reverse_seed` to finish 10 majority vote numbers experiments. After running above scripts under one `reverse_seed`, you will gain a `.npy` file that contains labels of 100000 (for CIFAR-10) randomized smoothing sampling times. If you want to obtain the final results of 10 majority vote numbers, you need to run the following scripts in `results` directory: 66 | ``` 67 | cd results 68 | bash merge_cifar10.sh [sigma] [steps] [majority_vote_numbers] # For CIFAR-10 69 | bash merge_imagenet.sh [sigma] [steps] [majority_vote_numbers] # For ImageNet 70 | ``` 71 | 72 | ## Citation 73 | Please cite our paper and Carlini et al. (2022), if you happen to use this codebase: 74 | ``` 75 | @article{xiao2022densepure, 76 | title={DensePure: Understanding Diffusion Models towards Adversarial Robustness}, 77 | author={Xiao, Chaowei and Chen, Zhongzhu and Jin, Kun and Wang, Jiongxiao and Nie, Weili and Liu, Mingyan and Anandkumar, Anima and Li, Bo and Song, Dawn}, 78 | journal={arXiv preprint arXiv:2211.00322}, 79 | year={2022} 80 | } 81 | ``` 82 | 83 | ``` 84 | @article{carlini2022certified, 85 | title={(Certified!!) Adversarial Robustness for Free!}, 86 | author={Carlini, Nicholas and Tramer, Florian and Kolter, J Zico and others}, 87 | journal={arXiv preprint arXiv:2206.10550}, 88 | year={2022} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /architectures.py: -------------------------------------------------------------------------------- 1 | from archs.cifar_resnet import resnet as resnet_cifar 2 | from datasets import get_normalize_layer, get_input_center_layer 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn as nn 6 | from torch.nn.functional import interpolate 7 | from torchvision.models.resnet import resnet50 8 | 9 | 10 | # resnet50 - the classic ResNet-50, sized for ImageNet 11 | # cifar_resnet20 - a 20-layer residual network sized for CIFAR 12 | # cifar_resnet110 - a 110-layer residual network sized for CIFAR 13 | ARCHITECTURES = ["resnet50", "cifar_resnet110", "imagenet32_resnet110"] 14 | 15 | def get_architecture(arch: str, dataset: str) -> torch.nn.Module: 16 | """ Return a neural network (with random weights) 17 | 18 | :param arch: the architecture - should be in the ARCHITECTURES list above 19 | :param dataset: the dataset - should be in the datasets.DATASETS list 20 | :return: a Pytorch module 21 | """ 22 | if arch == "resnet50" and dataset == "imagenet": 23 | model = torch.nn.DataParallel(resnet50(pretrained=False)).cuda() 24 | cudnn.benchmark = True 25 | elif arch == "cifar_resnet20": 26 | model = resnet_cifar(depth=20, num_classes=10).cuda() 27 | elif arch == "cifar_resnet110": 28 | model = resnet_cifar(depth=110, num_classes=10).cuda() 29 | elif arch == "imagenet32_resnet110": 30 | model = resnet_cifar(depth=110, num_classes=1000).cuda() 31 | 32 | # Both layers work fine, We tried both, and they both 33 | # give very similar results 34 | # IF YOU USE ONE OF THESE FOR TRAINING, MAKE SURE 35 | # TO USE THE SAME WHEN CERTIFYING. 36 | normalize_layer = get_normalize_layer(dataset) 37 | # normalize_layer = get_input_center_layer(dataset) 38 | return torch.nn.Sequential(normalize_layer, model) 39 | -------------------------------------------------------------------------------- /archs/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | __all__ = ['resnet'] 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 88 | super(ResNet, self).__init__() 89 | # Model type specifies number of layers for CIFAR-10 model 90 | if block_name.lower() == 'basicblock': 91 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 92 | n = (depth - 2) // 6 93 | block = BasicBlock 94 | elif block_name.lower() == 'bottleneck': 95 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 96 | n = (depth - 2) // 9 97 | block = Bottleneck 98 | else: 99 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 100 | 101 | 102 | self.inplanes = 16 103 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 104 | bias=False) 105 | self.bn1 = nn.BatchNorm2d(16) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.layer1 = self._make_layer(block, 16, n) 108 | self.layer2 = self._make_layer(block, 32, n, stride=2) 109 | self.layer3 = self._make_layer(block, 64, n, stride=2) 110 | self.avgpool = nn.AvgPool2d(8) 111 | self.fc = nn.Linear(64 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) # 32x32 142 | 143 | x = self.layer1(x) # 32x32 144 | x = self.layer2(x) # 16x16 145 | x = self.layer3(x) # 8x8 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | def resnet(**kwargs): 155 | """ 156 | Constructs a ResNet model. 157 | """ 158 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /classifiers/attribute_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from . import attribute_net 4 | 5 | softmax = torch.nn.Softmax(dim=1) 6 | 7 | 8 | def downsample(images, size=256): 9 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 10 | # follows https://github.com/NVlabs/stylegan/blob/master/metrics/linear_separability.py#L127 11 | if images.shape[2] > size: 12 | factor = images.shape[2] // size 13 | assert (factor * size == images.shape[2]) 14 | images = images.view( 15 | [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 16 | images = images.mean(dim=[3, 5]) 17 | return images 18 | else: 19 | assert (images.shape[-1] == 256) 20 | return images 21 | 22 | 23 | def get_logit(net, im): 24 | im_256 = downsample(im) 25 | logit = net(im_256) 26 | return logit 27 | 28 | 29 | def get_softmaxed(net, im): 30 | logit = get_logit(net, im) 31 | logits = torch.cat([logit, -logit], dim=1) 32 | softmaxed = softmax(torch.cat([logit, -logit], dim=1))[:, 1] 33 | return logits, softmaxed 34 | 35 | 36 | def load_attribute_classifier(attribute, ckpt_path=None): 37 | if ckpt_path is None: 38 | base_path = 'pretrained/celebahq' 39 | attribute_pkl = os.path.join(base_path, attribute, 'net_best.pth') 40 | ckpt = torch.load(attribute_pkl) 41 | else: 42 | ckpt = torch.load(ckpt_path) 43 | print("Using classifier at epoch: %d" % ckpt['epoch']) 44 | if 'valacc' in ckpt.keys(): 45 | print("Validation acc on raw images: %0.5f" % ckpt['valacc']) 46 | detector = attribute_net.from_state_dict( 47 | ckpt['state_dict'], fixed_size=True, use_mbstd=False).cuda().eval() 48 | return detector 49 | 50 | 51 | class ClassifierWrapper(torch.nn.Module): 52 | def __init__(self, classifier_name, ckpt_path=None, device='cuda'): 53 | super(ClassifierWrapper, self).__init__() 54 | self.net = load_attribute_classifier(classifier_name, ckpt_path).eval().to(device) 55 | 56 | def forward(self, ims): 57 | out = (ims - 0.5) / 0.5 58 | return get_softmaxed(self.net, out)[0] 59 | -------------------------------------------------------------------------------- /classifiers/attribute_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def lerp_clip(a, b, t): 7 | return a + (b - a) * torch.clamp(t, 0.0, 1.0) 8 | 9 | 10 | class WScaleLayer(nn.Module): 11 | def __init__(self, size, fan_in, gain=np.sqrt(2), bias=True): 12 | super(WScaleLayer, self).__init__() 13 | self.scale = gain / np.sqrt(fan_in) # No longer a parameter 14 | if bias: 15 | self.b = nn.Parameter(torch.randn(size)) 16 | else: 17 | self.b = 0 18 | self.size = size 19 | 20 | def forward(self, x): 21 | x_size = x.size() 22 | x = x * self.scale 23 | # modified to remove warning 24 | if type(self.b) == nn.Parameter and len(x_size) == 4: 25 | x = x + self.b.view(1, -1, 1, 1).expand( 26 | x_size[0], self.size, x_size[2], x_size[3]) 27 | if type(self.b) == nn.Parameter and len(x_size) == 2: 28 | x = x + self.b.view(1, -1).expand( 29 | x_size[0], self.size) 30 | return x 31 | 32 | 33 | class WScaleConv2d(nn.Module): 34 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, 35 | bias=True, gain=np.sqrt(2)): 36 | super().__init__() 37 | self.conv = nn.Conv2d(in_channels, out_channels, 38 | kernel_size=kernel_size, 39 | padding=padding, 40 | bias=False) 41 | fan_in = in_channels * kernel_size * kernel_size 42 | self.wscale = WScaleLayer(out_channels, fan_in, gain=gain, bias=bias) 43 | 44 | def forward(self, x): 45 | return self.wscale(self.conv(x)) 46 | 47 | 48 | class WScaleLinear(nn.Module): 49 | def __init__(self, in_channels, out_channels, bias=True, gain=np.sqrt(2)): 50 | super().__init__() 51 | self.linear = nn.Linear(in_channels, out_channels, bias=False) 52 | self.wscale = WScaleLayer(out_channels, in_channels, gain=gain, 53 | bias=bias) 54 | 55 | def forward(self, x): 56 | return self.wscale(self.linear(x)) 57 | 58 | 59 | class FromRGB(nn.Module): 60 | def __init__(self, in_channels, out_channels, kernel_size, 61 | act=nn.LeakyReLU(0.2), bias=True): 62 | super().__init__() 63 | self.conv = WScaleConv2d(in_channels, out_channels, kernel_size, 64 | padding=0, bias=bias) 65 | self.act = act 66 | 67 | def forward(self, x): 68 | return self.act(self.conv(x)) 69 | 70 | 71 | class Downscale2d(nn.Module): 72 | def __init__(self, factor=2): 73 | super().__init__() 74 | self.downsample = nn.AvgPool2d(kernel_size=factor, stride=factor) 75 | 76 | def forward(self, x): 77 | return self.downsample(x) 78 | 79 | 80 | class DownscaleConvBlock(nn.Module): 81 | def __init__(self, in_channels, conv0_channels, conv1_channels, 82 | kernel_size, padding, bias=True, act=nn.LeakyReLU(0.2)): 83 | super().__init__() 84 | self.downscale = Downscale2d() 85 | self.conv0 = WScaleConv2d(in_channels, conv0_channels, 86 | kernel_size=kernel_size, 87 | padding=padding, 88 | bias=bias) 89 | self.conv1 = WScaleConv2d(conv0_channels, conv1_channels, 90 | kernel_size=kernel_size, 91 | padding=padding, 92 | bias=bias) 93 | self.act = act 94 | 95 | def forward(self, x): 96 | x = self.act(self.conv0(x)) 97 | # conv2d_downscale2d applies downscaling before activation 98 | # the order matters here! has to be conv -> bias -> downscale -> act 99 | x = self.conv1(x) 100 | x = self.downscale(x) 101 | x = self.act(x) 102 | return x 103 | 104 | 105 | class MinibatchStdLayer(nn.Module): 106 | def __init__(self, group_size=4): 107 | super().__init__() 108 | self.group_size = group_size 109 | 110 | def forward(self, x): 111 | group_size = min(self.group_size, x.shape[0]) 112 | s = x.shape 113 | y = x.view([group_size, -1, s[1], s[2], s[3]]) 114 | y = y.float() 115 | y = y - torch.mean(y, dim=0, keepdim=True) 116 | y = torch.mean(y * y, dim=0) 117 | y = torch.sqrt(y + 1e-8) 118 | y = torch.mean(torch.mean(torch.mean(y, dim=3, keepdim=True), 119 | dim=2, keepdim=True), dim=1, keepdim=True) 120 | y = y.type(x.type()) 121 | y = y.repeat(group_size, 1, s[2], s[3]) 122 | return torch.cat([x, y], dim=1) 123 | 124 | 125 | class PredictionBlock(nn.Module): 126 | def __init__(self, in_channels, dense0_feat, dense1_feat, out_feat, 127 | pool_size=2, act=nn.LeakyReLU(0.2), use_mbstd=True): 128 | super().__init__() 129 | self.use_mbstd = use_mbstd # attribute classifiers don't have this 130 | if self.use_mbstd: 131 | self.mbstd_layer = MinibatchStdLayer() 132 | # MinibatchStdLayer adds an additional feature dimension 133 | self.conv = WScaleConv2d(in_channels + int(self.use_mbstd), 134 | dense0_feat, kernel_size=3, padding=1) 135 | self.dense0 = WScaleLinear(dense0_feat * pool_size * pool_size, dense1_feat) 136 | self.dense1 = WScaleLinear(dense1_feat, out_feat, gain=1) 137 | self.act = act 138 | 139 | def forward(self, x): 140 | if self.use_mbstd: 141 | x = self.mbstd_layer(x) 142 | x = self.act(self.conv(x)) 143 | x = x.view([x.shape[0], -1]) 144 | x = self.act(self.dense0(x)) 145 | x = self.dense1(x) 146 | return x 147 | 148 | 149 | class D(nn.Module): 150 | 151 | def __init__( 152 | self, 153 | num_channels=3, # Number of input color channels. Overridden based on dataset. 154 | resolution=128, # Input resolution. Overridden based on dataset. 155 | fmap_base=8192, # Overall multiplier for the number of feature maps. 156 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 157 | fmap_max=512, # Maximum number of feature maps in any layer. 158 | fixed_size=False, # True = load fromrgb_lod0 weights only 159 | use_mbstd=True, # False = no mbstd layer in PredictionBlock 160 | **kwargs): # Ignore unrecognized keyword args. 161 | super().__init__() 162 | 163 | self.resolution_log2 = resolution_log2 = int(np.log2(resolution)) 164 | assert resolution == 2 ** resolution_log2 and resolution >= 4 165 | 166 | def nf(stage): 167 | return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 168 | 169 | self.register_buffer('lod_in', torch.from_numpy(np.array(0.0))) 170 | 171 | res = resolution_log2 172 | 173 | setattr(self, 'fromrgb_lod0', FromRGB(num_channels, nf(res - 1), 1)) 174 | 175 | for i, res in enumerate(range(resolution_log2, 2, -1), 1): 176 | lod = resolution_log2 - res 177 | block = DownscaleConvBlock(nf(res - 1), nf(res - 1), nf(res - 2), 178 | kernel_size=3, padding=1) 179 | setattr(self, '%dx%d' % (2 ** res, 2 ** res), block) 180 | fromrgb = FromRGB(3, nf(res - 2), 1) 181 | if not fixed_size: 182 | setattr(self, 'fromrgb_lod%d' % i, fromrgb) 183 | 184 | res = 2 185 | pool_size = 2 ** res 186 | block = PredictionBlock(nf(res + 1 - 2), nf(res - 1), nf(res - 2), 1, 187 | pool_size, use_mbstd=use_mbstd) 188 | setattr(self, '%dx%d' % (pool_size, pool_size), block) 189 | self.downscale = Downscale2d() 190 | self.fixed_size = fixed_size 191 | 192 | def forward(self, img): 193 | x = self.fromrgb_lod0(img) 194 | for i, res in enumerate(range(self.resolution_log2, 2, -1), 1): 195 | lod = self.resolution_log2 - res 196 | x = getattr(self, '%dx%d' % (2 ** res, 2 ** res))(x) 197 | if not self.fixed_size: 198 | img = self.downscale(img) 199 | y = getattr(self, 'fromrgb_lod%d' % i)(img) 200 | x = lerp_clip(x, y, self.lod_in - lod) 201 | res = 2 202 | pool_size = 2 ** res 203 | out = getattr(self, '%dx%d' % (pool_size, pool_size))(x) 204 | return out 205 | 206 | 207 | def max_res_from_state_dict(state_dict): 208 | for i in range(3, 12): 209 | if '%dx%d.conv0.conv.weight' % (2 ** i, 2 ** i) not in state_dict: 210 | break 211 | return 2 ** (i - 1) 212 | 213 | 214 | def from_state_dict(state_dict, fixed_size=False, use_mbstd=True): 215 | res = max_res_from_state_dict(state_dict) 216 | print(f'res: {res}') 217 | d = D(num_channels=3, resolution=res, fixed_size=fixed_size, 218 | use_mbstd=use_mbstd) 219 | d.load_state_dict(state_dict) 220 | return d 221 | -------------------------------------------------------------------------------- /classifiers/cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | 8 | # ---------------------------- ResNet ---------------------------- 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(Bottleneck, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion * planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion * planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = F.relu(self.bn2(self.conv2(out))) 32 | out = self.bn3(self.conv3(out)) 33 | out += self.shortcut(x) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class ResNet(nn.Module): 39 | def __init__(self, block, num_blocks, num_classes=10): 40 | super(ResNet, self).__init__() 41 | self.in_planes = 64 42 | 43 | num_input_channels = 3 44 | mean = (0.4914, 0.4822, 0.4465) 45 | std = (0.2471, 0.2435, 0.2616) 46 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 47 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 52 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 53 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 54 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 55 | self.linear = nn.Linear(512 * block.expansion, num_classes) 56 | 57 | def _make_layer(self, block, planes, num_blocks, stride): 58 | strides = [stride] + [1] * (num_blocks - 1) 59 | layers = [] 60 | for stride in strides: 61 | layers.append(block(self.in_planes, planes, stride)) 62 | self.in_planes = planes * block.expansion 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = (x - self.mean.to(x.device)) / self.std.to(x.device) 67 | out = F.relu(self.bn1(self.conv1(out))) 68 | out = self.layer1(out) 69 | out = self.layer2(out) 70 | out = self.layer3(out) 71 | out = self.layer4(out) 72 | out = F.avg_pool2d(out, 4) 73 | out = out.view(out.size(0), -1) 74 | out = self.linear(out) 75 | return out 76 | 77 | 78 | def ResNet50(): 79 | return ResNet(Bottleneck, [3, 4, 6, 3]) 80 | 81 | 82 | # ---------------------------- ResNet ---------------------------- 83 | 84 | 85 | # ---------------------------- WideResNet ---------------------------- 86 | 87 | class BasicBlock(nn.Module): 88 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 89 | super(BasicBlock, self).__init__() 90 | self.bn1 = nn.BatchNorm2d(in_planes) 91 | self.relu1 = nn.ReLU(inplace=True) 92 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 93 | padding=1, bias=False) 94 | self.bn2 = nn.BatchNorm2d(out_planes) 95 | self.relu2 = nn.ReLU(inplace=True) 96 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 97 | padding=1, bias=False) 98 | self.droprate = dropRate 99 | self.equalInOut = (in_planes == out_planes) 100 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 101 | padding=0, bias=False) or None 102 | 103 | def forward(self, x): 104 | if not self.equalInOut: 105 | x = self.relu1(self.bn1(x)) 106 | else: 107 | out = self.relu1(self.bn1(x)) 108 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 109 | if self.droprate > 0: 110 | out = F.dropout(out, p=self.droprate, training=self.training) 111 | out = self.conv2(out) 112 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 113 | 114 | 115 | class NetworkBlock(nn.Module): 116 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 117 | super(NetworkBlock, self).__init__() 118 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 119 | 120 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 121 | layers = [] 122 | for i in range(int(nb_layers)): 123 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | return self.layer(x) 128 | 129 | 130 | class WideResNet(nn.Module): 131 | """ Based on code from https://github.com/yaodongyu/TRADES """ 132 | 133 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 134 | super(WideResNet, self).__init__() 135 | 136 | num_input_channels = 3 137 | mean = (0.4914, 0.4822, 0.4465) 138 | std = (0.2471, 0.2435, 0.2616) 139 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 140 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 141 | 142 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 143 | assert ((depth - 4) % 6 == 0) 144 | n = (depth - 4) / 6 145 | block = BasicBlock 146 | # 1st conv before any network block 147 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 148 | padding=1, bias=False) 149 | # 1st block 150 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 151 | if sub_block1: 152 | # 1st sub-block 153 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 154 | # 2nd block 155 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 156 | # 3rd block 157 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 158 | # global average pooling and classifier 159 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 162 | self.nChannels = nChannels[3] 163 | 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 167 | m.weight.data.normal_(0, math.sqrt(2. / n)) 168 | elif isinstance(m, nn.BatchNorm2d): 169 | m.weight.data.fill_(1) 170 | m.bias.data.zero_() 171 | elif isinstance(m, nn.Linear) and not m.bias is None: 172 | m.bias.data.zero_() 173 | 174 | def forward(self, x): 175 | out = (x - self.mean.to(x.device)) / self.std.to(x.device) 176 | out = self.conv1(out) 177 | out = self.block1(out) 178 | out = self.block2(out) 179 | out = self.block3(out) 180 | out = self.relu(self.bn1(out)) 181 | out = F.avg_pool2d(out, 8) 182 | out = out.view(-1, self.nChannels) 183 | return self.fc(out) 184 | 185 | 186 | def WideResNet_70_16(): 187 | return WideResNet(depth=70, widen_factor=16, dropRate=0.0) 188 | 189 | 190 | def WideResNet_70_16_dropout(): 191 | return WideResNet(depth=70, widen_factor=16, dropRate=0.3) 192 | # ---------------------------- WideResNet ---------------------------- 193 | -------------------------------------------------------------------------------- /compute_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | class Accuracy(object): 5 | def at_radii(self, radii: np.ndarray): 6 | raise NotImplementedError() 7 | 8 | class ApproximateAccuracy(Accuracy): 9 | def __init__(self, data_file_path: str): 10 | self.data_file_path = data_file_path 11 | 12 | def at_radii(self, radii: np.ndarray) -> np.ndarray: 13 | df = pd.read_csv(self.data_file_path, delimiter="\s+") 14 | return np.array([self.at_radius(df, radius) for radius in radii]) 15 | 16 | def at_radius(self, df: pd.DataFrame, radius: float): 17 | return (df["correct"] & (df["radius"] >= radius)).mean() 18 | 19 | def get_abstention_rate(self) -> np.ndarray: 20 | df = pd.read_csv(self.data_file_path, delimiter="\t") 21 | return 1.*(df["predict"]==-1).sum()/len(df["predict"])*100 22 | 23 | acc = ApproximateAccuracy("results_file_path") 24 | 25 | def latex_table_certified_accuracy(radius): 26 | radii = [radius] 27 | accuracy = acc.at_radii(radii) 28 | print('certified_acc:'+str(accuracy)) 29 | 30 | if __name__ == "__main__": 31 | #certified accuracy for imagenet 32 | # latex_table_certified_accuracy(0.00) 33 | # latex_table_certified_accuracy(0.50) 34 | # latex_table_certified_accuracy(1.00) 35 | # latex_table_certified_accuracy(1.50) 36 | # latex_table_certified_accuracy(2.00) 37 | # latex_table_certified_accuracy(3.00) 38 | 39 | # certified accuracy for cifar10 40 | latex_table_certified_accuracy(0.00) 41 | latex_table_certified_accuracy(0.25) 42 | latex_table_certified_accuracy(0.50) 43 | latex_table_certified_accuracy(0.75) 44 | latex_table_certified_accuracy(1.00) 45 | -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | model: -------------------------------------------------------------------------------- /configs/imagenet.yml: -------------------------------------------------------------------------------- 1 | model: 2 | attention_resolutions: '32,16,8' 3 | class_cond: False 4 | diffusion_steps: 1000 5 | rescale_timesteps: True 6 | timestep_respacing: '1000' # Modify this value to decrease the number of timesteps. 7 | image_size: 256 8 | learn_sigma: True 9 | noise_schedule: 'linear' 10 | num_channels: 256 11 | num_head_channels: 64 12 | num_res_blocks: 2 13 | resblock_updown: True 14 | use_fp16: True 15 | use_scale_shift_norm: True -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | from scipy.stats import norm, binom_test 5 | from statsmodels.stats.proportion import proportion_confint 6 | import torch 7 | 8 | class Smooth(object): 9 | """A smoothed classifier g """ 10 | 11 | # to abstain, Smooth returns this int 12 | ABSTAIN = -1 13 | 14 | def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float): 15 | """ 16 | :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes] 17 | :param num_classes: 18 | :param sigma: the noise level hyperparameter 19 | """ 20 | self.base_classifier = base_classifier 21 | self.num_classes = num_classes 22 | self.sigma = sigma 23 | 24 | def certify(self, x: torch.tensor, n0: int, n: int, sample_id:int, alpha: float, batch_size: int, clustering_method='none') -> (int, float): 25 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius. 26 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will 27 | robust within a L2 ball of radius R around x. 28 | 29 | :param x: the input [channel x height x width] 30 | :param n0: the number of Monte Carlo samples to use for selection 31 | :param n: the number of Monte Carlo samples to use for estimation 32 | :param alpha: the failure probability 33 | :param batch_size: batch size to use when evaluating the base classifier 34 | :return: (predicted class, certified radius) 35 | in the case of abstention, the class will be ABSTAIN and the radius 0. 36 | """ 37 | self.base_classifier.eval() 38 | # draw samples of f(x+ epsilon) 39 | counts_selection, n0_predictions = self._sample_noise(x, n0, batch_size, clustering_method, sample_id) 40 | # use these samples to take a guess at the top class 41 | cAHat = counts_selection.argmax().item() 42 | # draw more samples of f(x + epsilon) 43 | counts_estimation, n_predictions = self._sample_noise(x, n, batch_size, clustering_method, sample_id) 44 | # use these samples to estimate a lower bound on pA 45 | nA = counts_estimation[cAHat].item() 46 | pABar = self._lower_confidence_bound(nA, n, alpha) 47 | if pABar < 0.5: 48 | return Smooth.ABSTAIN, 0.0, n0_predictions, n_predictions 49 | else: 50 | radius = self.sigma * norm.ppf(pABar) 51 | return cAHat, radius, n0_predictions, n_predictions 52 | 53 | def certify_noapproximate(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float): 54 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius. 55 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will 56 | robust within a L2 ball of radius R around x. 57 | 58 | :param x: the input [channel x height x width] 59 | :param n0: the number of Monte Carlo samples to use for selection 60 | :param n: the number of Monte Carlo samples to use for estimation 61 | :param alpha: the failure probability 62 | :param batch_size: batch size to use when evaluating the base classifier 63 | :return: (predicted class, certified radius) 64 | in the case of abstention, the class will be ABSTAIN and the radius 0. 65 | """ 66 | self.base_classifier.eval() 67 | # draw samples of f(x+ epsilon) 68 | counts_selection = self._sample_noise(x, n0, batch_size) 69 | # use these samples to take a guess at the top class 70 | cAHat = counts_selection.argmax().item() 71 | # draw more samples of f(x + epsilon) 72 | counts_estimation = self._sample_noise(x, n, batch_size) 73 | # use these samples to estimate a lower bound on pA 74 | top2 = counts_estimation.argsort()[::-1][:2] 75 | nA = counts_estimation[top2[0]].item() 76 | nB = counts_estimation[top2[1]].item() 77 | 78 | pABar = self._lower_confidence_bound(nA, n, alpha) 79 | pBBar = self._upper_confidence_bound(nB, n, alpha) 80 | if pABar < 0.5: 81 | return Smooth.ABSTAIN, 0.0 82 | else: 83 | radius = self.sigma/2 * (norm.ppf(pABar) - norm.ppf(pBBar)) 84 | return cAHat, radius 85 | 86 | def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int: 87 | """ Monte Carlo algorithm for evaluating the prediction of g at x. With probability at least 1 - alpha, the 88 | class returned by this method will equal g(x). 89 | 90 | This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944 91 | for identifying the top category of a multinomial distribution. 92 | 93 | :param x: the input [channel x height x width] 94 | :param n: the number of Monte Carlo samples to use 95 | :param alpha: the failure probability 96 | :param batch_size: batch size to use when evaluating the base classifier 97 | :return: the predicted class, or ABSTAIN 98 | """ 99 | self.base_classifier.eval() 100 | counts = self._sample_noise(x, n, batch_size) 101 | top2 = counts.argsort()[::-1][:2] 102 | count1 = counts[top2[0]] 103 | count2 = counts[top2[1]] 104 | if binom_test(count1, count1 + count2, p=0.5) > alpha: 105 | return Smooth.ABSTAIN 106 | else: 107 | return top2[0] 108 | 109 | def _sample_noise(self, x: torch.tensor, num: int, batch_size, clustering_method='none', sample_id=None) -> np.ndarray: 110 | """ Sample the base classifier's prediction under noisy corruptions of the input x. 111 | 112 | :param x: the input [channel x width x height] 113 | :param num: number of samples to collect 114 | :param batch_size: 115 | :return: an ndarray[int] of length num_classes containing the per-class counts 116 | """ 117 | with torch.no_grad(): 118 | predictions_all = np.array([], dtype=int) 119 | counts = np.zeros(self.num_classes, dtype=int) 120 | for _ in range(ceil(num / batch_size)): 121 | this_batch_size = min(batch_size, num) 122 | num -= this_batch_size 123 | 124 | batch = x.repeat((this_batch_size, 1, 1, 1)) 125 | noise = torch.randn_like(batch, device='cuda') * self.sigma 126 | 127 | if clustering_method == 'classifier': 128 | predictions = self.base_classifier(batch + noise, sample_id).argmax(1) 129 | predictions = predictions.view(this_batch_size,-1).cpu().numpy() 130 | count_max_list = np.zeros(this_batch_size,dtype=int) 131 | for i in range(this_batch_size): 132 | count_max = max(list(predictions[i]),key=list(predictions[i]).count) 133 | count_max_list[i] = count_max 134 | counts += self._count_arr(count_max_list, self.num_classes) 135 | 136 | else: 137 | predictions = self.base_classifier(batch + noise, sample_id).argmax(1) 138 | counts += self._count_arr(predictions.cpu().numpy(), self.num_classes) 139 | predictions_all = np.hstack((predictions_all, predictions.cpu().numpy())) 140 | 141 | return counts, predictions_all 142 | 143 | def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray: 144 | counts = np.zeros(length, dtype=int) 145 | for idx in arr: 146 | counts[idx] += 1 147 | return counts 148 | 149 | def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float: 150 | """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion. 151 | 152 | This function uses the Clopper-Pearson method. 153 | 154 | :param NA: the number of "successes" 155 | :param N: the number of total draws 156 | :param alpha: the confidence level 157 | :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples 158 | """ 159 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0] 160 | 161 | def _upper_confidence_bound(self, NA: int, N: int, alpha: float) -> float: 162 | 163 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[1] -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import imagenet_lmdb_dataset, imagenet_lmdb_dataset_sub, cifar10_dataset_sub 2 | 3 | def get_transform(dataset_name, transform_type, base_size=256): 4 | from . import datasets 5 | if dataset_name == 'celebahq': 6 | return datasets.get_transform(dataset_name, transform_type, base_size) 7 | elif 'imagenet' in dataset_name: 8 | return datasets.get_transform(dataset_name, transform_type, base_size) 9 | else: 10 | raise NotImplementedError 11 | 12 | 13 | def get_dataset(dataset_name, partition, *args, **kwargs): 14 | from . import datasets 15 | if dataset_name == 'celebahq': 16 | return datasets.CelebAHQDataset(partition, *args, **kwargs) 17 | else: 18 | raise NotImplementedError -------------------------------------------------------------------------------- /ddpm/unet_ddpm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate( 49 | x, scale_factor=2.0, mode="nearest") 50 | if self.with_conv: 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0, 1, 0, 1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | 77 | class ResnetBlock(nn.Module): 78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 79 | dropout, temb_channels=512): 80 | super().__init__() 81 | self.in_channels = in_channels 82 | out_channels = in_channels if out_channels is None else out_channels 83 | self.out_channels = out_channels 84 | self.use_conv_shortcut = conv_shortcut 85 | 86 | self.norm1 = Normalize(in_channels) 87 | self.conv1 = torch.nn.Conv2d(in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x + h 135 | 136 | 137 | class AttnBlock(nn.Module): 138 | def __init__(self, in_channels): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | 142 | self.norm = Normalize(in_channels) 143 | self.q = torch.nn.Conv2d(in_channels, 144 | in_channels, 145 | kernel_size=1, 146 | stride=1, 147 | padding=0) 148 | self.k = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.v = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.proj_out = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b, c, h, w = q.shape 173 | q = q.reshape(b, c, h * w) 174 | q = q.permute(0, 2, 1) # b,hw,c 175 | k = k.reshape(b, c, h * w) # b,c,hw 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c) ** (-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b, c, h * w) 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = torch.bmm(v, w_) 185 | h_ = h_.reshape(b, c, h, w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x + h_ 190 | 191 | 192 | class Model(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.config = config 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 197 | num_res_blocks = config.model.num_res_blocks 198 | attn_resolutions = config.model.attn_resolutions 199 | dropout = config.model.dropout 200 | in_channels = config.model.in_channels 201 | resolution = config.data.image_size 202 | resamp_with_conv = config.model.resamp_with_conv 203 | 204 | self.ch = ch 205 | self.temb_ch = self.ch * 4 206 | self.num_resolutions = len(ch_mult) 207 | self.num_res_blocks = num_res_blocks 208 | self.resolution = resolution 209 | self.in_channels = in_channels 210 | 211 | # timestep embedding 212 | self.temb = nn.Module() 213 | self.temb.dense = nn.ModuleList([ 214 | torch.nn.Linear(self.ch, 215 | self.temb_ch), 216 | torch.nn.Linear(self.temb_ch, 217 | self.temb_ch), 218 | ]) 219 | 220 | # downsampling 221 | self.conv_in = torch.nn.Conv2d(in_channels, 222 | self.ch, 223 | kernel_size=3, 224 | stride=1, 225 | padding=1) 226 | 227 | curr_res = resolution 228 | in_ch_mult = (1,) + ch_mult 229 | self.down = nn.ModuleList() 230 | block_in = None 231 | for i_level in range(self.num_resolutions): 232 | block = nn.ModuleList() 233 | attn = nn.ModuleList() 234 | block_in = ch * in_ch_mult[i_level] 235 | block_out = ch * ch_mult[i_level] 236 | for i_block in range(self.num_res_blocks): 237 | block.append(ResnetBlock(in_channels=block_in, 238 | out_channels=block_out, 239 | temb_channels=self.temb_ch, 240 | dropout=dropout)) 241 | block_in = block_out 242 | if curr_res in attn_resolutions: 243 | attn.append(AttnBlock(block_in)) 244 | down = nn.Module() 245 | down.block = block 246 | down.attn = attn 247 | if i_level != self.num_resolutions - 1: 248 | down.downsample = Downsample(block_in, resamp_with_conv) 249 | curr_res = curr_res // 2 250 | self.down.append(down) 251 | 252 | # middle 253 | self.mid = nn.Module() 254 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 255 | out_channels=block_in, 256 | temb_channels=self.temb_ch, 257 | dropout=dropout) 258 | self.mid.attn_1 = AttnBlock(block_in) 259 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 260 | out_channels=block_in, 261 | temb_channels=self.temb_ch, 262 | dropout=dropout) 263 | 264 | # upsampling 265 | self.up = nn.ModuleList() 266 | for i_level in reversed(range(self.num_resolutions)): 267 | block = nn.ModuleList() 268 | attn = nn.ModuleList() 269 | block_out = ch * ch_mult[i_level] 270 | skip_in = ch * ch_mult[i_level] 271 | for i_block in range(self.num_res_blocks + 1): 272 | if i_block == self.num_res_blocks: 273 | skip_in = ch * in_ch_mult[i_level] 274 | block.append(ResnetBlock(in_channels=block_in + skip_in, 275 | out_channels=block_out, 276 | temb_channels=self.temb_ch, 277 | dropout=dropout)) 278 | block_in = block_out 279 | if curr_res in attn_resolutions: 280 | attn.append(AttnBlock(block_in)) 281 | up = nn.Module() 282 | up.block = block 283 | up.attn = attn 284 | if i_level != 0: 285 | up.upsample = Upsample(block_in, resamp_with_conv) 286 | curr_res = curr_res * 2 287 | self.up.insert(0, up) # prepend to get consistent order 288 | 289 | # end 290 | self.norm_out = Normalize(block_in) 291 | self.conv_out = torch.nn.Conv2d(block_in, 292 | out_ch, 293 | kernel_size=3, 294 | stride=1, 295 | padding=1) 296 | 297 | def forward(self, x, t): 298 | assert x.shape[2] == x.shape[3] == self.resolution 299 | 300 | # timestep embedding 301 | temb = get_timestep_embedding(t, self.ch) 302 | temb = self.temb.dense[0](temb) 303 | temb = nonlinearity(temb) 304 | temb = self.temb.dense[1](temb) 305 | 306 | # downsampling 307 | hs = [self.conv_in(x)] 308 | for i_level in range(self.num_resolutions): 309 | for i_block in range(self.num_res_blocks): 310 | h = self.down[i_level].block[i_block](hs[-1], temb) 311 | if len(self.down[i_level].attn) > 0: 312 | h = self.down[i_level].attn[i_block](h) 313 | hs.append(h) 314 | if i_level != self.num_resolutions - 1: 315 | hs.append(self.down[i_level].downsample(hs[-1])) 316 | 317 | # middle 318 | h = hs[-1] 319 | h = self.mid.block_1(h, temb) 320 | h = self.mid.attn_1(h) 321 | h = self.mid.block_2(h, temb) 322 | 323 | # upsampling 324 | for i_level in reversed(range(self.num_resolutions)): 325 | for i_block in range(self.num_res_blocks + 1): 326 | h = self.up[i_level].block[i_block]( 327 | torch.cat([h, hs.pop()], dim=1), temb) 328 | if len(self.up[i_level].attn) > 0: 329 | h = self.up[i_level].attn[i_block](h) 330 | if i_level != 0: 331 | h = self.up[i_level].upsample(h) 332 | 333 | # end 334 | h = self.norm_out(h) 335 | h = nonlinearity(h) 336 | h = self.conv_out(h) 337 | return h 338 | -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch as th 4 | 5 | 6 | def normal_kl(mean1, logvar1, mean2, logvar2): 7 | """ 8 | Compute the KL divergence between two gaussians. 9 | 10 | Shapes are automatically broadcasted, so batches can be compared to 11 | scalars, among other use cases. 12 | """ 13 | tensor = None 14 | for obj in (mean1, logvar1, mean2, logvar2): 15 | if isinstance(obj, th.Tensor): 16 | tensor = obj 17 | break 18 | assert tensor is not None, "at least one argument must be a Tensor" 19 | 20 | # Force variances to be Tensors. Broadcasting helps convert scalars to 21 | # Tensors, but it does not work for th.exp(). 22 | logvar1, logvar2 = [ 23 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 24 | for x in (logvar1, logvar2) 25 | ] 26 | 27 | return 0.5 * ( 28 | -1.0 29 | + logvar2 30 | - logvar1 31 | + th.exp(logvar1 - logvar2) 32 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 33 | ) 34 | 35 | 36 | def approx_standard_normal_cdf(x): 37 | """ 38 | A fast approximation of the cumulative distribution function of the 39 | standard normal. 40 | """ 41 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 42 | 43 | 44 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 45 | """ 46 | Compute the log-likelihood of a Gaussian distribution discretizing to a 47 | given image. 48 | 49 | :param x: the target images. It is assumed that this was uint8 values, 50 | rescaled to the range [-1, 1]. 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | assert x.shape == means.shape == log_scales.shape 56 | centered_x = x - means 57 | inv_stdv = th.exp(-log_scales) 58 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 59 | cdf_plus = approx_standard_normal_cdf(plus_in) 60 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 61 | cdf_min = approx_standard_normal_cdf(min_in) 62 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 63 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 64 | cdf_delta = cdf_plus - cdf_min 65 | log_probs = th.where( 66 | x < -0.999, 67 | log_cdf_plus, 68 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 69 | ) 70 | assert log_probs.shape == x.shape 71 | return log_probs 72 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | 6 | 7 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 8 | class SiLU(nn.Module): 9 | def forward(self, x): 10 | return x * th.sigmoid(x) 11 | 12 | 13 | class GroupNorm32(nn.GroupNorm): 14 | def forward(self, x): 15 | return super().forward(x.float()).type(x.dtype) 16 | 17 | 18 | def conv_nd(dims, *args, **kwargs): 19 | """ 20 | Create a 1D, 2D, or 3D convolution module. 21 | """ 22 | if dims == 1: 23 | return nn.Conv1d(*args, **kwargs) 24 | elif dims == 2: 25 | return nn.Conv2d(*args, **kwargs) 26 | elif dims == 3: 27 | return nn.Conv3d(*args, **kwargs) 28 | raise ValueError(f"unsupported dimensions: {dims}") 29 | 30 | 31 | def linear(*args, **kwargs): 32 | """ 33 | Create a linear module. 34 | """ 35 | return nn.Linear(*args, **kwargs) 36 | 37 | 38 | def avg_pool_nd(dims, *args, **kwargs): 39 | """ 40 | Create a 1D, 2D, or 3D average pooling module. 41 | """ 42 | if dims == 1: 43 | return nn.AvgPool1d(*args, **kwargs) 44 | elif dims == 2: 45 | return nn.AvgPool2d(*args, **kwargs) 46 | elif dims == 3: 47 | return nn.AvgPool3d(*args, **kwargs) 48 | raise ValueError(f"unsupported dimensions: {dims}") 49 | 50 | 51 | def update_ema(target_params, source_params, rate=0.99): 52 | """ 53 | Update target parameters to be closer to those of source parameters using 54 | an exponential moving average. 55 | 56 | :param target_params: the target parameter sequence. 57 | :param source_params: the source parameter sequence. 58 | :param rate: the EMA rate (closer to 1 means slower). 59 | """ 60 | for targ, src in zip(target_params, source_params): 61 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 62 | 63 | 64 | def zero_module(module): 65 | """ 66 | Zero out the parameters of a module and return it. 67 | """ 68 | for p in module.parameters(): 69 | p.detach().zero_() 70 | return module 71 | 72 | 73 | def scale_module(module, scale): 74 | """ 75 | Scale the parameters of a module and return it. 76 | """ 77 | for p in module.parameters(): 78 | p.detach().mul_(scale) 79 | return module 80 | 81 | 82 | def mean_flat(tensor): 83 | """ 84 | Take the mean over all non-batch dimensions. 85 | """ 86 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 87 | 88 | 89 | def normalization(channels): 90 | """ 91 | Make a standard normalization layer. 92 | 93 | :param channels: number of input channels. 94 | :return: an nn.Module for normalization. 95 | """ 96 | return GroupNorm32(32, channels) 97 | 98 | 99 | def timestep_embedding(timesteps, dim, max_period=10000): 100 | """ 101 | Create sinusoidal timestep embeddings. 102 | 103 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 104 | These may be fractional. 105 | :param dim: the dimension of the output. 106 | :param max_period: controls the minimum frequency of the embeddings. 107 | :return: an [N x dim] Tensor of positional embeddings. 108 | """ 109 | half = dim // 2 110 | freqs = th.exp( 111 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 112 | ).to(device=timesteps.device) 113 | args = timesteps[:, None].float() * freqs[None] 114 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 115 | if dim % 2: 116 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 117 | return embedding 118 | 119 | 120 | def checkpoint(func, inputs, params, flag): 121 | """ 122 | Evaluate a function without caching intermediate activations, allowing for 123 | reduced memory at the expense of extra compute in the backward pass. 124 | 125 | :param func: the function to evaluate. 126 | :param inputs: the argument sequence to pass to `func`. 127 | :param params: a sequence of parameters `func` depends on but does not 128 | explicitly take as arguments. 129 | :param flag: if False, disable gradient checkpointing. 130 | """ 131 | if flag: 132 | args = tuple(inputs) + tuple(params) 133 | return CheckpointFunction.apply(func, len(inputs), *args) 134 | else: 135 | return func(*inputs) 136 | 137 | 138 | class CheckpointFunction(th.autograd.Function): 139 | @staticmethod 140 | def forward(ctx, run_function, length, *args): 141 | ctx.run_function = run_function 142 | ctx.input_tensors = list(args[:length]) 143 | ctx.input_params = list(args[length:]) 144 | with th.no_grad(): 145 | output_tensors = ctx.run_function(*ctx.input_tensors) 146 | return output_tensors 147 | 148 | @staticmethod 149 | def backward(ctx, *output_grads): 150 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 151 | with th.enable_grad(): 152 | # Fixes a bug where the first op in run_function modifies the 153 | # Tensor storage in place, which is not allowed for detach()'d 154 | # Tensors. 155 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 156 | output_tensors = ctx.run_function(*shallow_copies) 157 | input_grads = th.autograd.grad( 158 | output_tensors, 159 | ctx.input_tensors + ctx.input_params, 160 | output_grads, 161 | allow_unused=True, 162 | ) 163 | del ctx.input_tensors 164 | del ctx.input_params 165 | del output_tensors 166 | return (None, None) + input_grads 167 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | for k, v in cond.items() 187 | } 188 | last_batch = (i + self.microbatch) >= batch.shape[0] 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 190 | 191 | compute_losses = functools.partial( 192 | self.diffusion.training_losses, 193 | self.ddp_model, 194 | micro, 195 | t, 196 | model_kwargs=micro_cond, 197 | ) 198 | 199 | if last_batch or not self.use_ddp: 200 | losses = compute_losses() 201 | else: 202 | with self.ddp_model.no_sync(): 203 | losses = compute_losses() 204 | 205 | if isinstance(self.schedule_sampler, LossAwareSampler): 206 | self.schedule_sampler.update_with_local_losses( 207 | t, losses["loss"].detach() 208 | ) 209 | 210 | loss = (losses["loss"] * weights).mean() 211 | log_loss_dict( 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 213 | ) 214 | self.mp_trainer.backward(loss) 215 | 216 | def _update_ema(self): 217 | for rate, params in zip(self.ema_rate, self.ema_params): 218 | update_ema(params, self.mp_trainer.master_params, rate=rate) 219 | 220 | def _anneal_lr(self): 221 | if not self.lr_anneal_steps: 222 | return 223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 224 | lr = self.lr * (1 - frac_done) 225 | for param_group in self.opt.param_groups: 226 | param_group["lr"] = lr 227 | 228 | def log_step(self): 229 | logger.logkv("step", self.step + self.resume_step) 230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 231 | 232 | def save(self): 233 | def save_checkpoint(rate, params): 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 235 | if dist.get_rank() == 0: 236 | logger.log(f"saving model {rate}...") 237 | if not rate: 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" 239 | else: 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 242 | th.save(state_dict, f) 243 | 244 | save_checkpoint(0, self.mp_trainer.master_params) 245 | for rate, params in zip(self.ema_rate, self.ema_params): 246 | save_checkpoint(rate, params) 247 | 248 | if dist.get_rank() == 0: 249 | with bf.BlobFile( 250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 251 | "wb", 252 | ) as f: 253 | th.save(self.opt.state_dict(), f) 254 | 255 | dist.barrier() 256 | 257 | 258 | def parse_resume_step_from_filename(filename): 259 | """ 260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 261 | checkpoint's number of steps. 262 | """ 263 | split = filename.split("model") 264 | if len(split) < 2: 265 | return 0 266 | split1 = split[-1].split(".")[0] 267 | try: 268 | return int(split1) 269 | except ValueError: 270 | return 0 271 | 272 | 273 | def get_blob_logdir(): 274 | # You can change this to be a separate path to save checkpoints to 275 | # a blobstore or some external drive. 276 | return logger.get_dir() 277 | 278 | 279 | def find_resume_checkpoint(): 280 | # On your infrastructure, you may want to override this to automatically 281 | # discover the latest checkpoint on your blob storage, etc. 282 | return None 283 | 284 | 285 | def find_ema_checkpoint(main_checkpoint, step, rate): 286 | if main_checkpoint is None: 287 | return None 288 | filename = f"ema_{rate}_{(step):06d}.pt" 289 | path = bf.join(bf.dirname(main_checkpoint), filename) 290 | if bf.exists(path): 291 | return path 292 | return None 293 | 294 | 295 | def log_loss_dict(diffusion, ts, losses): 296 | for key, values in losses.items(): 297 | logger.logkv_mean(key, values.mean().item()) 298 | # Log the quantiles (four quartiles, in particular). 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 302 | -------------------------------------------------------------------------------- /improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /improved_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | # from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | if th.cuda.is_available(): 49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 50 | return th.device("cpu") 51 | 52 | 53 | def load_state_dict(path, **kwargs): 54 | """ 55 | Load a PyTorch file without redundant fetches across MPI ranks. 56 | """ 57 | if 0 == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | else: 61 | data = None 62 | # data = MPI.COMM_WORLD.bcast(data) 63 | return th.load(io.BytesIO(data), **kwargs) 64 | 65 | 66 | def sync_params(params): 67 | """ 68 | Synchronize a sequence of Tensors across ranks from rank 0. 69 | """ 70 | for p in params: 71 | with th.no_grad(): 72 | dist.broadcast(p, 0) 73 | 74 | 75 | def _find_free_port(): 76 | try: 77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 78 | s.bind(("", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | finally: 82 | s.close() 83 | -------------------------------------------------------------------------------- /improved_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /improved_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def load_data( 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 10 | ): 11 | """ 12 | For a dataset, create a generator over (images, kwargs) pairs. 13 | 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 15 | more keys, each of which map to a batched Tensor of their own. 16 | The kwargs dict can be used for class labels, in which case the key is "y" 17 | and the values are integer tensors of class labels. 18 | 19 | :param data_dir: a dataset directory. 20 | :param batch_size: the batch size of each returned pair. 21 | :param image_size: the size to which images are resized. 22 | :param class_cond: if True, include a "y" key in returned dicts for class 23 | label. If classes are not available and this is true, an 24 | exception will be raised. 25 | :param deterministic: if True, yield results in a deterministic order. 26 | """ 27 | if not data_dir: 28 | raise ValueError("unspecified data directory") 29 | all_files = _list_image_files_recursively(data_dir) 30 | classes = None 31 | if class_cond: 32 | # Assume classes are the first part of the filename, 33 | # before an underscore. 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 36 | classes = [sorted_classes[x] for x in class_names] 37 | dataset = ImageDataset( 38 | image_size, 39 | all_files, 40 | classes=classes, 41 | shard=MPI.COMM_WORLD.Get_rank(), 42 | num_shards=MPI.COMM_WORLD.Get_size(), 43 | ) 44 | if deterministic: 45 | loader = DataLoader( 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 47 | ) 48 | else: 49 | loader = DataLoader( 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 51 | ) 52 | while True: 53 | yield from loader 54 | 55 | 56 | def _list_image_files_recursively(data_dir): 57 | results = [] 58 | for entry in sorted(bf.listdir(data_dir)): 59 | full_path = bf.join(data_dir, entry) 60 | ext = entry.split(".")[-1] 61 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 62 | results.append(full_path) 63 | elif bf.isdir(full_path): 64 | results.extend(_list_image_files_recursively(full_path)) 65 | return results 66 | 67 | 68 | class ImageDataset(Dataset): 69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1): 70 | super().__init__() 71 | self.resolution = resolution 72 | self.local_images = image_paths[shard:][::num_shards] 73 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 74 | 75 | def __len__(self): 76 | return len(self.local_images) 77 | 78 | def __getitem__(self, idx): 79 | path = self.local_images[idx] 80 | with bf.BlobFile(path, "rb") as f: 81 | pil_image = Image.open(f) 82 | pil_image.load() 83 | 84 | # We are not on a new enough PIL to support the `reducing_gap` 85 | # argument, which uses BOX downsampling at powers of two first. 86 | # Thus, we do it by hand to improve downsample quality. 87 | while min(*pil_image.size) >= 2 * self.resolution: 88 | pil_image = pil_image.resize( 89 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 90 | ) 91 | 92 | scale = self.resolution / min(*pil_image.size) 93 | pil_image = pil_image.resize( 94 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 95 | ) 96 | 97 | arr = np.array(pil_image.convert("RGB")) 98 | crop_y = (arr.shape[0] - self.resolution) // 2 99 | crop_x = (arr.shape[1] - self.resolution) // 2 100 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 101 | arr = arr.astype(np.float32) / 127.5 - 1 102 | 103 | out_dict = {} 104 | if self.local_classes is not None: 105 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 106 | return np.transpose(arr, [2, 0, 1]), out_dict 107 | -------------------------------------------------------------------------------- /improved_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch as th 4 | 5 | 6 | def normal_kl(mean1, logvar1, mean2, logvar2): 7 | """ 8 | Compute the KL divergence between two gaussians. 9 | 10 | Shapes are automatically broadcasted, so batches can be compared to 11 | scalars, among other use cases. 12 | """ 13 | tensor = None 14 | for obj in (mean1, logvar1, mean2, logvar2): 15 | if isinstance(obj, th.Tensor): 16 | tensor = obj 17 | break 18 | assert tensor is not None, "at least one argument must be a Tensor" 19 | 20 | # Force variances to be Tensors. Broadcasting helps convert scalars to 21 | # Tensors, but it does not work for th.exp(). 22 | logvar1, logvar2 = [ 23 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 24 | for x in (logvar1, logvar2) 25 | ] 26 | 27 | return 0.5 * ( 28 | -1.0 29 | + logvar2 30 | - logvar1 31 | + th.exp(logvar1 - logvar2) 32 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 33 | ) 34 | 35 | 36 | def approx_standard_normal_cdf(x): 37 | """ 38 | A fast approximation of the cumulative distribution function of the 39 | standard normal. 40 | """ 41 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 42 | 43 | 44 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 45 | """ 46 | Compute the log-likelihood of a Gaussian distribution discretizing to a 47 | given image. 48 | 49 | :param x: the target images. It is assumed that this was uint8 values, 50 | rescaled to the range [-1, 1]. 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | assert x.shape == means.shape == log_scales.shape 56 | centered_x = x - means 57 | inv_stdv = th.exp(-log_scales) 58 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 59 | cdf_plus = approx_standard_normal_cdf(plus_in) 60 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 61 | cdf_min = approx_standard_normal_cdf(min_in) 62 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 63 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 64 | cdf_delta = cdf_plus - cdf_min 65 | log_probs = th.where( 66 | x < -0.999, 67 | log_cdf_plus, 68 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 69 | ) 70 | assert log_probs.shape == x.shape 71 | return log_probs 72 | -------------------------------------------------------------------------------- /improved_diffusion/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lenet import * 2 | from .vggnet import * 3 | from .resnet import * 4 | from .wide_resnet import * 5 | -------------------------------------------------------------------------------- /improved_diffusion/networks/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def conv_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 8 | init.constant(m.bias, 0) 9 | 10 | class LeNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(LeNet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 6, 5) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(16*5*5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x): 20 | out = F.relu(self.conv1(x)) 21 | out = F.max_pool2d(out, 2) 22 | out = F.relu(self.conv2(out)) 23 | out = F.max_pool2d(out, 2) 24 | out = out.view(out.size(0), -1) 25 | out = F.relu(self.fc1(out)) 26 | out = F.relu(self.fc2(out)) 27 | out = self.fc3(out) 28 | 29 | return(out) 30 | -------------------------------------------------------------------------------- /improved_diffusion/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | import sys 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | 17 | def cfg(depth): 18 | depth_lst = [18, 34, 50, 101, 152] 19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 20 | cf_dict = { 21 | '18': (BasicBlock, [2,2,2,2]), 22 | '34': (BasicBlock, [3,4,6,3]), 23 | '50': (Bottleneck, [3,4,6,3]), 24 | '101':(Bottleneck, [3,4,23,3]), 25 | '152':(Bottleneck, [3,8,36,3]), 26 | } 27 | 28 | return cf_dict[str(depth)] 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | 53 | return out 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, stride=1): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True) 65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion*planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 71 | nn.BatchNorm2d(self.expansion*planes) 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, depth, num_classes): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 16 87 | 88 | block, num_blocks = cfg(depth) 89 | 90 | self.conv1 = conv3x3(3,16) 91 | self.bn1 = nn.BatchNorm2d(16) 92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 95 | self.linear = nn.Linear(64*block.expansion, num_classes) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1]*(num_blocks-1) 99 | layers = [] 100 | 101 | for stride in strides: 102 | layers.append(block(self.in_planes, planes, stride)) 103 | self.in_planes = planes * block.expansion 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | out = F.relu(self.bn1(self.conv1(x))) 109 | out = self.layer1(out) 110 | out = self.layer2(out) 111 | out = self.layer3(out) 112 | out = F.avg_pool2d(out, 8) 113 | out = out.view(out.size(0), -1) 114 | out = self.linear(out) 115 | 116 | return out 117 | 118 | if __name__ == '__main__': 119 | net=ResNet(50, 10) 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | -------------------------------------------------------------------------------- /improved_diffusion/networks/vggnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | def conv_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 9 | init.constant(m.bias, 0) 10 | 11 | def cfg(depth): 12 | depth_lst = [11, 13, 16, 19] 13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19" 14 | cf_dict = { 15 | '11': [ 16 | 64, 'mp', 17 | 128, 'mp', 18 | 256, 256, 'mp', 19 | 512, 512, 'mp', 20 | 512, 512, 'mp'], 21 | '13': [ 22 | 64, 64, 'mp', 23 | 128, 128, 'mp', 24 | 256, 256, 'mp', 25 | 512, 512, 'mp', 26 | 512, 512, 'mp' 27 | ], 28 | '16': [ 29 | 64, 64, 'mp', 30 | 128, 128, 'mp', 31 | 256, 256, 256, 'mp', 32 | 512, 512, 512, 'mp', 33 | 512, 512, 512, 'mp' 34 | ], 35 | '19': [ 36 | 64, 64, 'mp', 37 | 128, 128, 'mp', 38 | 256, 256, 256, 256, 'mp', 39 | 512, 512, 512, 512, 'mp', 40 | 512, 512, 512, 512, 'mp' 41 | ], 42 | } 43 | 44 | return cf_dict[str(depth)] 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 48 | 49 | class VGG(nn.Module): 50 | def __init__(self, depth, num_classes): 51 | super(VGG, self).__init__() 52 | self.features = self._make_layers(cfg(depth)) 53 | self.linear = nn.Linear(512, num_classes) 54 | 55 | def forward(self, x): 56 | out = self.features(x) 57 | out = out.view(out.size(0), -1) 58 | out = self.linear(out) 59 | 60 | return out 61 | 62 | def _make_layers(self, cfg): 63 | layers = [] 64 | in_planes = 3 65 | 66 | for x in cfg: 67 | if x == 'mp': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] 71 | in_planes = x 72 | 73 | # After cfg convolution 74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 75 | return nn.Sequential(*layers) 76 | 77 | if __name__ == "__main__": 78 | net = VGG(16, 10) 79 | y = net(Variable(torch.randn(1,3,32,32))) 80 | print(y.size()) 81 | -------------------------------------------------------------------------------- /improved_diffusion/networks/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | if __name__ == '__main__': 86 | net=Wide_ResNet(28, 10, 0.3, 10) 87 | y = net(Variable(torch.randn(1,3,32,32))) 88 | 89 | print(y.size()) 90 | -------------------------------------------------------------------------------- /improved_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /improved_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /improved_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel( 102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 103 | ) 104 | 105 | def _scale_timesteps(self, t): 106 | # Scaling is done by the wrapped model. 107 | return t 108 | 109 | 110 | class _WrappedModel: 111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 112 | self.model = model 113 | self.timestep_map = timestep_map 114 | self.rescale_timesteps = rescale_timesteps 115 | self.original_num_steps = original_num_steps 116 | 117 | def __call__(self, x, ts, **kwargs): 118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 119 | new_ts = map_tensor[ts] 120 | if self.rescale_timesteps: 121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 122 | return self.model(x, new_ts, **kwargs) 123 | -------------------------------------------------------------------------------- /improved_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def model_and_diffusion_defaults(t_total): 12 | """ 13 | Defaults for image training. 14 | """ 15 | return dict( 16 | image_size=32, 17 | num_channels=128, 18 | num_res_blocks=3, 19 | num_heads=4, 20 | num_heads_upsample=-1, 21 | attention_resolutions="16,8", 22 | dropout=0.3, 23 | learn_sigma=True, 24 | sigma_small=False, 25 | class_cond=False, 26 | diffusion_steps=t_total, 27 | noise_schedule="cosine", 28 | timestep_respacing="", 29 | use_kl=False, 30 | predict_xstart=False, 31 | rescale_timesteps=True, 32 | rescale_learned_sigmas=True, 33 | use_checkpoint=False, 34 | use_scale_shift_norm=True, 35 | ) 36 | 37 | 38 | def create_model_and_diffusion( 39 | image_size, 40 | class_cond, 41 | learn_sigma, 42 | sigma_small, 43 | num_channels, 44 | num_res_blocks, 45 | num_heads, 46 | num_heads_upsample, 47 | attention_resolutions, 48 | dropout, 49 | diffusion_steps, 50 | noise_schedule, 51 | timestep_respacing, 52 | use_kl, 53 | predict_xstart, 54 | rescale_timesteps, 55 | rescale_learned_sigmas, 56 | use_checkpoint, 57 | use_scale_shift_norm, 58 | ): 59 | model = create_model( 60 | image_size, 61 | num_channels, 62 | num_res_blocks, 63 | learn_sigma=learn_sigma, 64 | class_cond=class_cond, 65 | use_checkpoint=use_checkpoint, 66 | attention_resolutions=attention_resolutions, 67 | num_heads=num_heads, 68 | num_heads_upsample=num_heads_upsample, 69 | use_scale_shift_norm=use_scale_shift_norm, 70 | dropout=dropout, 71 | ) 72 | diffusion = create_gaussian_diffusion( 73 | steps=diffusion_steps, 74 | learn_sigma=learn_sigma, 75 | sigma_small=sigma_small, 76 | noise_schedule=noise_schedule, 77 | use_kl=use_kl, 78 | predict_xstart=predict_xstart, 79 | rescale_timesteps=rescale_timesteps, 80 | rescale_learned_sigmas=rescale_learned_sigmas, 81 | timestep_respacing=timestep_respacing, 82 | ) 83 | return model, diffusion 84 | 85 | 86 | def create_model( 87 | image_size, 88 | num_channels, 89 | num_res_blocks, 90 | learn_sigma, 91 | class_cond, 92 | use_checkpoint, 93 | attention_resolutions, 94 | num_heads, 95 | num_heads_upsample, 96 | use_scale_shift_norm, 97 | dropout, 98 | ): 99 | if image_size == 256: 100 | channel_mult = (1, 1, 2, 2, 4, 4) 101 | elif image_size == 64: 102 | channel_mult = (1, 2, 3, 4) 103 | elif image_size == 32: 104 | channel_mult = (1, 2, 2, 2) 105 | else: 106 | raise ValueError(f"unsupported image size: {image_size}") 107 | 108 | attention_ds = [] 109 | for res in attention_resolutions.split(","): 110 | attention_ds.append(image_size // int(res)) 111 | 112 | return UNetModel( 113 | in_channels=3, 114 | model_channels=num_channels, 115 | out_channels=(3 if not learn_sigma else 6), 116 | num_res_blocks=num_res_blocks, 117 | attention_resolutions=tuple(attention_ds), 118 | dropout=dropout, 119 | channel_mult=channel_mult, 120 | num_classes=(NUM_CLASSES if class_cond else None), 121 | use_checkpoint=use_checkpoint, 122 | num_heads=num_heads, 123 | num_heads_upsample=num_heads_upsample, 124 | use_scale_shift_norm=use_scale_shift_norm, 125 | ) 126 | 127 | 128 | def sr_model_and_diffusion_defaults(): 129 | res = model_and_diffusion_defaults() 130 | res["large_size"] = 256 131 | res["small_size"] = 64 132 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 133 | for k in res.copy().keys(): 134 | if k not in arg_names: 135 | del res[k] 136 | return res 137 | 138 | 139 | def sr_create_model_and_diffusion( 140 | large_size, 141 | small_size, 142 | class_cond, 143 | learn_sigma, 144 | num_channels, 145 | num_res_blocks, 146 | num_heads, 147 | num_heads_upsample, 148 | attention_resolutions, 149 | dropout, 150 | diffusion_steps, 151 | noise_schedule, 152 | timestep_respacing, 153 | use_kl, 154 | predict_xstart, 155 | rescale_timesteps, 156 | rescale_learned_sigmas, 157 | use_checkpoint, 158 | use_scale_shift_norm, 159 | ): 160 | model = sr_create_model( 161 | large_size, 162 | small_size, 163 | num_channels, 164 | num_res_blocks, 165 | learn_sigma=learn_sigma, 166 | class_cond=class_cond, 167 | use_checkpoint=use_checkpoint, 168 | attention_resolutions=attention_resolutions, 169 | num_heads=num_heads, 170 | num_heads_upsample=num_heads_upsample, 171 | use_scale_shift_norm=use_scale_shift_norm, 172 | dropout=dropout, 173 | ) 174 | diffusion = create_gaussian_diffusion( 175 | steps=diffusion_steps, 176 | learn_sigma=learn_sigma, 177 | noise_schedule=noise_schedule, 178 | use_kl=use_kl, 179 | predict_xstart=predict_xstart, 180 | rescale_timesteps=rescale_timesteps, 181 | rescale_learned_sigmas=rescale_learned_sigmas, 182 | timestep_respacing=timestep_respacing, 183 | ) 184 | return model, diffusion 185 | 186 | 187 | def sr_create_model( 188 | large_size, 189 | small_size, 190 | num_channels, 191 | num_res_blocks, 192 | learn_sigma, 193 | class_cond, 194 | use_checkpoint, 195 | attention_resolutions, 196 | num_heads, 197 | num_heads_upsample, 198 | use_scale_shift_norm, 199 | dropout, 200 | ): 201 | _ = small_size # hack to prevent unused variable 202 | 203 | if large_size == 256: 204 | channel_mult = (1, 1, 2, 2, 4, 4) 205 | elif large_size == 64: 206 | channel_mult = (1, 2, 3, 4) 207 | else: 208 | raise ValueError(f"unsupported large size: {large_size}") 209 | 210 | attention_ds = [] 211 | for res in attention_resolutions.split(","): 212 | attention_ds.append(large_size // int(res)) 213 | 214 | return SuperResModel( 215 | in_channels=3, 216 | model_channels=num_channels, 217 | out_channels=(3 if not learn_sigma else 6), 218 | num_res_blocks=num_res_blocks, 219 | attention_resolutions=tuple(attention_ds), 220 | dropout=dropout, 221 | channel_mult=channel_mult, 222 | num_classes=(NUM_CLASSES if class_cond else None), 223 | use_checkpoint=use_checkpoint, 224 | num_heads=num_heads, 225 | num_heads_upsample=num_heads_upsample, 226 | use_scale_shift_norm=use_scale_shift_norm, 227 | ) 228 | 229 | 230 | def create_gaussian_diffusion( 231 | *, 232 | steps=1000, 233 | learn_sigma=False, 234 | sigma_small=False, 235 | noise_schedule="linear", 236 | use_kl=False, 237 | predict_xstart=False, 238 | rescale_timesteps=False, 239 | rescale_learned_sigmas=False, 240 | timestep_respacing="", 241 | ): 242 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 243 | if use_kl: 244 | loss_type = gd.LossType.RESCALED_KL 245 | elif rescale_learned_sigmas: 246 | loss_type = gd.LossType.RESCALED_MSE 247 | else: 248 | loss_type = gd.LossType.MSE 249 | if not timestep_respacing: 250 | timestep_respacing = [steps] 251 | return SpacedDiffusion( 252 | use_timesteps=space_timesteps(steps, timestep_respacing), 253 | betas=betas, 254 | model_mean_type=( 255 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 256 | ), 257 | model_var_type=( 258 | ( 259 | gd.ModelVarType.FIXED_LARGE 260 | if not sigma_small 261 | else gd.ModelVarType.FIXED_SMALL 262 | ) 263 | if not learn_sigma 264 | else gd.ModelVarType.LEARNED_RANGE 265 | ), 266 | loss_type=loss_type, 267 | rescale_timesteps=rescale_timesteps, 268 | ) 269 | 270 | 271 | def add_dict_to_argparser(parser, default_dict): 272 | for k, v in default_dict.items(): 273 | v_type = type(v) 274 | if v is None: 275 | v_type = str 276 | elif isinstance(v, bool): 277 | v_type = str2bool 278 | parser.add_argument(f"--{k}", default=v, type=v_type) 279 | 280 | 281 | def args_to_dict(args, keys): 282 | return {k: getattr(args, k) for k in keys} 283 | 284 | 285 | def str2bool(v): 286 | """ 287 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 288 | """ 289 | if isinstance(v, bool): 290 | return v 291 | if v.lower() in ("yes", "true", "t", "y", "1"): 292 | return True 293 | elif v.lower() in ("no", "false", "f", "n", "0"): 294 | return False 295 | else: 296 | raise argparse.ArgumentTypeError("boolean value expected") 297 | -------------------------------------------------------------------------------- /improved_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | 12 | from . import dist_util, logger 13 | from .fp16_util import ( 14 | make_master_params, 15 | master_params_to_model_params, 16 | model_grads_to_master_grads, 17 | unflatten_master_params, 18 | zero_grad, 19 | ) 20 | from .nn import update_ema 21 | from .resample import LossAwareSampler, UniformSampler 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | 28 | 29 | class TrainLoop: 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | use_fp16=False, 44 | fp16_scale_growth=1e-3, 45 | schedule_sampler=None, 46 | weight_decay=0.0, 47 | lr_anneal_steps=0, 48 | ): 49 | self.model = model 50 | self.diffusion = diffusion 51 | self.data = data 52 | self.batch_size = batch_size 53 | self.microbatch = microbatch if microbatch > 0 else batch_size 54 | self.lr = lr 55 | self.ema_rate = ( 56 | [ema_rate] 57 | if isinstance(ema_rate, float) 58 | else [float(x) for x in ema_rate.split(",")] 59 | ) 60 | self.log_interval = log_interval 61 | self.save_interval = save_interval 62 | self.resume_checkpoint = resume_checkpoint 63 | self.use_fp16 = use_fp16 64 | self.fp16_scale_growth = fp16_scale_growth 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 66 | self.weight_decay = weight_decay 67 | self.lr_anneal_steps = lr_anneal_steps 68 | 69 | self.step = 0 70 | self.resume_step = 0 71 | self.global_batch = self.batch_size * dist.get_world_size() 72 | 73 | self.model_params = list(self.model.parameters()) 74 | self.master_params = self.model_params 75 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 76 | self.sync_cuda = th.cuda.is_available() 77 | 78 | self._load_and_sync_parameters() 79 | if self.use_fp16: 80 | self._setup_fp16() 81 | 82 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 83 | if self.resume_step: 84 | self._load_optimizer_state() 85 | # Model was resumed, either due to a restart or a checkpoint 86 | # being specified at the command line. 87 | self.ema_params = [ 88 | self._load_ema_parameters(rate) for rate in self.ema_rate 89 | ] 90 | else: 91 | self.ema_params = [ 92 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 93 | ] 94 | 95 | if th.cuda.is_available(): 96 | self.use_ddp = True 97 | self.ddp_model = DDP( 98 | self.model, 99 | device_ids=[dist_util.dev()], 100 | output_device=dist_util.dev(), 101 | broadcast_buffers=False, 102 | bucket_cap_mb=128, 103 | find_unused_parameters=False, 104 | ) 105 | else: 106 | if dist.get_world_size() > 1: 107 | logger.warn( 108 | "Distributed training requires CUDA. " 109 | "Gradients will not be synchronized properly!" 110 | ) 111 | self.use_ddp = False 112 | self.ddp_model = self.model 113 | 114 | def _load_and_sync_parameters(self): 115 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 116 | 117 | if resume_checkpoint: 118 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 119 | if dist.get_rank() == 0: 120 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 121 | self.model.load_state_dict( 122 | dist_util.load_state_dict( 123 | resume_checkpoint, map_location=dist_util.dev() 124 | ) 125 | ) 126 | 127 | dist_util.sync_params(self.model.parameters()) 128 | 129 | def _load_ema_parameters(self, rate): 130 | ema_params = copy.deepcopy(self.master_params) 131 | 132 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 133 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 134 | if ema_checkpoint: 135 | if dist.get_rank() == 0: 136 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 137 | state_dict = dist_util.load_state_dict( 138 | ema_checkpoint, map_location=dist_util.dev() 139 | ) 140 | ema_params = self._state_dict_to_master_params(state_dict) 141 | 142 | dist_util.sync_params(ema_params) 143 | return ema_params 144 | 145 | def _load_optimizer_state(self): 146 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 147 | opt_checkpoint = bf.join( 148 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 149 | ) 150 | if bf.exists(opt_checkpoint): 151 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 152 | state_dict = dist_util.load_state_dict( 153 | opt_checkpoint, map_location=dist_util.dev() 154 | ) 155 | self.opt.load_state_dict(state_dict) 156 | 157 | def _setup_fp16(self): 158 | self.master_params = make_master_params(self.model_params) 159 | self.model.convert_to_fp16() 160 | 161 | def run_loop(self): 162 | while ( 163 | not self.lr_anneal_steps 164 | or self.step + self.resume_step < self.lr_anneal_steps 165 | ): 166 | batch, cond = next(self.data) 167 | self.run_step(batch, cond) 168 | if self.step % self.log_interval == 0: 169 | logger.dumpkvs() 170 | if self.step % self.save_interval == 0: 171 | self.save() 172 | # Run for a finite amount of time in integration tests. 173 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 174 | return 175 | self.step += 1 176 | # Save the last checkpoint if it wasn't already saved. 177 | if (self.step - 1) % self.save_interval != 0: 178 | self.save() 179 | 180 | def run_step(self, batch, cond): 181 | self.forward_backward(batch, cond) 182 | if self.use_fp16: 183 | self.optimize_fp16() 184 | else: 185 | self.optimize_normal() 186 | self.log_step() 187 | 188 | def forward_backward(self, batch, cond): 189 | zero_grad(self.model_params) 190 | for i in range(0, batch.shape[0], self.microbatch): 191 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 192 | micro_cond = { 193 | k: v[i : i + self.microbatch].to(dist_util.dev()) 194 | for k, v in cond.items() 195 | } 196 | last_batch = (i + self.microbatch) >= batch.shape[0] 197 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 198 | 199 | compute_losses = functools.partial( 200 | self.diffusion.training_losses, 201 | self.ddp_model, 202 | micro, 203 | t, 204 | model_kwargs=micro_cond, 205 | ) 206 | 207 | if last_batch or not self.use_ddp: 208 | losses = compute_losses() 209 | else: 210 | with self.ddp_model.no_sync(): 211 | losses = compute_losses() 212 | 213 | if isinstance(self.schedule_sampler, LossAwareSampler): 214 | self.schedule_sampler.update_with_local_losses( 215 | t, losses["loss"].detach() 216 | ) 217 | 218 | loss = (losses["loss"] * weights).mean() 219 | log_loss_dict( 220 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 221 | ) 222 | if self.use_fp16: 223 | loss_scale = 2 ** self.lg_loss_scale 224 | (loss * loss_scale).backward() 225 | else: 226 | loss.backward() 227 | 228 | def optimize_fp16(self): 229 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 230 | self.lg_loss_scale -= 1 231 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 232 | return 233 | 234 | model_grads_to_master_grads(self.model_params, self.master_params) 235 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 236 | self._log_grad_norm() 237 | self._anneal_lr() 238 | self.opt.step() 239 | for rate, params in zip(self.ema_rate, self.ema_params): 240 | update_ema(params, self.master_params, rate=rate) 241 | master_params_to_model_params(self.model_params, self.master_params) 242 | self.lg_loss_scale += self.fp16_scale_growth 243 | 244 | def optimize_normal(self): 245 | self._log_grad_norm() 246 | self._anneal_lr() 247 | self.opt.step() 248 | for rate, params in zip(self.ema_rate, self.ema_params): 249 | update_ema(params, self.master_params, rate=rate) 250 | 251 | def _log_grad_norm(self): 252 | sqsum = 0.0 253 | for p in self.master_params: 254 | sqsum += (p.grad ** 2).sum().item() 255 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 256 | 257 | def _anneal_lr(self): 258 | if not self.lr_anneal_steps: 259 | return 260 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 261 | lr = self.lr * (1 - frac_done) 262 | for param_group in self.opt.param_groups: 263 | param_group["lr"] = lr 264 | 265 | def log_step(self): 266 | logger.logkv("step", self.step + self.resume_step) 267 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 268 | if self.use_fp16: 269 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 270 | 271 | def save(self): 272 | def save_checkpoint(rate, params): 273 | state_dict = self._master_params_to_state_dict(params) 274 | if dist.get_rank() == 0: 275 | logger.log(f"saving model {rate}...") 276 | if not rate: 277 | filename = f"model{(self.step+self.resume_step):06d}.pt" 278 | else: 279 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 280 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 281 | th.save(state_dict, f) 282 | 283 | save_checkpoint(0, self.master_params) 284 | for rate, params in zip(self.ema_rate, self.ema_params): 285 | save_checkpoint(rate, params) 286 | 287 | if dist.get_rank() == 0: 288 | with bf.BlobFile( 289 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 290 | "wb", 291 | ) as f: 292 | th.save(self.opt.state_dict(), f) 293 | 294 | dist.barrier() 295 | 296 | def _master_params_to_state_dict(self, master_params): 297 | if self.use_fp16: 298 | master_params = unflatten_master_params( 299 | self.model.parameters(), master_params 300 | ) 301 | state_dict = self.model.state_dict() 302 | for i, (name, _value) in enumerate(self.model.named_parameters()): 303 | assert name in state_dict 304 | state_dict[name] = master_params[i] 305 | return state_dict 306 | 307 | def _state_dict_to_master_params(self, state_dict): 308 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 309 | if self.use_fp16: 310 | return make_master_params(params) 311 | else: 312 | return params 313 | 314 | 315 | def parse_resume_step_from_filename(filename): 316 | """ 317 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 318 | checkpoint's number of steps. 319 | """ 320 | split = filename.split("model") 321 | if len(split) < 2: 322 | return 0 323 | split1 = split[-1].split(".")[0] 324 | try: 325 | return int(split1) 326 | except ValueError: 327 | return 0 328 | 329 | 330 | def get_blob_logdir(): 331 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 332 | 333 | 334 | def find_resume_checkpoint(): 335 | # On your infrastructure, you may want to override this to automatically 336 | # discover the latest checkpoint on your blob storage, etc. 337 | return None 338 | 339 | 340 | def find_ema_checkpoint(main_checkpoint, step, rate): 341 | if main_checkpoint is None: 342 | return None 343 | filename = f"ema_{rate}_{(step):06d}.pt" 344 | path = bf.join(bf.dirname(main_checkpoint), filename) 345 | if bf.exists(path): 346 | return path 347 | return None 348 | 349 | 350 | def log_loss_dict(diffusion, ts, losses): 351 | for key, values in losses.items(): 352 | logger.logkv_mean(key, values.mean().item()) 353 | # Log the quantiles (four quartiles, in particular). 354 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 355 | quartile = int(4 * sub_t / diffusion.num_timesteps) 356 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 357 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lenet import * 2 | from .vggnet import * 3 | from .resnet import * 4 | from .wide_resnet import * 5 | -------------------------------------------------------------------------------- /networks/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def conv_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 8 | init.constant(m.bias, 0) 9 | 10 | class LeNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(LeNet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 6, 5) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(16*5*5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x): 20 | out = F.relu(self.conv1(x)) 21 | out = F.max_pool2d(out, 2) 22 | out = F.relu(self.conv2(out)) 23 | out = F.max_pool2d(out, 2) 24 | out = out.view(out.size(0), -1) 25 | out = F.relu(self.fc1(out)) 26 | out = F.relu(self.fc2(out)) 27 | out = self.fc3(out) 28 | 29 | return(out) 30 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | import sys 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | 17 | def cfg(depth): 18 | depth_lst = [18, 34, 50, 101, 152] 19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 20 | cf_dict = { 21 | '18': (BasicBlock, [2,2,2,2]), 22 | '34': (BasicBlock, [3,4,6,3]), 23 | '50': (Bottleneck, [3,4,6,3]), 24 | '101':(Bottleneck, [3,4,23,3]), 25 | '152':(Bottleneck, [3,8,36,3]), 26 | } 27 | 28 | return cf_dict[str(depth)] 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | 53 | return out 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, stride=1): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True) 65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion*planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 71 | nn.BatchNorm2d(self.expansion*planes) 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, depth, num_classes): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 16 87 | 88 | block, num_blocks = cfg(depth) 89 | 90 | self.conv1 = conv3x3(3,16) 91 | self.bn1 = nn.BatchNorm2d(16) 92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 95 | self.linear = nn.Linear(64*block.expansion, num_classes) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1]*(num_blocks-1) 99 | layers = [] 100 | 101 | for stride in strides: 102 | layers.append(block(self.in_planes, planes, stride)) 103 | self.in_planes = planes * block.expansion 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | out = F.relu(self.bn1(self.conv1(x))) 109 | out = self.layer1(out) 110 | out = self.layer2(out) 111 | out = self.layer3(out) 112 | out = F.avg_pool2d(out, 8) 113 | out = out.view(out.size(0), -1) 114 | out = self.linear(out) 115 | 116 | return out 117 | 118 | if __name__ == '__main__': 119 | net=ResNet(50, 10) 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | -------------------------------------------------------------------------------- /networks/vggnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | def conv_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 9 | init.constant(m.bias, 0) 10 | 11 | def cfg(depth): 12 | depth_lst = [11, 13, 16, 19] 13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19" 14 | cf_dict = { 15 | '11': [ 16 | 64, 'mp', 17 | 128, 'mp', 18 | 256, 256, 'mp', 19 | 512, 512, 'mp', 20 | 512, 512, 'mp'], 21 | '13': [ 22 | 64, 64, 'mp', 23 | 128, 128, 'mp', 24 | 256, 256, 'mp', 25 | 512, 512, 'mp', 26 | 512, 512, 'mp' 27 | ], 28 | '16': [ 29 | 64, 64, 'mp', 30 | 128, 128, 'mp', 31 | 256, 256, 256, 'mp', 32 | 512, 512, 512, 'mp', 33 | 512, 512, 512, 'mp' 34 | ], 35 | '19': [ 36 | 64, 64, 'mp', 37 | 128, 128, 'mp', 38 | 256, 256, 256, 256, 'mp', 39 | 512, 512, 512, 512, 'mp', 40 | 512, 512, 512, 512, 'mp' 41 | ], 42 | } 43 | 44 | return cf_dict[str(depth)] 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 48 | 49 | class VGG(nn.Module): 50 | def __init__(self, depth, num_classes): 51 | super(VGG, self).__init__() 52 | self.features = self._make_layers(cfg(depth)) 53 | self.linear = nn.Linear(512, num_classes) 54 | 55 | def forward(self, x): 56 | out = self.features(x) 57 | out = out.view(out.size(0), -1) 58 | out = self.linear(out) 59 | 60 | return out 61 | 62 | def _make_layers(self, cfg): 63 | layers = [] 64 | in_planes = 3 65 | 66 | for x in cfg: 67 | if x == 'mp': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] 71 | in_planes = x 72 | 73 | # After cfg convolution 74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 75 | return nn.Sequential(*layers) 76 | 77 | if __name__ == "__main__": 78 | net = VGG(16, 10) 79 | y = net(Variable(torch.randn(1,3,32,32))) 80 | print(y.size()) 81 | -------------------------------------------------------------------------------- /networks/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | if __name__ == '__main__': 86 | net=Wide_ResNet(28, 10, 0.3, 10) 87 | y = net(Variable(torch.randn(1,3,32,32))) 88 | 89 | print(y.size()) 90 | -------------------------------------------------------------------------------- /pictures/densepure_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayfeather1024/DensePure/1ed105d13a6ccfaf34a4fc240609dae89abc6a0d/pictures/densepure_flowchart.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.4 2 | pyyaml==5.3.1 3 | wheel==0.34.2 4 | scipy==1.5.2 5 | pillow==7.2.0 6 | matplotlib==3.3.0 7 | tqdm==4.56.1 8 | tensorboardX==2.0 9 | seaborn==0.10.1 10 | pandas==1.2.0 11 | requests==2.25.0 12 | xvfbwrapper==0.2.9 13 | torchdiffeq==0.2.1 14 | timm==0.5.4 15 | lmdb 16 | Ninja 17 | foolbox 18 | torchsde 19 | statsmodels 20 | transformers 21 | blobfile 22 | mpi4py -------------------------------------------------------------------------------- /results/merge_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sigma=$1 4 | steps=$2 5 | majority_vote_num=$3 6 | 7 | python merge_results.py \ 8 | --sample_id_list $(seq -s ' ' 0 20 9980) \ 9 | --sample_num 500 \ 10 | --majority_vote_num $majority_vote_num \ 11 | --N 100000 \ 12 | --N0 100 \ 13 | --sigma $sigma \ 14 | --classes_num 10 \ 15 | --datasets cifar10 \ 16 | --steps $steps 17 | -------------------------------------------------------------------------------- /results/merge_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sigma=$1 4 | steps=$2 5 | majority_vote_num=$3 6 | 7 | python merge_results.py \ 8 | --sample_id_list $(seq -s ' ' 0 500 49500) \ 9 | --sample_num 100 \ 10 | --majority_vote_num $majority_vote_num \ 11 | --N 10000 \ 12 | --N0 100 \ 13 | --sigma $sigma \ 14 | --classes_num 1000 \ 15 | --datasets imagenet \ 16 | --steps $steps -------------------------------------------------------------------------------- /results/merge_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | from statsmodels.stats.proportion import proportion_confint 5 | from scipy.stats import norm 6 | 7 | def gain_results(args): 8 | file_merge = open(str(args.datasets)+'_'+str(args.sigma)+'_'+str(args.results_file), 'w') 9 | file_merge.write("idx\tlabel\tpredict\tradius\tcorrect\n") 10 | for sample_id in range(args.sample_num): 11 | 12 | n0_predictions_list = [] 13 | n_predictions_list = [] 14 | for i in range(args.majority_vote_num): 15 | id_file = open(str(args.datasets)+'-densepure-sample_num_'+str(args.N0)+'-noise_'+str(args.sigma)+'-'+str(args.steps)+'-steps-'+str(i), 'r') 16 | lines = id_file.readlines() 17 | line = lines[sample_id+1].split('\t') 18 | label = int(line[1]) 19 | 20 | n0_predictions = np.load('exp/'+str(args.datasets)+'/'+str(args.sigma)+'-'+str(args.sample_id_list[sample_id])+'-'+str(i)+'-n0_predictions.npy') 21 | n_predictions = np.load('exp/'+str(args.datasets)+'/'+str(args.sigma)+'-'+str(args.sample_id_list[sample_id])+'-'+str(i)+'-n_predictions.npy') 22 | 23 | n0_predictions_list.append(n0_predictions) 24 | n_predictions_list.append(n_predictions) 25 | 26 | n0_predictions_list = np.array(n0_predictions_list).T 27 | n_predictions_list = np.array(n_predictions_list).T 28 | count_max_list = np.zeros(args.N0,dtype=int) 29 | 30 | for i in range(args.N0): 31 | count_max = max(list(n0_predictions_list[i]),key=list(n0_predictions_list[i]).count) 32 | count_max_list[i] = count_max 33 | counts = np.zeros(args.classes_num, dtype=int) 34 | for idx in count_max_list: 35 | counts[idx] += 1 36 | prediction = counts.argmax().item() 37 | 38 | count_max_list = np.zeros(args.N,dtype=int) 39 | for i in range(args.N): 40 | count_max = max(list(n_predictions_list[i]),key=list(n_predictions_list[i]).count) 41 | count_max_list[i] = count_max 42 | counts = np.zeros(args.classes_num, dtype=int) 43 | for idx in count_max_list: 44 | counts[idx] += 1 45 | 46 | nA = counts[prediction].item() 47 | pABar = proportion_confint(nA, args.N, alpha=2 * 0.001, method="beta")[0] 48 | if pABar < 0.5: 49 | prediction = -1 50 | radius = 0.0 51 | else: 52 | radius = args.sigma * norm.ppf(pABar) 53 | 54 | correct = int(prediction == label) 55 | 56 | file_merge.write("{}\t{}\t{}\t{:.3}\t{}".format(args.sample_id_list[sample_id], label, prediction, radius, correct)) 57 | file_merge.write("\n") 58 | 59 | 60 | def parse_args(): 61 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 62 | parser.add_argument("--sample_id_list", type=int, nargs='+', default=[0], help="sample id for evaluation") 63 | parser.add_argument('--sample_num', type=int, default=100, help='sample numbers') 64 | parser.add_argument('--majority_vote_num', type=int, default=10, help='majority vote numbers') 65 | parser.add_argument("--N0", type=int, default=100) 66 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use") 67 | parser.add_argument('--sigma', type=float, default=0.25, help='noise hyperparameter') 68 | parser.add_argument('--classes_num', type=int, default=10, help='classes numbers of datasets') 69 | parser.add_argument("--results_file", type=str, default='merge_results.txt', help="output file") 70 | parser.add_argument("--datasets", type=str, default='cifar10', help="cifar10 or imagenet") 71 | parser.add_argument("--steps", type=int, default=2) 72 | 73 | args = parser.parse_args() 74 | return args 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | args = parse_args() 80 | print(args) 81 | gain_results(args) -------------------------------------------------------------------------------- /run_scripts/carlini22_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd .. 3 | 4 | sigma=$1 5 | 6 | python eval_certified_densepure.py \ 7 | --exp exp \ 8 | --config cifar10.yml \ 9 | -i cifar10-carlini22-sample_num_100000-noise_$sigma-1step \ 10 | --domain cifar10 \ 11 | --seed 0 \ 12 | --diffusion_type ddpm \ 13 | --lp_norm L2 \ 14 | --outfile results/cifar10-carlini22-sample_num_100000-noise_$sigma-1step \ 15 | --sigma $sigma \ 16 | --N 100000 \ 17 | --N0 100 \ 18 | --certified_batch 100 \ 19 | --sample_id $(seq -s ' ' 0 20 9980) \ 20 | --use_id \ 21 | --certify_mode purify \ 22 | --advanced_classifier vit \ 23 | --use_one_step -------------------------------------------------------------------------------- /run_scripts/carlini22_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd .. 3 | 4 | sigma=$1 5 | 6 | python eval_certified_densepure.py \ 7 | --exp exp \ 8 | --config imagenet.yml \ 9 | -i imagenet-carlini22-sample_num_10000-noise_$sigma-1step \ 10 | --domain imagenet \ 11 | --seed 0 \ 12 | --diffusion_type guided-ddpm \ 13 | --lp_norm L2 \ 14 | --outfile results/imagenet-carlini22-sample_num_10000-noise_$sigma-1step \ 15 | --sigma $sigma \ 16 | --N 10000 \ 17 | --N0 100 \ 18 | --certified_batch 16 \ 19 | --sample_id $(seq -s ' ' 0 500 49500) \ 20 | --use_id \ 21 | --certify_mode purify \ 22 | --advanced_classifier beit \ 23 | --use_one_step -------------------------------------------------------------------------------- /run_scripts/densepure_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd .. 3 | 4 | sigma=$1 5 | steps=$2 6 | reverse_seed=$3 7 | 8 | python eval_certified_densepure.py \ 9 | --exp exp/cifar10 \ 10 | --config cifar10.yml \ 11 | -i cifar10-densepure-sample_num_100000-noise_$sigma-$steps-$reverse_seed \ 12 | --domain cifar10 \ 13 | --seed 0 \ 14 | --diffusion_type ddpm \ 15 | --lp_norm L2 \ 16 | --outfile results/cifar10-densepure-sample_num_100000-noise_$sigma-$steps-steps-$reverse_seed \ 17 | --sigma $sigma \ 18 | --N 100000 \ 19 | --N0 100 \ 20 | --certified_batch 100 \ 21 | --sample_id $(seq -s ' ' 0 20 9980) \ 22 | --use_id \ 23 | --certify_mode purify \ 24 | --advanced_classifier vit \ 25 | --use_t_steps \ 26 | --num_t_steps $steps \ 27 | --save_predictions \ 28 | --predictions_path exp/cifar10/$sigma- \ 29 | --reverse_seed $reverse_seed -------------------------------------------------------------------------------- /run_scripts/densepure_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd .. 3 | 4 | sigma=$1 5 | steps=$2 6 | reverse_seed=$3 7 | 8 | python eval_certified_densepure.py \ 9 | --exp exp/imagenet \ 10 | --config imagenet.yml \ 11 | -i imagenet-densepure-sample_num_10000-noise_$sigma-$steps-$reverse_seed \ 12 | --domain imagenet \ 13 | --seed 0 \ 14 | --diffusion_type guided-ddpm \ 15 | --lp_norm L2 \ 16 | --outfile imagenet-densepure-sample_num_10000-noise_$sigma-$steps-$reverse_seed \ 17 | --sigma $sigma \ 18 | --N 10000 \ 19 | --N0 100 \ 20 | --certified_batch 16 \ 21 | --sample_id $(seq -s ' ' 0 500 49500) \ 22 | --use_id \ 23 | --certify_mode purify \ 24 | --advanced_classifier beit \ 25 | --use_t_steps \ 26 | --num_t_steps $steps \ 27 | --save_predictions \ 28 | --predictions_path exp/imagenet/$sigma- \ 29 | --reverse_seed $reverse_seed -------------------------------------------------------------------------------- /runners/diffpure_ddpm_densepure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision.utils as tvu 8 | 9 | from improved_diffusion import dist_util, logger 10 | from improved_diffusion.script_util import ( 11 | NUM_CLASSES, 12 | model_and_diffusion_defaults, 13 | create_model_and_diffusion, 14 | add_dict_to_argparser, 15 | args_to_dict, 16 | ) 17 | from improved_diffusion import gaussian_diffusion 18 | 19 | import math 20 | 21 | 22 | def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps): 23 | betas = np.linspace(beta_start, beta_end, 24 | num_diffusion_timesteps, dtype=np.float64) 25 | assert betas.shape == (num_diffusion_timesteps,) 26 | return betas 27 | 28 | 29 | def extract(a, t, x_shape): 30 | """Extract coefficients from a based on t and reshape to make it 31 | broadcastable with x_shape.""" 32 | bs, = t.shape 33 | assert x_shape[0] == bs 34 | out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long()) 35 | assert out.shape == (bs,) 36 | out = out.reshape((bs,) + (1,) * (len(x_shape) - 1)) 37 | return out 38 | 39 | 40 | def image_editing_denoising_step_flexible_mask(x, t, *, model, logvar, betas): 41 | """ 42 | Sample from p(x_{t-1} | x_t) 43 | """ 44 | alphas = 1.0 - betas 45 | alphas_cumprod = alphas.cumprod(dim=0) 46 | 47 | model_output = model(x, t) 48 | weighted_score = betas / torch.sqrt(1 - alphas_cumprod) 49 | mean = extract(1 / torch.sqrt(alphas), t, x.shape) * (x - extract(weighted_score, t, x.shape) * model_output) 50 | 51 | logvar = extract(logvar, t, x.shape) 52 | noise = torch.randn_like(x) 53 | mask = 1 - (t == 0).float() 54 | mask = mask.reshape((x.shape[0],) + (1,) * (len(x.shape) - 1)) 55 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 56 | sample = sample.float() 57 | return sample 58 | 59 | 60 | class Diffusion(torch.nn.Module): 61 | def __init__(self, args, config, device=None): 62 | super().__init__() 63 | self.args = args 64 | self.config = config 65 | if device is None: 66 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 67 | self.device = device 68 | self.reverse_state = None 69 | self.reverse_state_cuda = None 70 | 71 | print("Loading model") 72 | defaults = model_and_diffusion_defaults(self.args.t_total) 73 | model, diffusion = create_model_and_diffusion(**defaults) 74 | model.load_state_dict( 75 | dist_util.load_state_dict("pretrained/cifar10_uncond_50M_500K.pt", map_location="cpu") 76 | ) 77 | model.to(self.device) 78 | model.eval() 79 | 80 | self.model = model 81 | self.diffusion = diffusion 82 | sigma = self.args.sigma 83 | a = 1/(1+(sigma*2)**2) 84 | self.scale = a**0.5 85 | sigma = sigma*2 86 | T = self.args.t_total 87 | self.t = T*(1-(2*1.008*math.asin(math.sin(math.pi/(2*1.008))/(1+sigma**2)**0.5))/math.pi) 88 | 89 | 90 | def image_editing_sample(self, img=None, bs_id=0, tag=None, sigma=0.0): 91 | assert isinstance(img, torch.Tensor) 92 | batch_size = img.shape[0] 93 | 94 | with torch.no_grad(): 95 | if tag is None: 96 | tag = 'rnd' + str(random.randint(0, 10000)) 97 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag) 98 | 99 | assert img.ndim == 4, img.ndim 100 | x0 = img 101 | 102 | x0 = self.scale*(img) 103 | t = self.t 104 | 105 | if self.args.use_clustering: 106 | x0 = x0.unsqueeze(1).repeat(1,self.args.clustering_batch,1,1,1).view(batch_size*self.args.clustering_batch,3,32,32) 107 | 108 | if self.args.use_one_step: 109 | # one step denoise 110 | t = torch.tensor([round(t)] * x0.shape[0], device=self.device) 111 | out = self.diffusion.p_sample( 112 | self.model, 113 | x0, 114 | t+self.args.t_plus, 115 | clip_denoised=True, 116 | ) 117 | x0 = out["pred_xstart"] 118 | 119 | elif self.args.use_t_steps: 120 | 121 | #save random state 122 | if self.args.save_predictions: 123 | global_seed_state = torch.random.get_rng_state() 124 | if torch.cuda.is_available(): 125 | global_cuda_state = torch.cuda.random.get_rng_state_all() 126 | 127 | if self.reverse_state==None: 128 | torch.manual_seed(self.args.reverse_seed) 129 | if torch.cuda.is_available(): 130 | torch.cuda.manual_seed_all(self.args.reverse_seed) 131 | else: 132 | torch.random.set_rng_state(self.reverse_state) 133 | if torch.cuda.is_available(): 134 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda) 135 | 136 | # t steps denoise 137 | inter = t/self.args.num_t_steps 138 | indices_t_steps = [round(t-i*inter) for i in range(self.args.num_t_steps)] 139 | 140 | for i in range(len(indices_t_steps)): 141 | t = torch.tensor([len(indices_t_steps)-i-1] * x0.shape[0], device=self.device) 142 | real_t = torch.tensor([indices_t_steps[i]] * x0.shape[0], device=self.device) 143 | with torch.no_grad(): 144 | out = self.diffusion.p_sample( 145 | self.model, 146 | x0, 147 | t, 148 | clip_denoised=True, 149 | indices_t_steps = indices_t_steps.copy(), 150 | T = self.args.t_total, 151 | step = len(indices_t_steps)-i, 152 | real_t = real_t 153 | ) 154 | x0 = out["sample"] 155 | 156 | #load random state 157 | if self.args.save_predictions: 158 | self.reverse_state = torch.random.get_rng_state() 159 | if torch.cuda.is_available(): 160 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all() 161 | 162 | torch.random.set_rng_state(global_seed_state) 163 | if torch.cuda.is_available(): 164 | torch.cuda.random.set_rng_state_all(global_cuda_state) 165 | 166 | else: 167 | #save random state 168 | if self.args.save_predictions: 169 | global_seed_state = torch.random.get_rng_state() 170 | if torch.cuda.is_available(): 171 | global_cuda_state = torch.cuda.random.get_rng_state_all() 172 | 173 | if self.reverse_state==None: 174 | torch.manual_seed(self.args.reverse_seed) 175 | if torch.cuda.is_available(): 176 | torch.cuda.manual_seed_all(self.args.reverse_seed) 177 | else: 178 | torch.random.set_rng_state(self.reverse_state) 179 | if torch.cuda.is_available(): 180 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda) 181 | 182 | # full steps denoise 183 | indices = list(range(round(t)))[::-1] 184 | for i in indices: 185 | t = torch.tensor([i] * x0.shape[0], device=self.device) 186 | with torch.no_grad(): 187 | out = self.diffusion.p_sample( 188 | self.model, 189 | x0, 190 | t, 191 | clip_denoised=True, 192 | ) 193 | x0 = out["sample"] 194 | 195 | #load random state 196 | if self.args.save_predictions: 197 | self.reverse_state = torch.random.get_rng_state() 198 | if torch.cuda.is_available(): 199 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all() 200 | 201 | torch.random.set_rng_state(global_seed_state) 202 | if torch.cuda.is_available(): 203 | torch.cuda.random.set_rng_state_all(global_cuda_state) 204 | 205 | return x0 206 | -------------------------------------------------------------------------------- /runners/diffpure_guided_densepure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torchvision.utils as tvu 6 | 7 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults 8 | import math 9 | import numpy as np 10 | 11 | 12 | class GuidedDiffusion(torch.nn.Module): 13 | def __init__(self, args, config, device=None, model_dir='pretrained'): 14 | super().__init__() 15 | self.args = args 16 | self.config = config 17 | if device is None: 18 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 19 | self.device = device 20 | self.reverse_state = None 21 | self.reverse_state_cuda = None 22 | 23 | # load model 24 | model_config = model_and_diffusion_defaults() 25 | model_config.update(vars(self.config.model)) 26 | print(f'model_config: {model_config}') 27 | model, diffusion = create_model_and_diffusion(**model_config) 28 | model.load_state_dict(torch.load(f'{model_dir}/256x256_diffusion_uncond.pt', map_location='cpu')) 29 | model.requires_grad_(False).eval().to(self.device) 30 | 31 | if model_config['use_fp16']: 32 | model.convert_to_fp16() 33 | 34 | self.model = model 35 | self.model.eval() 36 | self.diffusion = diffusion 37 | self.betas = diffusion.betas 38 | alphas = 1.0 - self.betas 39 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 40 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 41 | 42 | sigma = self.args.sigma 43 | 44 | a = 1/(1+(sigma*2)**2) 45 | self.scale = a**0.5 46 | 47 | sigma = sigma*2 48 | T = self.args.t_total 49 | for t in range(len(self.sqrt_recipm1_alphas_cumprod)-1): 50 | if self.sqrt_recipm1_alphas_cumprod[t]=sigma: 51 | if sigma - self.sqrt_recipm1_alphas_cumprod[t] > self.sqrt_recipm1_alphas_cumprod[t+1] - sigma: 52 | self.t = t+1 53 | break 54 | else: 55 | self.t = t 56 | break 57 | self.t = len(diffusion.alphas_cumprod)-1 58 | 59 | def image_editing_sample(self, img=None, bs_id=0, tag=None, sigma=0.0): 60 | assert isinstance(img, torch.Tensor) 61 | batch_size = img.shape[0] 62 | 63 | with torch.no_grad(): 64 | if tag is None: 65 | tag = 'rnd' + str(random.randint(0, 10000)) 66 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag) 67 | 68 | assert img.ndim == 4, img.ndim 69 | x0 = img 70 | 71 | x0 = self.scale*(img) 72 | t = self.t 73 | 74 | if self.args.use_clustering: 75 | x0 = x0.unsqueeze(1).repeat(1,self.args.clustering_batch,1,1,1).view(batch_size*self.args.clustering_batch,3,256,256) 76 | self.model.eval() 77 | 78 | if self.args.use_one_step: 79 | # one step denoise 80 | t = torch.tensor([round(t)] * x0.shape[0], device=self.device) 81 | out = self.diffusion.p_sample( 82 | self.model, 83 | x0, 84 | t+self.args.t_plus, 85 | clip_denoised=True, 86 | ) 87 | 88 | x0 = out["pred_xstart"] 89 | 90 | elif self.args.use_t_steps: 91 | #save random state 92 | if self.args.save_predictions: 93 | global_seed_state = torch.random.get_rng_state() 94 | if torch.cuda.is_available(): 95 | global_cuda_state = torch.cuda.random.get_rng_state_all() 96 | 97 | if self.reverse_state==None: 98 | torch.manual_seed(self.args.reverse_seed) 99 | if torch.cuda.is_available(): 100 | torch.cuda.manual_seed_all(self.args.reverse_seed) 101 | else: 102 | torch.random.set_rng_state(self.reverse_state) 103 | if torch.cuda.is_available(): 104 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda) 105 | 106 | # t steps denoise 107 | inter = t/self.args.num_t_steps 108 | indices_t_steps = [round(t-i*inter) for i in range(self.args.num_t_steps)] 109 | 110 | for i in range(len(indices_t_steps)): 111 | t = torch.tensor([len(indices_t_steps)-i-1] * x0.shape[0], device=self.device) 112 | real_t = torch.tensor([indices_t_steps[i]] * x0.shape[0], device=self.device) 113 | with torch.no_grad(): 114 | out = self.diffusion.p_sample( 115 | self.model, 116 | x0, 117 | t, 118 | clip_denoised=True, 119 | indices_t_steps = indices_t_steps.copy(), 120 | T = self.args.t_total, 121 | step = len(indices_t_steps)-i, 122 | real_t = real_t 123 | ) 124 | x0 = out["sample"] 125 | 126 | #load random state 127 | if self.args.save_predictions: 128 | self.reverse_state = torch.random.get_rng_state() 129 | if torch.cuda.is_available(): 130 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all() 131 | 132 | torch.random.set_rng_state(global_seed_state) 133 | if torch.cuda.is_available(): 134 | torch.cuda.random.set_rng_state_all(global_cuda_state) 135 | 136 | else: 137 | #save random state 138 | if self.args.save_predictions: 139 | global_seed_state = torch.random.get_rng_state() 140 | if torch.cuda.is_available(): 141 | global_cuda_state = torch.cuda.random.get_rng_state_all() 142 | 143 | if self.reverse_state==None: 144 | torch.manual_seed(self.args.reverse_seed) 145 | if torch.cuda.is_available(): 146 | torch.cuda.manual_seed_all(self.args.reverse_seed) 147 | else: 148 | torch.random.set_rng_state(self.reverse_state) 149 | if torch.cuda.is_available(): 150 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda) 151 | 152 | # full steps denoise 153 | indices = list(range(round(t)))[::-1] 154 | for i in indices: 155 | t = torch.tensor([i] * x0.shape[0], device=self.device) 156 | with torch.no_grad(): 157 | out = self.diffusion.p_sample( 158 | self.model, 159 | x0, 160 | t, 161 | clip_denoised=True, 162 | ) 163 | x0 = out["sample"] 164 | 165 | #load random state 166 | if self.args.save_predictions: 167 | self.reverse_state = torch.random.get_rng_state() 168 | if torch.cuda.is_available(): 169 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all() 170 | 171 | torch.random.set_rng_state(global_seed_state) 172 | if torch.cuda.is_available(): 173 | torch.cuda.random.set_rng_state_all(global_cuda_state) 174 | 175 | return x0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from typing import Any 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | from torch.utils.data import DataLoader 8 | import torchvision.transforms as transforms 9 | from architectures import get_architecture 10 | import data 11 | 12 | 13 | def compute_n_params(model, return_str=True): 14 | tot = 0 15 | for p in model.parameters(): 16 | w = 1 17 | for x in p.shape: 18 | w *= x 19 | tot += w 20 | if return_str: 21 | if tot >= 1e6: 22 | return '{:.1f}M'.format(tot / 1e6) 23 | else: 24 | return '{:.1f}K'.format(tot / 1e3) 25 | else: 26 | return tot 27 | 28 | 29 | class Logger(object): 30 | """ 31 | Redirect stderr to stdout, optionally print stdout to a file, 32 | and optionally force flushing on both stdout and the file. 33 | """ 34 | 35 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 36 | self.file = None 37 | 38 | if file_name is not None: 39 | self.file = open(file_name, file_mode) 40 | 41 | self.should_flush = should_flush 42 | self.stdout = sys.stdout 43 | self.stderr = sys.stderr 44 | 45 | sys.stdout = self 46 | sys.stderr = self 47 | 48 | def __enter__(self) -> "Logger": 49 | return self 50 | 51 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 52 | self.close() 53 | 54 | def write(self, text: str) -> None: 55 | """Write text to stdout (and a file) and optionally flush.""" 56 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 57 | return 58 | 59 | if self.file is not None: 60 | self.file.write(text) 61 | 62 | self.stdout.write(text) 63 | 64 | if self.should_flush: 65 | self.flush() 66 | 67 | def flush(self) -> None: 68 | """Flush written text to both stdout and a file, if open.""" 69 | if self.file is not None: 70 | self.file.flush() 71 | 72 | self.stdout.flush() 73 | 74 | def close(self) -> None: 75 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 76 | self.flush() 77 | 78 | # if using multiple loggers, prevent closing in wrong order 79 | if sys.stdout is self: 80 | sys.stdout = self.stdout 81 | if sys.stderr is self: 82 | sys.stderr = self.stderr 83 | 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | 88 | def dict2namespace(config): 89 | namespace = argparse.Namespace() 90 | for key, value in config.items(): 91 | if isinstance(value, dict): 92 | new_value = dict2namespace(value) 93 | else: 94 | new_value = value 95 | setattr(namespace, key, new_value) 96 | return namespace 97 | 98 | 99 | def str2bool(v): 100 | if isinstance(v, bool): 101 | return v 102 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 103 | return True 104 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 105 | return False 106 | else: 107 | raise argparse.ArgumentTypeError('Boolean value expected.') 108 | 109 | 110 | def update_state_dict(state_dict, idx_start=9): 111 | 112 | from collections import OrderedDict 113 | new_state_dict = OrderedDict() 114 | for k, v in state_dict.items(): 115 | name = k[idx_start:] # remove 'module.0.' of dataparallel 116 | new_state_dict[name]=v 117 | 118 | return new_state_dict 119 | 120 | 121 | # ------------------------------------------------------------------------ 122 | def get_accuracy(model, x_orig, y_orig, bs=64, device=torch.device('cuda:0')): 123 | n_batches = x_orig.shape[0] // bs 124 | acc = 0. 125 | for counter in range(n_batches): 126 | x = x_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(device) 127 | y = y_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(device) 128 | output = model(x) 129 | acc += (output.max(1)[1] == y).float().sum() 130 | 131 | return (acc / x_orig.shape[0]).item() 132 | 133 | def get_image_classifier_certified(classifier_path, dataset): 134 | checkpoint = torch.load(classifier_path) 135 | base_classifier = get_architecture(checkpoint["arch"], dataset) 136 | base_classifier.load_state_dict(checkpoint['state_dict']) 137 | return base_classifier 138 | 139 | 140 | def load_data(args, adv_batch_size): 141 | if 'imagenet' in args.domain: 142 | val_dir = '/home/data/imagenet/imagenet' # using imagenet lmdb data 143 | val_transform = data.get_transform(args.domain, 'imval', base_size=224) 144 | val_data = data.imagenet_lmdb_dataset_sub(val_dir, transform=val_transform, 145 | num_sub=args.num_sub, data_seed=args.data_seed) 146 | n_samples = len(val_data) 147 | val_loader = DataLoader(val_data, batch_size=n_samples, shuffle=False, pin_memory=True, num_workers=4) 148 | x_val, y_val = next(iter(val_loader)) 149 | elif 'cifar10' in args.domain: 150 | data_dir = './dataset' 151 | transform = transforms.Compose([transforms.ToTensor()]) 152 | val_data = data.cifar10_dataset_sub(data_dir, transform=transform, 153 | num_sub=args.num_sub, data_seed=args.data_seed) 154 | n_samples = len(val_data) 155 | val_loader = DataLoader(val_data, batch_size=n_samples, shuffle=False, pin_memory=True, num_workers=4) 156 | x_val, y_val = next(iter(val_loader)) 157 | elif 'celebahq' in args.domain: 158 | data_dir = './dataset/celebahq' 159 | attribute = args.classifier_name.split('__')[-1] # `celebahq__Smiling` 160 | val_transform = data.get_transform('celebahq', 'imval') 161 | clean_dset = data.get_dataset('celebahq', 'val', attribute, root=data_dir, transform=val_transform, 162 | fraction=2, data_seed=args.data_seed) # data_seed randomizes here 163 | loader = DataLoader(clean_dset, batch_size=adv_batch_size, shuffle=False, 164 | pin_memory=True, num_workers=4) 165 | x_val, y_val = next(iter(loader)) # [0, 1], 256x256 166 | else: 167 | raise NotImplementedError(f'Unknown domain: {args.domain}!') 168 | 169 | print(f'x_val shape: {x_val.shape}') 170 | x_val, y_val = x_val.contiguous().requires_grad_(True), y_val.contiguous() 171 | print(f'x (min, max): ({x_val.min()}, {x_val.max()})') 172 | 173 | return x_val, y_val 174 | -------------------------------------------------------------------------------- /zipdata.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os.path as op 3 | from threading import local 4 | from zipfile import ZipFile, BadZipFile 5 | 6 | from PIL import Image 7 | from io import BytesIO 8 | import torch.utils.data as data 9 | 10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] 11 | 12 | class ZipData(data.Dataset): 13 | _IGNORE_ATTRS = {'_zip_file'} 14 | 15 | def __init__(self, path, map_file, 16 | transform=None, target_transform=None, 17 | extensions=None): 18 | self._path = path 19 | if not extensions: 20 | extensions = _VALID_IMAGE_TYPES 21 | self._zip_file = ZipFile(path) 22 | self.zip_dict = {} 23 | self.samples = [] 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | self.class_to_idx = {} 27 | with open(map_file, 'r') as f: 28 | for line in iter(f.readline, ""): 29 | line = line.strip() 30 | if not line: 31 | continue 32 | cls_idx = [l for l in line.split('\t') if l] 33 | if not cls_idx: 34 | continue 35 | assert len(cls_idx) >= 2, "invalid line: {}".format(line) 36 | idx = int(cls_idx[1]) 37 | cls = cls_idx[0] 38 | del cls_idx 39 | at_idx = cls.find('@') 40 | assert at_idx >= 0, "invalid class: {}".format(cls) 41 | cls = cls[at_idx + 1:] 42 | if cls.startswith('/'): 43 | # Python ZipFile expects no root 44 | cls = cls[1:] 45 | assert cls, "invalid class in line {}".format(line) 46 | prev_idx = self.class_to_idx.get(cls) 47 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format( 48 | cls, idx, prev_idx 49 | ) 50 | self.class_to_idx[cls] = idx 51 | 52 | for fst in self._zip_file.infolist(): 53 | fname = fst.filename 54 | target = self.class_to_idx.get(fname) 55 | if target is None: 56 | continue 57 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: 58 | continue 59 | ext = op.splitext(fname)[1].lower() 60 | if ext in extensions: 61 | self.samples.append((fname, target)) 62 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file) 63 | 64 | def __repr__(self): 65 | return 'ZipData({}, size={})'.format(self._path, len(self)) 66 | 67 | def __getstate__(self): 68 | return { 69 | key: val if key not in self._IGNORE_ATTRS else None 70 | for key, val in self.__dict__.iteritems() 71 | } 72 | 73 | def __getitem__(self, index): 74 | proc = multiprocessing.current_process() 75 | pid = proc.pid # get pid of this process. 76 | if pid not in self.zip_dict: 77 | self.zip_dict[pid] = ZipFile(self._path) 78 | zip_file = self.zip_dict[pid] 79 | 80 | if index >= len(self) or index < 0: 81 | raise KeyError("{} is invalid".format(index)) 82 | path, target = self.samples[index] 83 | try: 84 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB') 85 | except BadZipFile: 86 | print("bad zip file") 87 | return None, None 88 | if self.transform is not None: 89 | sample = self.transform(sample) 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | return sample, target 93 | 94 | def __len__(self): 95 | return len(self.samples) 96 | --------------------------------------------------------------------------------