├── README.md
├── architectures.py
├── archs
└── cifar_resnet.py
├── certification_pic.py
├── classifiers
├── attribute_classifier.py
├── attribute_net.py
└── cifar10_resnet.py
├── compute_accuracy.py
├── configs
├── cifar10.yml
└── imagenet.yml
├── core.py
├── data
├── __init__.py
└── datasets.py
├── datasets.py
├── ddpm
└── unet_ddpm.py
├── eval_certified_densepure.py
├── guided_diffusion
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
└── unet.py
├── improved_diffusion
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── networks
│ ├── __init__.py
│ ├── lenet.py
│ ├── resnet.py
│ ├── vggnet.py
│ └── wide_resnet.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
└── unet.py
├── networks
├── __init__.py
├── lenet.py
├── resnet.py
├── vggnet.py
└── wide_resnet.py
├── pictures
└── densepure_flowchart.png
├── requirements.txt
├── results
├── merge_cifar10.sh
├── merge_imagenet.sh
└── merge_results.py
├── run_scripts
├── carlini22_cifar10.sh
├── carlini22_imagenet.sh
├── densepure_cifar10.sh
└── densepure_imagenet.sh
├── runners
├── diffpure_ddpm_densepure.py
└── diffpure_guided_densepure.py
├── utils.py
└── zipdata.py
/README.md:
--------------------------------------------------------------------------------
1 | # DensePure: Understanding Diffusion Models towards Adversarial Robustness
2 |
3 |
4 |
5 |
6 |
7 | Official PyTorch implementation of the paper:
8 | **[DensePure: Understanding Diffusion Models towards Adversarial Robustness](https://arxiv.org/abs/2211.00322)**
9 |
10 | Chaowei Xiao, Zhongzhu Chen, Kun Jin, Jiongxiao Wang, Weili Nie, Mingyan Liu, Anima Anandkumar, Bo Li, Dawn Song
11 | https://densepure.github.io
12 |
13 | Abstract: *Diffusion models have been recently employed to improve certified robustness through the process of denoising. However, the theoretical understanding of why diffusion models are able to improve the certified robustness is still lacking, preventing from further improvement. In this study, we close this gap by analyzing the fundamental properties of diffusion models and establishing the conditions under which they can enhance certified robustness. This deeper understanding allows us to propose a new method DensePure, designed to improve the certified robustness of a pretrained model (i.e. classifier). Given an (adversarial) input, DensePure consists of multiple runs of denoising via the reverse process of the diffusion model (with different random seeds) to get multiple reversed samples, which are then passed through the classifier, followed by majority voting of inferred labels to make the final prediction. This design of using multiple runs of denoising is informed by our theoretical analysis of the conditional distribution of the reversed sample. Specifically, when the data density of a clean sample is high, its conditional density under the reverse process in a diffusion model is also high; thus sampling from the latter conditional distribution can purify the adversarial example and return the corresponding clean sample with a high probability. By using the highest density point in the conditional distribution as the reversed sample, we identify the robust region of a given instance under the diffusion model's reverse process. We show that this robust region is a union of multiple convex sets, and is potentially much larger than the robust regions identified in previous works. In practice, DensePure can approximate the label of the high density region in the conditional distribution so that it can enhance certified robustness. We conduct extensive experiments to demonstrate the effectiveness of DensePure by evaluating its certified robustness given a standard model via randomized smoothing. We show that DensePure is consistently better than existing methods on ImageNet, with 7% improvement on average.*
14 |
15 | ## Requirements
16 |
17 | - Python 3.8.5
18 | - CUDA=11.1
19 | - Installation of PyTorch 1.8.0:
20 | ```bash
21 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
22 | ```
23 | - Installation of required packages:
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Datasets, Pre-trained Diffusion Models and Classifiers
29 | Before running our code, you need to first prepare two datasets CIFAR-10 and ImageNet. CIFAR-10 will be downloaded automatically.
30 | For ImageNet, you need to download validation images of ILSVRC2012 from https://www.image-net.org/. And the images need to be preprocessed by running the scripts `valprep.sh` from https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
31 | under validation directory.
32 |
33 | Please change IMAGENET_DIR to your own location of ImageNet dataset in `datasets.py` before running the code.
34 |
35 | For the pre-trained diffusion models, you need to first download them from the following links:
36 | - [Improved Diffusion](https://github.com/openai/improved-diffusion) for
37 | CIFAR-10: (`cifar10_uncond_50M_500K.pt`: [download link](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_50M_500K.pt))
38 | - [Guided Diffusion](https://github.com/openai/guided-diffusion) for
39 | ImageNet: (`256x256 diffusion unconditional`: [download link](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt))
40 |
41 | For the pre-trained classifiers, ViT-B/16 model on CIFAR-10 will be automatically downloaded by `transformers`. For ImageNet BEiT large model, you need to dwonload from the following links:
42 | - [BEiT](https://github.com/microsoft/unilm/tree/master/beit) for
43 | ImageNet: (`beit_large_patch16_512_pt22k_ft22kto1k.pth`: [download link](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth))
44 |
45 | Please place all the pretrained models in the `pretrained` directory. If you want to use your own classifiers, code need to be changed in `eval_certified_densepure.py`.
46 |
47 | ## Run Experiments of Carlini 2022
48 | We provide our own code implementation for the paper [(Certified!!) Adversarial Robustness for Free!](https://arxiv.org/abs/2206.10550) to compare with DensePure.
49 |
50 | To gain the results in Table 1 about Carlini22, please run the following scripts using different noise levels `sigma`:
51 | ```
52 | cd run_scripts
53 | bash carlini22_cifar10.sh [sigma] # For CIFAR-10
54 | bash carlini22_imagenet.sh [sigma] # For ImageNet
55 | ```
56 |
57 | ## Run Experiments of DensePure
58 | To get certified accuracy under DensePure, please run the following scripts:
59 | ```
60 | cd run_scripts
61 | bash densepure_cifar10.sh [sigma] [steps] [reverse_seed] # For CIFAR-10
62 | bash densepure_imagenet.sh [sigma] [steps] [reverse_seed] # For ImageNet
63 | ```
64 |
65 | Note: `sigma` is the noise level of randomized smoothing. `steps` is the parameter for fast sampling steps in Section 5.2 and it must be larger than one and smaller than the total reverse steps. `reverse_seed` is a parameter which control majority vote process in Section 5.2. For example, you need to run `densepure_cifar10.sh` 10 times with 10 different `reverse_seed` to finish 10 majority vote numbers experiments. After running above scripts under one `reverse_seed`, you will gain a `.npy` file that contains labels of 100000 (for CIFAR-10) randomized smoothing sampling times. If you want to obtain the final results of 10 majority vote numbers, you need to run the following scripts in `results` directory:
66 | ```
67 | cd results
68 | bash merge_cifar10.sh [sigma] [steps] [majority_vote_numbers] # For CIFAR-10
69 | bash merge_imagenet.sh [sigma] [steps] [majority_vote_numbers] # For ImageNet
70 | ```
71 |
72 | ## Citation
73 | Please cite our paper and Carlini et al. (2022), if you happen to use this codebase:
74 | ```
75 | @article{xiao2022densepure,
76 | title={DensePure: Understanding Diffusion Models towards Adversarial Robustness},
77 | author={Xiao, Chaowei and Chen, Zhongzhu and Jin, Kun and Wang, Jiongxiao and Nie, Weili and Liu, Mingyan and Anandkumar, Anima and Li, Bo and Song, Dawn},
78 | journal={arXiv preprint arXiv:2211.00322},
79 | year={2022}
80 | }
81 | ```
82 |
83 | ```
84 | @article{carlini2022certified,
85 | title={(Certified!!) Adversarial Robustness for Free!},
86 | author={Carlini, Nicholas and Tramer, Florian and Kolter, J Zico and others},
87 | journal={arXiv preprint arXiv:2206.10550},
88 | year={2022}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/architectures.py:
--------------------------------------------------------------------------------
1 | from archs.cifar_resnet import resnet as resnet_cifar
2 | from datasets import get_normalize_layer, get_input_center_layer
3 | import torch
4 | import torch.backends.cudnn as cudnn
5 | import torch.nn as nn
6 | from torch.nn.functional import interpolate
7 | from torchvision.models.resnet import resnet50
8 |
9 |
10 | # resnet50 - the classic ResNet-50, sized for ImageNet
11 | # cifar_resnet20 - a 20-layer residual network sized for CIFAR
12 | # cifar_resnet110 - a 110-layer residual network sized for CIFAR
13 | ARCHITECTURES = ["resnet50", "cifar_resnet110", "imagenet32_resnet110"]
14 |
15 | def get_architecture(arch: str, dataset: str) -> torch.nn.Module:
16 | """ Return a neural network (with random weights)
17 |
18 | :param arch: the architecture - should be in the ARCHITECTURES list above
19 | :param dataset: the dataset - should be in the datasets.DATASETS list
20 | :return: a Pytorch module
21 | """
22 | if arch == "resnet50" and dataset == "imagenet":
23 | model = torch.nn.DataParallel(resnet50(pretrained=False)).cuda()
24 | cudnn.benchmark = True
25 | elif arch == "cifar_resnet20":
26 | model = resnet_cifar(depth=20, num_classes=10).cuda()
27 | elif arch == "cifar_resnet110":
28 | model = resnet_cifar(depth=110, num_classes=10).cuda()
29 | elif arch == "imagenet32_resnet110":
30 | model = resnet_cifar(depth=110, num_classes=1000).cuda()
31 |
32 | # Both layers work fine, We tried both, and they both
33 | # give very similar results
34 | # IF YOU USE ONE OF THESE FOR TRAINING, MAKE SURE
35 | # TO USE THE SAME WHEN CERTIFYING.
36 | normalize_layer = get_normalize_layer(dataset)
37 | # normalize_layer = get_input_center_layer(dataset)
38 | return torch.nn.Sequential(normalize_layer, model)
39 |
--------------------------------------------------------------------------------
/archs/cifar_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | __all__ = ['resnet']
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 | residual = x
29 |
30 | out = self.conv1(x)
31 | out = self.bn1(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv2(out)
35 | out = self.bn2(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 |
46 | class Bottleneck(nn.Module):
47 | expansion = 4
48 |
49 | def __init__(self, inplanes, planes, stride=1, downsample=None):
50 | super(Bottleneck, self).__init__()
51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
52 | self.bn1 = nn.BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
54 | padding=1, bias=False)
55 | self.bn2 = nn.BatchNorm2d(planes)
56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57 | self.bn3 = nn.BatchNorm2d(planes * 4)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
88 | super(ResNet, self).__init__()
89 | # Model type specifies number of layers for CIFAR-10 model
90 | if block_name.lower() == 'basicblock':
91 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
92 | n = (depth - 2) // 6
93 | block = BasicBlock
94 | elif block_name.lower() == 'bottleneck':
95 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
96 | n = (depth - 2) // 9
97 | block = Bottleneck
98 | else:
99 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
100 |
101 |
102 | self.inplanes = 16
103 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
104 | bias=False)
105 | self.bn1 = nn.BatchNorm2d(16)
106 | self.relu = nn.ReLU(inplace=True)
107 | self.layer1 = self._make_layer(block, 16, n)
108 | self.layer2 = self._make_layer(block, 32, n, stride=2)
109 | self.layer3 = self._make_layer(block, 64, n, stride=2)
110 | self.avgpool = nn.AvgPool2d(8)
111 | self.fc = nn.Linear(64 * block.expansion, num_classes)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | elif isinstance(m, nn.BatchNorm2d):
118 | m.weight.data.fill_(1)
119 | m.bias.data.zero_()
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x) # 32x32
142 |
143 | x = self.layer1(x) # 32x32
144 | x = self.layer2(x) # 16x16
145 | x = self.layer3(x) # 8x8
146 |
147 | x = self.avgpool(x)
148 | x = x.view(x.size(0), -1)
149 | x = self.fc(x)
150 |
151 | return x
152 |
153 |
154 | def resnet(**kwargs):
155 | """
156 | Constructs a ResNet model.
157 | """
158 | return ResNet(**kwargs)
--------------------------------------------------------------------------------
/classifiers/attribute_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from . import attribute_net
4 |
5 | softmax = torch.nn.Softmax(dim=1)
6 |
7 |
8 | def downsample(images, size=256):
9 | # Downsample to 256x256. The attribute classifiers were built for 256x256.
10 | # follows https://github.com/NVlabs/stylegan/blob/master/metrics/linear_separability.py#L127
11 | if images.shape[2] > size:
12 | factor = images.shape[2] // size
13 | assert (factor * size == images.shape[2])
14 | images = images.view(
15 | [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
16 | images = images.mean(dim=[3, 5])
17 | return images
18 | else:
19 | assert (images.shape[-1] == 256)
20 | return images
21 |
22 |
23 | def get_logit(net, im):
24 | im_256 = downsample(im)
25 | logit = net(im_256)
26 | return logit
27 |
28 |
29 | def get_softmaxed(net, im):
30 | logit = get_logit(net, im)
31 | logits = torch.cat([logit, -logit], dim=1)
32 | softmaxed = softmax(torch.cat([logit, -logit], dim=1))[:, 1]
33 | return logits, softmaxed
34 |
35 |
36 | def load_attribute_classifier(attribute, ckpt_path=None):
37 | if ckpt_path is None:
38 | base_path = 'pretrained/celebahq'
39 | attribute_pkl = os.path.join(base_path, attribute, 'net_best.pth')
40 | ckpt = torch.load(attribute_pkl)
41 | else:
42 | ckpt = torch.load(ckpt_path)
43 | print("Using classifier at epoch: %d" % ckpt['epoch'])
44 | if 'valacc' in ckpt.keys():
45 | print("Validation acc on raw images: %0.5f" % ckpt['valacc'])
46 | detector = attribute_net.from_state_dict(
47 | ckpt['state_dict'], fixed_size=True, use_mbstd=False).cuda().eval()
48 | return detector
49 |
50 |
51 | class ClassifierWrapper(torch.nn.Module):
52 | def __init__(self, classifier_name, ckpt_path=None, device='cuda'):
53 | super(ClassifierWrapper, self).__init__()
54 | self.net = load_attribute_classifier(classifier_name, ckpt_path).eval().to(device)
55 |
56 | def forward(self, ims):
57 | out = (ims - 0.5) / 0.5
58 | return get_softmaxed(self.net, out)[0]
59 |
--------------------------------------------------------------------------------
/classifiers/attribute_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | def lerp_clip(a, b, t):
7 | return a + (b - a) * torch.clamp(t, 0.0, 1.0)
8 |
9 |
10 | class WScaleLayer(nn.Module):
11 | def __init__(self, size, fan_in, gain=np.sqrt(2), bias=True):
12 | super(WScaleLayer, self).__init__()
13 | self.scale = gain / np.sqrt(fan_in) # No longer a parameter
14 | if bias:
15 | self.b = nn.Parameter(torch.randn(size))
16 | else:
17 | self.b = 0
18 | self.size = size
19 |
20 | def forward(self, x):
21 | x_size = x.size()
22 | x = x * self.scale
23 | # modified to remove warning
24 | if type(self.b) == nn.Parameter and len(x_size) == 4:
25 | x = x + self.b.view(1, -1, 1, 1).expand(
26 | x_size[0], self.size, x_size[2], x_size[3])
27 | if type(self.b) == nn.Parameter and len(x_size) == 2:
28 | x = x + self.b.view(1, -1).expand(
29 | x_size[0], self.size)
30 | return x
31 |
32 |
33 | class WScaleConv2d(nn.Module):
34 | def __init__(self, in_channels, out_channels, kernel_size, padding=0,
35 | bias=True, gain=np.sqrt(2)):
36 | super().__init__()
37 | self.conv = nn.Conv2d(in_channels, out_channels,
38 | kernel_size=kernel_size,
39 | padding=padding,
40 | bias=False)
41 | fan_in = in_channels * kernel_size * kernel_size
42 | self.wscale = WScaleLayer(out_channels, fan_in, gain=gain, bias=bias)
43 |
44 | def forward(self, x):
45 | return self.wscale(self.conv(x))
46 |
47 |
48 | class WScaleLinear(nn.Module):
49 | def __init__(self, in_channels, out_channels, bias=True, gain=np.sqrt(2)):
50 | super().__init__()
51 | self.linear = nn.Linear(in_channels, out_channels, bias=False)
52 | self.wscale = WScaleLayer(out_channels, in_channels, gain=gain,
53 | bias=bias)
54 |
55 | def forward(self, x):
56 | return self.wscale(self.linear(x))
57 |
58 |
59 | class FromRGB(nn.Module):
60 | def __init__(self, in_channels, out_channels, kernel_size,
61 | act=nn.LeakyReLU(0.2), bias=True):
62 | super().__init__()
63 | self.conv = WScaleConv2d(in_channels, out_channels, kernel_size,
64 | padding=0, bias=bias)
65 | self.act = act
66 |
67 | def forward(self, x):
68 | return self.act(self.conv(x))
69 |
70 |
71 | class Downscale2d(nn.Module):
72 | def __init__(self, factor=2):
73 | super().__init__()
74 | self.downsample = nn.AvgPool2d(kernel_size=factor, stride=factor)
75 |
76 | def forward(self, x):
77 | return self.downsample(x)
78 |
79 |
80 | class DownscaleConvBlock(nn.Module):
81 | def __init__(self, in_channels, conv0_channels, conv1_channels,
82 | kernel_size, padding, bias=True, act=nn.LeakyReLU(0.2)):
83 | super().__init__()
84 | self.downscale = Downscale2d()
85 | self.conv0 = WScaleConv2d(in_channels, conv0_channels,
86 | kernel_size=kernel_size,
87 | padding=padding,
88 | bias=bias)
89 | self.conv1 = WScaleConv2d(conv0_channels, conv1_channels,
90 | kernel_size=kernel_size,
91 | padding=padding,
92 | bias=bias)
93 | self.act = act
94 |
95 | def forward(self, x):
96 | x = self.act(self.conv0(x))
97 | # conv2d_downscale2d applies downscaling before activation
98 | # the order matters here! has to be conv -> bias -> downscale -> act
99 | x = self.conv1(x)
100 | x = self.downscale(x)
101 | x = self.act(x)
102 | return x
103 |
104 |
105 | class MinibatchStdLayer(nn.Module):
106 | def __init__(self, group_size=4):
107 | super().__init__()
108 | self.group_size = group_size
109 |
110 | def forward(self, x):
111 | group_size = min(self.group_size, x.shape[0])
112 | s = x.shape
113 | y = x.view([group_size, -1, s[1], s[2], s[3]])
114 | y = y.float()
115 | y = y - torch.mean(y, dim=0, keepdim=True)
116 | y = torch.mean(y * y, dim=0)
117 | y = torch.sqrt(y + 1e-8)
118 | y = torch.mean(torch.mean(torch.mean(y, dim=3, keepdim=True),
119 | dim=2, keepdim=True), dim=1, keepdim=True)
120 | y = y.type(x.type())
121 | y = y.repeat(group_size, 1, s[2], s[3])
122 | return torch.cat([x, y], dim=1)
123 |
124 |
125 | class PredictionBlock(nn.Module):
126 | def __init__(self, in_channels, dense0_feat, dense1_feat, out_feat,
127 | pool_size=2, act=nn.LeakyReLU(0.2), use_mbstd=True):
128 | super().__init__()
129 | self.use_mbstd = use_mbstd # attribute classifiers don't have this
130 | if self.use_mbstd:
131 | self.mbstd_layer = MinibatchStdLayer()
132 | # MinibatchStdLayer adds an additional feature dimension
133 | self.conv = WScaleConv2d(in_channels + int(self.use_mbstd),
134 | dense0_feat, kernel_size=3, padding=1)
135 | self.dense0 = WScaleLinear(dense0_feat * pool_size * pool_size, dense1_feat)
136 | self.dense1 = WScaleLinear(dense1_feat, out_feat, gain=1)
137 | self.act = act
138 |
139 | def forward(self, x):
140 | if self.use_mbstd:
141 | x = self.mbstd_layer(x)
142 | x = self.act(self.conv(x))
143 | x = x.view([x.shape[0], -1])
144 | x = self.act(self.dense0(x))
145 | x = self.dense1(x)
146 | return x
147 |
148 |
149 | class D(nn.Module):
150 |
151 | def __init__(
152 | self,
153 | num_channels=3, # Number of input color channels. Overridden based on dataset.
154 | resolution=128, # Input resolution. Overridden based on dataset.
155 | fmap_base=8192, # Overall multiplier for the number of feature maps.
156 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution.
157 | fmap_max=512, # Maximum number of feature maps in any layer.
158 | fixed_size=False, # True = load fromrgb_lod0 weights only
159 | use_mbstd=True, # False = no mbstd layer in PredictionBlock
160 | **kwargs): # Ignore unrecognized keyword args.
161 | super().__init__()
162 |
163 | self.resolution_log2 = resolution_log2 = int(np.log2(resolution))
164 | assert resolution == 2 ** resolution_log2 and resolution >= 4
165 |
166 | def nf(stage):
167 | return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
168 |
169 | self.register_buffer('lod_in', torch.from_numpy(np.array(0.0)))
170 |
171 | res = resolution_log2
172 |
173 | setattr(self, 'fromrgb_lod0', FromRGB(num_channels, nf(res - 1), 1))
174 |
175 | for i, res in enumerate(range(resolution_log2, 2, -1), 1):
176 | lod = resolution_log2 - res
177 | block = DownscaleConvBlock(nf(res - 1), nf(res - 1), nf(res - 2),
178 | kernel_size=3, padding=1)
179 | setattr(self, '%dx%d' % (2 ** res, 2 ** res), block)
180 | fromrgb = FromRGB(3, nf(res - 2), 1)
181 | if not fixed_size:
182 | setattr(self, 'fromrgb_lod%d' % i, fromrgb)
183 |
184 | res = 2
185 | pool_size = 2 ** res
186 | block = PredictionBlock(nf(res + 1 - 2), nf(res - 1), nf(res - 2), 1,
187 | pool_size, use_mbstd=use_mbstd)
188 | setattr(self, '%dx%d' % (pool_size, pool_size), block)
189 | self.downscale = Downscale2d()
190 | self.fixed_size = fixed_size
191 |
192 | def forward(self, img):
193 | x = self.fromrgb_lod0(img)
194 | for i, res in enumerate(range(self.resolution_log2, 2, -1), 1):
195 | lod = self.resolution_log2 - res
196 | x = getattr(self, '%dx%d' % (2 ** res, 2 ** res))(x)
197 | if not self.fixed_size:
198 | img = self.downscale(img)
199 | y = getattr(self, 'fromrgb_lod%d' % i)(img)
200 | x = lerp_clip(x, y, self.lod_in - lod)
201 | res = 2
202 | pool_size = 2 ** res
203 | out = getattr(self, '%dx%d' % (pool_size, pool_size))(x)
204 | return out
205 |
206 |
207 | def max_res_from_state_dict(state_dict):
208 | for i in range(3, 12):
209 | if '%dx%d.conv0.conv.weight' % (2 ** i, 2 ** i) not in state_dict:
210 | break
211 | return 2 ** (i - 1)
212 |
213 |
214 | def from_state_dict(state_dict, fixed_size=False, use_mbstd=True):
215 | res = max_res_from_state_dict(state_dict)
216 | print(f'res: {res}')
217 | d = D(num_channels=3, resolution=res, fixed_size=fixed_size,
218 | use_mbstd=use_mbstd)
219 | d.load_state_dict(state_dict)
220 | return d
221 |
--------------------------------------------------------------------------------
/classifiers/cifar10_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 |
7 |
8 | # ---------------------------- ResNet ----------------------------
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, in_planes, planes, stride=1):
14 | super(Bottleneck, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
20 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion * planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(self.expansion * planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = F.relu(self.bn2(self.conv2(out)))
32 | out = self.bn3(self.conv3(out))
33 | out += self.shortcut(x)
34 | out = F.relu(out)
35 | return out
36 |
37 |
38 | class ResNet(nn.Module):
39 | def __init__(self, block, num_blocks, num_classes=10):
40 | super(ResNet, self).__init__()
41 | self.in_planes = 64
42 |
43 | num_input_channels = 3
44 | mean = (0.4914, 0.4822, 0.4465)
45 | std = (0.2471, 0.2435, 0.2616)
46 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
47 | self.std = torch.tensor(std).view(num_input_channels, 1, 1)
48 |
49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
50 | self.bn1 = nn.BatchNorm2d(64)
51 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
52 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
53 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
54 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
55 | self.linear = nn.Linear(512 * block.expansion, num_classes)
56 |
57 | def _make_layer(self, block, planes, num_blocks, stride):
58 | strides = [stride] + [1] * (num_blocks - 1)
59 | layers = []
60 | for stride in strides:
61 | layers.append(block(self.in_planes, planes, stride))
62 | self.in_planes = planes * block.expansion
63 | return nn.Sequential(*layers)
64 |
65 | def forward(self, x):
66 | out = (x - self.mean.to(x.device)) / self.std.to(x.device)
67 | out = F.relu(self.bn1(self.conv1(out)))
68 | out = self.layer1(out)
69 | out = self.layer2(out)
70 | out = self.layer3(out)
71 | out = self.layer4(out)
72 | out = F.avg_pool2d(out, 4)
73 | out = out.view(out.size(0), -1)
74 | out = self.linear(out)
75 | return out
76 |
77 |
78 | def ResNet50():
79 | return ResNet(Bottleneck, [3, 4, 6, 3])
80 |
81 |
82 | # ---------------------------- ResNet ----------------------------
83 |
84 |
85 | # ---------------------------- WideResNet ----------------------------
86 |
87 | class BasicBlock(nn.Module):
88 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
89 | super(BasicBlock, self).__init__()
90 | self.bn1 = nn.BatchNorm2d(in_planes)
91 | self.relu1 = nn.ReLU(inplace=True)
92 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
93 | padding=1, bias=False)
94 | self.bn2 = nn.BatchNorm2d(out_planes)
95 | self.relu2 = nn.ReLU(inplace=True)
96 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
97 | padding=1, bias=False)
98 | self.droprate = dropRate
99 | self.equalInOut = (in_planes == out_planes)
100 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
101 | padding=0, bias=False) or None
102 |
103 | def forward(self, x):
104 | if not self.equalInOut:
105 | x = self.relu1(self.bn1(x))
106 | else:
107 | out = self.relu1(self.bn1(x))
108 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
109 | if self.droprate > 0:
110 | out = F.dropout(out, p=self.droprate, training=self.training)
111 | out = self.conv2(out)
112 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
113 |
114 |
115 | class NetworkBlock(nn.Module):
116 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
117 | super(NetworkBlock, self).__init__()
118 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
119 |
120 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
121 | layers = []
122 | for i in range(int(nb_layers)):
123 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
124 | return nn.Sequential(*layers)
125 |
126 | def forward(self, x):
127 | return self.layer(x)
128 |
129 |
130 | class WideResNet(nn.Module):
131 | """ Based on code from https://github.com/yaodongyu/TRADES """
132 |
133 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
134 | super(WideResNet, self).__init__()
135 |
136 | num_input_channels = 3
137 | mean = (0.4914, 0.4822, 0.4465)
138 | std = (0.2471, 0.2435, 0.2616)
139 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1)
140 | self.std = torch.tensor(std).view(num_input_channels, 1, 1)
141 |
142 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
143 | assert ((depth - 4) % 6 == 0)
144 | n = (depth - 4) / 6
145 | block = BasicBlock
146 | # 1st conv before any network block
147 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
148 | padding=1, bias=False)
149 | # 1st block
150 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
151 | if sub_block1:
152 | # 1st sub-block
153 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
154 | # 2nd block
155 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
156 | # 3rd block
157 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
158 | # global average pooling and classifier
159 | self.bn1 = nn.BatchNorm2d(nChannels[3])
160 | self.relu = nn.ReLU(inplace=True)
161 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
162 | self.nChannels = nChannels[3]
163 |
164 | for m in self.modules():
165 | if isinstance(m, nn.Conv2d):
166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
167 | m.weight.data.normal_(0, math.sqrt(2. / n))
168 | elif isinstance(m, nn.BatchNorm2d):
169 | m.weight.data.fill_(1)
170 | m.bias.data.zero_()
171 | elif isinstance(m, nn.Linear) and not m.bias is None:
172 | m.bias.data.zero_()
173 |
174 | def forward(self, x):
175 | out = (x - self.mean.to(x.device)) / self.std.to(x.device)
176 | out = self.conv1(out)
177 | out = self.block1(out)
178 | out = self.block2(out)
179 | out = self.block3(out)
180 | out = self.relu(self.bn1(out))
181 | out = F.avg_pool2d(out, 8)
182 | out = out.view(-1, self.nChannels)
183 | return self.fc(out)
184 |
185 |
186 | def WideResNet_70_16():
187 | return WideResNet(depth=70, widen_factor=16, dropRate=0.0)
188 |
189 |
190 | def WideResNet_70_16_dropout():
191 | return WideResNet(depth=70, widen_factor=16, dropRate=0.3)
192 | # ---------------------------- WideResNet ----------------------------
193 |
--------------------------------------------------------------------------------
/compute_accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | class Accuracy(object):
5 | def at_radii(self, radii: np.ndarray):
6 | raise NotImplementedError()
7 |
8 | class ApproximateAccuracy(Accuracy):
9 | def __init__(self, data_file_path: str):
10 | self.data_file_path = data_file_path
11 |
12 | def at_radii(self, radii: np.ndarray) -> np.ndarray:
13 | df = pd.read_csv(self.data_file_path, delimiter="\s+")
14 | return np.array([self.at_radius(df, radius) for radius in radii])
15 |
16 | def at_radius(self, df: pd.DataFrame, radius: float):
17 | return (df["correct"] & (df["radius"] >= radius)).mean()
18 |
19 | def get_abstention_rate(self) -> np.ndarray:
20 | df = pd.read_csv(self.data_file_path, delimiter="\t")
21 | return 1.*(df["predict"]==-1).sum()/len(df["predict"])*100
22 |
23 | acc = ApproximateAccuracy("results_file_path")
24 |
25 | def latex_table_certified_accuracy(radius):
26 | radii = [radius]
27 | accuracy = acc.at_radii(radii)
28 | print('certified_acc:'+str(accuracy))
29 |
30 | if __name__ == "__main__":
31 | #certified accuracy for imagenet
32 | # latex_table_certified_accuracy(0.00)
33 | # latex_table_certified_accuracy(0.50)
34 | # latex_table_certified_accuracy(1.00)
35 | # latex_table_certified_accuracy(1.50)
36 | # latex_table_certified_accuracy(2.00)
37 | # latex_table_certified_accuracy(3.00)
38 |
39 | # certified accuracy for cifar10
40 | latex_table_certified_accuracy(0.00)
41 | latex_table_certified_accuracy(0.25)
42 | latex_table_certified_accuracy(0.50)
43 | latex_table_certified_accuracy(0.75)
44 | latex_table_certified_accuracy(1.00)
45 |
--------------------------------------------------------------------------------
/configs/cifar10.yml:
--------------------------------------------------------------------------------
1 | model:
--------------------------------------------------------------------------------
/configs/imagenet.yml:
--------------------------------------------------------------------------------
1 | model:
2 | attention_resolutions: '32,16,8'
3 | class_cond: False
4 | diffusion_steps: 1000
5 | rescale_timesteps: True
6 | timestep_respacing: '1000' # Modify this value to decrease the number of timesteps.
7 | image_size: 256
8 | learn_sigma: True
9 | noise_schedule: 'linear'
10 | num_channels: 256
11 | num_head_channels: 64
12 | num_res_blocks: 2
13 | resblock_updown: True
14 | use_fp16: True
15 | use_scale_shift_norm: True
--------------------------------------------------------------------------------
/core.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 |
3 | import numpy as np
4 | from scipy.stats import norm, binom_test
5 | from statsmodels.stats.proportion import proportion_confint
6 | import torch
7 |
8 | class Smooth(object):
9 | """A smoothed classifier g """
10 |
11 | # to abstain, Smooth returns this int
12 | ABSTAIN = -1
13 |
14 | def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float):
15 | """
16 | :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes]
17 | :param num_classes:
18 | :param sigma: the noise level hyperparameter
19 | """
20 | self.base_classifier = base_classifier
21 | self.num_classes = num_classes
22 | self.sigma = sigma
23 |
24 | def certify(self, x: torch.tensor, n0: int, n: int, sample_id:int, alpha: float, batch_size: int, clustering_method='none') -> (int, float):
25 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
26 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
27 | robust within a L2 ball of radius R around x.
28 |
29 | :param x: the input [channel x height x width]
30 | :param n0: the number of Monte Carlo samples to use for selection
31 | :param n: the number of Monte Carlo samples to use for estimation
32 | :param alpha: the failure probability
33 | :param batch_size: batch size to use when evaluating the base classifier
34 | :return: (predicted class, certified radius)
35 | in the case of abstention, the class will be ABSTAIN and the radius 0.
36 | """
37 | self.base_classifier.eval()
38 | # draw samples of f(x+ epsilon)
39 | counts_selection, n0_predictions = self._sample_noise(x, n0, batch_size, clustering_method, sample_id)
40 | # use these samples to take a guess at the top class
41 | cAHat = counts_selection.argmax().item()
42 | # draw more samples of f(x + epsilon)
43 | counts_estimation, n_predictions = self._sample_noise(x, n, batch_size, clustering_method, sample_id)
44 | # use these samples to estimate a lower bound on pA
45 | nA = counts_estimation[cAHat].item()
46 | pABar = self._lower_confidence_bound(nA, n, alpha)
47 | if pABar < 0.5:
48 | return Smooth.ABSTAIN, 0.0, n0_predictions, n_predictions
49 | else:
50 | radius = self.sigma * norm.ppf(pABar)
51 | return cAHat, radius, n0_predictions, n_predictions
52 |
53 | def certify_noapproximate(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float):
54 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
55 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
56 | robust within a L2 ball of radius R around x.
57 |
58 | :param x: the input [channel x height x width]
59 | :param n0: the number of Monte Carlo samples to use for selection
60 | :param n: the number of Monte Carlo samples to use for estimation
61 | :param alpha: the failure probability
62 | :param batch_size: batch size to use when evaluating the base classifier
63 | :return: (predicted class, certified radius)
64 | in the case of abstention, the class will be ABSTAIN and the radius 0.
65 | """
66 | self.base_classifier.eval()
67 | # draw samples of f(x+ epsilon)
68 | counts_selection = self._sample_noise(x, n0, batch_size)
69 | # use these samples to take a guess at the top class
70 | cAHat = counts_selection.argmax().item()
71 | # draw more samples of f(x + epsilon)
72 | counts_estimation = self._sample_noise(x, n, batch_size)
73 | # use these samples to estimate a lower bound on pA
74 | top2 = counts_estimation.argsort()[::-1][:2]
75 | nA = counts_estimation[top2[0]].item()
76 | nB = counts_estimation[top2[1]].item()
77 |
78 | pABar = self._lower_confidence_bound(nA, n, alpha)
79 | pBBar = self._upper_confidence_bound(nB, n, alpha)
80 | if pABar < 0.5:
81 | return Smooth.ABSTAIN, 0.0
82 | else:
83 | radius = self.sigma/2 * (norm.ppf(pABar) - norm.ppf(pBBar))
84 | return cAHat, radius
85 |
86 | def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:
87 | """ Monte Carlo algorithm for evaluating the prediction of g at x. With probability at least 1 - alpha, the
88 | class returned by this method will equal g(x).
89 |
90 | This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944
91 | for identifying the top category of a multinomial distribution.
92 |
93 | :param x: the input [channel x height x width]
94 | :param n: the number of Monte Carlo samples to use
95 | :param alpha: the failure probability
96 | :param batch_size: batch size to use when evaluating the base classifier
97 | :return: the predicted class, or ABSTAIN
98 | """
99 | self.base_classifier.eval()
100 | counts = self._sample_noise(x, n, batch_size)
101 | top2 = counts.argsort()[::-1][:2]
102 | count1 = counts[top2[0]]
103 | count2 = counts[top2[1]]
104 | if binom_test(count1, count1 + count2, p=0.5) > alpha:
105 | return Smooth.ABSTAIN
106 | else:
107 | return top2[0]
108 |
109 | def _sample_noise(self, x: torch.tensor, num: int, batch_size, clustering_method='none', sample_id=None) -> np.ndarray:
110 | """ Sample the base classifier's prediction under noisy corruptions of the input x.
111 |
112 | :param x: the input [channel x width x height]
113 | :param num: number of samples to collect
114 | :param batch_size:
115 | :return: an ndarray[int] of length num_classes containing the per-class counts
116 | """
117 | with torch.no_grad():
118 | predictions_all = np.array([], dtype=int)
119 | counts = np.zeros(self.num_classes, dtype=int)
120 | for _ in range(ceil(num / batch_size)):
121 | this_batch_size = min(batch_size, num)
122 | num -= this_batch_size
123 |
124 | batch = x.repeat((this_batch_size, 1, 1, 1))
125 | noise = torch.randn_like(batch, device='cuda') * self.sigma
126 |
127 | if clustering_method == 'classifier':
128 | predictions = self.base_classifier(batch + noise, sample_id).argmax(1)
129 | predictions = predictions.view(this_batch_size,-1).cpu().numpy()
130 | count_max_list = np.zeros(this_batch_size,dtype=int)
131 | for i in range(this_batch_size):
132 | count_max = max(list(predictions[i]),key=list(predictions[i]).count)
133 | count_max_list[i] = count_max
134 | counts += self._count_arr(count_max_list, self.num_classes)
135 |
136 | else:
137 | predictions = self.base_classifier(batch + noise, sample_id).argmax(1)
138 | counts += self._count_arr(predictions.cpu().numpy(), self.num_classes)
139 | predictions_all = np.hstack((predictions_all, predictions.cpu().numpy()))
140 |
141 | return counts, predictions_all
142 |
143 | def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray:
144 | counts = np.zeros(length, dtype=int)
145 | for idx in arr:
146 | counts[idx] += 1
147 | return counts
148 |
149 | def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float:
150 | """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion.
151 |
152 | This function uses the Clopper-Pearson method.
153 |
154 | :param NA: the number of "successes"
155 | :param N: the number of total draws
156 | :param alpha: the confidence level
157 | :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples
158 | """
159 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]
160 |
161 | def _upper_confidence_bound(self, NA: int, N: int, alpha: float) -> float:
162 |
163 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[1]
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import imagenet_lmdb_dataset, imagenet_lmdb_dataset_sub, cifar10_dataset_sub
2 |
3 | def get_transform(dataset_name, transform_type, base_size=256):
4 | from . import datasets
5 | if dataset_name == 'celebahq':
6 | return datasets.get_transform(dataset_name, transform_type, base_size)
7 | elif 'imagenet' in dataset_name:
8 | return datasets.get_transform(dataset_name, transform_type, base_size)
9 | else:
10 | raise NotImplementedError
11 |
12 |
13 | def get_dataset(dataset_name, partition, *args, **kwargs):
14 | from . import datasets
15 | if dataset_name == 'celebahq':
16 | return datasets.CelebAHQDataset(partition, *args, **kwargs)
17 | else:
18 | raise NotImplementedError
--------------------------------------------------------------------------------
/ddpm/unet_ddpm.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def get_timestep_embedding(timesteps, embedding_dim):
7 | """
8 | This matches the implementation in Denoising Diffusion Probabilistic Models:
9 | From Fairseq.
10 | Build sinusoidal embeddings.
11 | This matches the implementation in tensor2tensor, but differs slightly
12 | from the description in Section 3.5 of "Attention Is All You Need".
13 | """
14 | assert len(timesteps.shape) == 1
15 |
16 | half_dim = embedding_dim // 2
17 | emb = math.log(10000) / (half_dim - 1)
18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
19 | emb = emb.to(device=timesteps.device)
20 | emb = timesteps.float()[:, None] * emb[None, :]
21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
22 | if embedding_dim % 2 == 1: # zero pad
23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
24 | return emb
25 |
26 |
27 | def nonlinearity(x):
28 | # swish
29 | return x * torch.sigmoid(x)
30 |
31 |
32 | def Normalize(in_channels):
33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34 |
35 |
36 | class Upsample(nn.Module):
37 | def __init__(self, in_channels, with_conv):
38 | super().__init__()
39 | self.with_conv = with_conv
40 | if self.with_conv:
41 | self.conv = torch.nn.Conv2d(in_channels,
42 | in_channels,
43 | kernel_size=3,
44 | stride=1,
45 | padding=1)
46 |
47 | def forward(self, x):
48 | x = torch.nn.functional.interpolate(
49 | x, scale_factor=2.0, mode="nearest")
50 | if self.with_conv:
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class Downsample(nn.Module):
56 | def __init__(self, in_channels, with_conv):
57 | super().__init__()
58 | self.with_conv = with_conv
59 | if self.with_conv:
60 | # no asymmetric padding in torch conv, must do it ourselves
61 | self.conv = torch.nn.Conv2d(in_channels,
62 | in_channels,
63 | kernel_size=3,
64 | stride=2,
65 | padding=0)
66 |
67 | def forward(self, x):
68 | if self.with_conv:
69 | pad = (0, 1, 0, 1)
70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71 | x = self.conv(x)
72 | else:
73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74 | return x
75 |
76 |
77 | class ResnetBlock(nn.Module):
78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79 | dropout, temb_channels=512):
80 | super().__init__()
81 | self.in_channels = in_channels
82 | out_channels = in_channels if out_channels is None else out_channels
83 | self.out_channels = out_channels
84 | self.use_conv_shortcut = conv_shortcut
85 |
86 | self.norm1 = Normalize(in_channels)
87 | self.conv1 = torch.nn.Conv2d(in_channels,
88 | out_channels,
89 | kernel_size=3,
90 | stride=1,
91 | padding=1)
92 | self.temb_proj = torch.nn.Linear(temb_channels,
93 | out_channels)
94 | self.norm2 = Normalize(out_channels)
95 | self.dropout = torch.nn.Dropout(dropout)
96 | self.conv2 = torch.nn.Conv2d(out_channels,
97 | out_channels,
98 | kernel_size=3,
99 | stride=1,
100 | padding=1)
101 | if self.in_channels != self.out_channels:
102 | if self.use_conv_shortcut:
103 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
104 | out_channels,
105 | kernel_size=3,
106 | stride=1,
107 | padding=1)
108 | else:
109 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
110 | out_channels,
111 | kernel_size=1,
112 | stride=1,
113 | padding=0)
114 |
115 | def forward(self, x, temb):
116 | h = x
117 | h = self.norm1(h)
118 | h = nonlinearity(h)
119 | h = self.conv1(h)
120 |
121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122 |
123 | h = self.norm2(h)
124 | h = nonlinearity(h)
125 | h = self.dropout(h)
126 | h = self.conv2(h)
127 |
128 | if self.in_channels != self.out_channels:
129 | if self.use_conv_shortcut:
130 | x = self.conv_shortcut(x)
131 | else:
132 | x = self.nin_shortcut(x)
133 |
134 | return x + h
135 |
136 |
137 | class AttnBlock(nn.Module):
138 | def __init__(self, in_channels):
139 | super().__init__()
140 | self.in_channels = in_channels
141 |
142 | self.norm = Normalize(in_channels)
143 | self.q = torch.nn.Conv2d(in_channels,
144 | in_channels,
145 | kernel_size=1,
146 | stride=1,
147 | padding=0)
148 | self.k = torch.nn.Conv2d(in_channels,
149 | in_channels,
150 | kernel_size=1,
151 | stride=1,
152 | padding=0)
153 | self.v = torch.nn.Conv2d(in_channels,
154 | in_channels,
155 | kernel_size=1,
156 | stride=1,
157 | padding=0)
158 | self.proj_out = torch.nn.Conv2d(in_channels,
159 | in_channels,
160 | kernel_size=1,
161 | stride=1,
162 | padding=0)
163 |
164 | def forward(self, x):
165 | h_ = x
166 | h_ = self.norm(h_)
167 | q = self.q(h_)
168 | k = self.k(h_)
169 | v = self.v(h_)
170 |
171 | # compute attention
172 | b, c, h, w = q.shape
173 | q = q.reshape(b, c, h * w)
174 | q = q.permute(0, 2, 1) # b,hw,c
175 | k = k.reshape(b, c, h * w) # b,c,hw
176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
177 | w_ = w_ * (int(c) ** (-0.5))
178 | w_ = torch.nn.functional.softmax(w_, dim=2)
179 |
180 | # attend to values
181 | v = v.reshape(b, c, h * w)
182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
184 | h_ = torch.bmm(v, w_)
185 | h_ = h_.reshape(b, c, h, w)
186 |
187 | h_ = self.proj_out(h_)
188 |
189 | return x + h_
190 |
191 |
192 | class Model(nn.Module):
193 | def __init__(self, config):
194 | super().__init__()
195 | self.config = config
196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
197 | num_res_blocks = config.model.num_res_blocks
198 | attn_resolutions = config.model.attn_resolutions
199 | dropout = config.model.dropout
200 | in_channels = config.model.in_channels
201 | resolution = config.data.image_size
202 | resamp_with_conv = config.model.resamp_with_conv
203 |
204 | self.ch = ch
205 | self.temb_ch = self.ch * 4
206 | self.num_resolutions = len(ch_mult)
207 | self.num_res_blocks = num_res_blocks
208 | self.resolution = resolution
209 | self.in_channels = in_channels
210 |
211 | # timestep embedding
212 | self.temb = nn.Module()
213 | self.temb.dense = nn.ModuleList([
214 | torch.nn.Linear(self.ch,
215 | self.temb_ch),
216 | torch.nn.Linear(self.temb_ch,
217 | self.temb_ch),
218 | ])
219 |
220 | # downsampling
221 | self.conv_in = torch.nn.Conv2d(in_channels,
222 | self.ch,
223 | kernel_size=3,
224 | stride=1,
225 | padding=1)
226 |
227 | curr_res = resolution
228 | in_ch_mult = (1,) + ch_mult
229 | self.down = nn.ModuleList()
230 | block_in = None
231 | for i_level in range(self.num_resolutions):
232 | block = nn.ModuleList()
233 | attn = nn.ModuleList()
234 | block_in = ch * in_ch_mult[i_level]
235 | block_out = ch * ch_mult[i_level]
236 | for i_block in range(self.num_res_blocks):
237 | block.append(ResnetBlock(in_channels=block_in,
238 | out_channels=block_out,
239 | temb_channels=self.temb_ch,
240 | dropout=dropout))
241 | block_in = block_out
242 | if curr_res in attn_resolutions:
243 | attn.append(AttnBlock(block_in))
244 | down = nn.Module()
245 | down.block = block
246 | down.attn = attn
247 | if i_level != self.num_resolutions - 1:
248 | down.downsample = Downsample(block_in, resamp_with_conv)
249 | curr_res = curr_res // 2
250 | self.down.append(down)
251 |
252 | # middle
253 | self.mid = nn.Module()
254 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
255 | out_channels=block_in,
256 | temb_channels=self.temb_ch,
257 | dropout=dropout)
258 | self.mid.attn_1 = AttnBlock(block_in)
259 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
260 | out_channels=block_in,
261 | temb_channels=self.temb_ch,
262 | dropout=dropout)
263 |
264 | # upsampling
265 | self.up = nn.ModuleList()
266 | for i_level in reversed(range(self.num_resolutions)):
267 | block = nn.ModuleList()
268 | attn = nn.ModuleList()
269 | block_out = ch * ch_mult[i_level]
270 | skip_in = ch * ch_mult[i_level]
271 | for i_block in range(self.num_res_blocks + 1):
272 | if i_block == self.num_res_blocks:
273 | skip_in = ch * in_ch_mult[i_level]
274 | block.append(ResnetBlock(in_channels=block_in + skip_in,
275 | out_channels=block_out,
276 | temb_channels=self.temb_ch,
277 | dropout=dropout))
278 | block_in = block_out
279 | if curr_res in attn_resolutions:
280 | attn.append(AttnBlock(block_in))
281 | up = nn.Module()
282 | up.block = block
283 | up.attn = attn
284 | if i_level != 0:
285 | up.upsample = Upsample(block_in, resamp_with_conv)
286 | curr_res = curr_res * 2
287 | self.up.insert(0, up) # prepend to get consistent order
288 |
289 | # end
290 | self.norm_out = Normalize(block_in)
291 | self.conv_out = torch.nn.Conv2d(block_in,
292 | out_ch,
293 | kernel_size=3,
294 | stride=1,
295 | padding=1)
296 |
297 | def forward(self, x, t):
298 | assert x.shape[2] == x.shape[3] == self.resolution
299 |
300 | # timestep embedding
301 | temb = get_timestep_embedding(t, self.ch)
302 | temb = self.temb.dense[0](temb)
303 | temb = nonlinearity(temb)
304 | temb = self.temb.dense[1](temb)
305 |
306 | # downsampling
307 | hs = [self.conv_in(x)]
308 | for i_level in range(self.num_resolutions):
309 | for i_block in range(self.num_res_blocks):
310 | h = self.down[i_level].block[i_block](hs[-1], temb)
311 | if len(self.down[i_level].attn) > 0:
312 | h = self.down[i_level].attn[i_block](h)
313 | hs.append(h)
314 | if i_level != self.num_resolutions - 1:
315 | hs.append(self.down[i_level].downsample(hs[-1]))
316 |
317 | # middle
318 | h = hs[-1]
319 | h = self.mid.block_1(h, temb)
320 | h = self.mid.attn_1(h)
321 | h = self.mid.block_2(h, temb)
322 |
323 | # upsampling
324 | for i_level in reversed(range(self.num_resolutions)):
325 | for i_block in range(self.num_res_blocks + 1):
326 | h = self.up[i_level].block[i_block](
327 | torch.cat([h, hs.pop()], dim=1), temb)
328 | if len(self.up[i_level].attn) > 0:
329 | h = self.up[i_level].attn[i_block](h)
330 | if i_level != 0:
331 | h = self.up[i_level].upsample(h)
332 |
333 | # end
334 | h = self.norm_out(h)
335 | h = nonlinearity(h)
336 | h = self.conv_out(h)
337 | return h
338 |
--------------------------------------------------------------------------------
/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/guided_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | from mpi4py import MPI
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | def load_data(
12 | *,
13 | data_dir,
14 | batch_size,
15 | image_size,
16 | class_cond=False,
17 | deterministic=False,
18 | random_crop=False,
19 | random_flip=True,
20 | ):
21 | """
22 | For a dataset, create a generator over (images, kwargs) pairs.
23 |
24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
25 | more keys, each of which map to a batched Tensor of their own.
26 | The kwargs dict can be used for class labels, in which case the key is "y"
27 | and the values are integer tensors of class labels.
28 |
29 | :param data_dir: a dataset directory.
30 | :param batch_size: the batch size of each returned pair.
31 | :param image_size: the size to which images are resized.
32 | :param class_cond: if True, include a "y" key in returned dicts for class
33 | label. If classes are not available and this is true, an
34 | exception will be raised.
35 | :param deterministic: if True, yield results in a deterministic order.
36 | :param random_crop: if True, randomly crop the images for augmentation.
37 | :param random_flip: if True, randomly flip the images for augmentation.
38 | """
39 | if not data_dir:
40 | raise ValueError("unspecified data directory")
41 | all_files = _list_image_files_recursively(data_dir)
42 | classes = None
43 | if class_cond:
44 | # Assume classes are the first part of the filename,
45 | # before an underscore.
46 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
48 | classes = [sorted_classes[x] for x in class_names]
49 | dataset = ImageDataset(
50 | image_size,
51 | all_files,
52 | classes=classes,
53 | shard=MPI.COMM_WORLD.Get_rank(),
54 | num_shards=MPI.COMM_WORLD.Get_size(),
55 | random_crop=random_crop,
56 | random_flip=random_flip,
57 | )
58 | if deterministic:
59 | loader = DataLoader(
60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
61 | )
62 | else:
63 | loader = DataLoader(
64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
65 | )
66 | while True:
67 | yield from loader
68 |
69 |
70 | def _list_image_files_recursively(data_dir):
71 | results = []
72 | for entry in sorted(bf.listdir(data_dir)):
73 | full_path = bf.join(data_dir, entry)
74 | ext = entry.split(".")[-1]
75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
76 | results.append(full_path)
77 | elif bf.isdir(full_path):
78 | results.extend(_list_image_files_recursively(full_path))
79 | return results
80 |
81 |
82 | class ImageDataset(Dataset):
83 | def __init__(
84 | self,
85 | resolution,
86 | image_paths,
87 | classes=None,
88 | shard=0,
89 | num_shards=1,
90 | random_crop=False,
91 | random_flip=True,
92 | ):
93 | super().__init__()
94 | self.resolution = resolution
95 | self.local_images = image_paths[shard:][::num_shards]
96 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
97 | self.random_crop = random_crop
98 | self.random_flip = random_flip
99 |
100 | def __len__(self):
101 | return len(self.local_images)
102 |
103 | def __getitem__(self, idx):
104 | path = self.local_images[idx]
105 | with bf.BlobFile(path, "rb") as f:
106 | pil_image = Image.open(f)
107 | pil_image.load()
108 | pil_image = pil_image.convert("RGB")
109 |
110 | if self.random_crop:
111 | arr = random_crop_arr(pil_image, self.resolution)
112 | else:
113 | arr = center_crop_arr(pil_image, self.resolution)
114 |
115 | if self.random_flip and random.random() < 0.5:
116 | arr = arr[:, ::-1]
117 |
118 | arr = arr.astype(np.float32) / 127.5 - 1
119 |
120 | out_dict = {}
121 | if self.local_classes is not None:
122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123 | return np.transpose(arr, [2, 0, 1]), out_dict
124 |
125 |
126 | def center_crop_arr(pil_image, image_size):
127 | # We are not on a new enough PIL to support the `reducing_gap`
128 | # argument, which uses BOX downsampling at powers of two first.
129 | # Thus, we do it by hand to improve downsample quality.
130 | while min(*pil_image.size) >= 2 * image_size:
131 | pil_image = pil_image.resize(
132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
133 | )
134 |
135 | scale = image_size / min(*pil_image.size)
136 | pil_image = pil_image.resize(
137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
138 | )
139 |
140 | arr = np.array(pil_image)
141 | crop_y = (arr.shape[0] - image_size) // 2
142 | crop_x = (arr.shape[1] - image_size) // 2
143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
144 |
145 |
146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
150 |
151 | # We are not on a new enough PIL to support the `reducing_gap`
152 | # argument, which uses BOX downsampling at powers of two first.
153 | # Thus, we do it by hand to improve downsample quality.
154 | while min(*pil_image.size) >= 2 * smaller_dim_size:
155 | pil_image = pil_image.resize(
156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
157 | )
158 |
159 | scale = smaller_dim_size / min(*pil_image.size)
160 | pil_image = pil_image.resize(
161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
162 | )
163 |
164 | arr = np.array(pil_image)
165 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
166 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
168 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch as th
4 |
5 |
6 | def normal_kl(mean1, logvar1, mean2, logvar2):
7 | """
8 | Compute the KL divergence between two gaussians.
9 |
10 | Shapes are automatically broadcasted, so batches can be compared to
11 | scalars, among other use cases.
12 | """
13 | tensor = None
14 | for obj in (mean1, logvar1, mean2, logvar2):
15 | if isinstance(obj, th.Tensor):
16 | tensor = obj
17 | break
18 | assert tensor is not None, "at least one argument must be a Tensor"
19 |
20 | # Force variances to be Tensors. Broadcasting helps convert scalars to
21 | # Tensors, but it does not work for th.exp().
22 | logvar1, logvar2 = [
23 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
24 | for x in (logvar1, logvar2)
25 | ]
26 |
27 | return 0.5 * (
28 | -1.0
29 | + logvar2
30 | - logvar1
31 | + th.exp(logvar1 - logvar2)
32 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
33 | )
34 |
35 |
36 | def approx_standard_normal_cdf(x):
37 | """
38 | A fast approximation of the cumulative distribution function of the
39 | standard normal.
40 | """
41 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
42 |
43 |
44 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
45 | """
46 | Compute the log-likelihood of a Gaussian distribution discretizing to a
47 | given image.
48 |
49 | :param x: the target images. It is assumed that this was uint8 values,
50 | rescaled to the range [-1, 1].
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | assert x.shape == means.shape == log_scales.shape
56 | centered_x = x - means
57 | inv_stdv = th.exp(-log_scales)
58 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
59 | cdf_plus = approx_standard_normal_cdf(plus_in)
60 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
61 | cdf_min = approx_standard_normal_cdf(min_in)
62 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
63 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
64 | cdf_delta = cdf_plus - cdf_min
65 | log_probs = th.where(
66 | x < -0.999,
67 | log_cdf_plus,
68 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
69 | )
70 | assert log_probs.shape == x.shape
71 | return log_probs
72 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch as th
4 | import torch.nn as nn
5 |
6 |
7 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
8 | class SiLU(nn.Module):
9 | def forward(self, x):
10 | return x * th.sigmoid(x)
11 |
12 |
13 | class GroupNorm32(nn.GroupNorm):
14 | def forward(self, x):
15 | return super().forward(x.float()).type(x.dtype)
16 |
17 |
18 | def conv_nd(dims, *args, **kwargs):
19 | """
20 | Create a 1D, 2D, or 3D convolution module.
21 | """
22 | if dims == 1:
23 | return nn.Conv1d(*args, **kwargs)
24 | elif dims == 2:
25 | return nn.Conv2d(*args, **kwargs)
26 | elif dims == 3:
27 | return nn.Conv3d(*args, **kwargs)
28 | raise ValueError(f"unsupported dimensions: {dims}")
29 |
30 |
31 | def linear(*args, **kwargs):
32 | """
33 | Create a linear module.
34 | """
35 | return nn.Linear(*args, **kwargs)
36 |
37 |
38 | def avg_pool_nd(dims, *args, **kwargs):
39 | """
40 | Create a 1D, 2D, or 3D average pooling module.
41 | """
42 | if dims == 1:
43 | return nn.AvgPool1d(*args, **kwargs)
44 | elif dims == 2:
45 | return nn.AvgPool2d(*args, **kwargs)
46 | elif dims == 3:
47 | return nn.AvgPool3d(*args, **kwargs)
48 | raise ValueError(f"unsupported dimensions: {dims}")
49 |
50 |
51 | def update_ema(target_params, source_params, rate=0.99):
52 | """
53 | Update target parameters to be closer to those of source parameters using
54 | an exponential moving average.
55 |
56 | :param target_params: the target parameter sequence.
57 | :param source_params: the source parameter sequence.
58 | :param rate: the EMA rate (closer to 1 means slower).
59 | """
60 | for targ, src in zip(target_params, source_params):
61 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
62 |
63 |
64 | def zero_module(module):
65 | """
66 | Zero out the parameters of a module and return it.
67 | """
68 | for p in module.parameters():
69 | p.detach().zero_()
70 | return module
71 |
72 |
73 | def scale_module(module, scale):
74 | """
75 | Scale the parameters of a module and return it.
76 | """
77 | for p in module.parameters():
78 | p.detach().mul_(scale)
79 | return module
80 |
81 |
82 | def mean_flat(tensor):
83 | """
84 | Take the mean over all non-batch dimensions.
85 | """
86 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
87 |
88 |
89 | def normalization(channels):
90 | """
91 | Make a standard normalization layer.
92 |
93 | :param channels: number of input channels.
94 | :return: an nn.Module for normalization.
95 | """
96 | return GroupNorm32(32, channels)
97 |
98 |
99 | def timestep_embedding(timesteps, dim, max_period=10000):
100 | """
101 | Create sinusoidal timestep embeddings.
102 |
103 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
104 | These may be fractional.
105 | :param dim: the dimension of the output.
106 | :param max_period: controls the minimum frequency of the embeddings.
107 | :return: an [N x dim] Tensor of positional embeddings.
108 | """
109 | half = dim // 2
110 | freqs = th.exp(
111 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
112 | ).to(device=timesteps.device)
113 | args = timesteps[:, None].float() * freqs[None]
114 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
115 | if dim % 2:
116 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
117 | return embedding
118 |
119 |
120 | def checkpoint(func, inputs, params, flag):
121 | """
122 | Evaluate a function without caching intermediate activations, allowing for
123 | reduced memory at the expense of extra compute in the backward pass.
124 |
125 | :param func: the function to evaluate.
126 | :param inputs: the argument sequence to pass to `func`.
127 | :param params: a sequence of parameters `func` depends on but does not
128 | explicitly take as arguments.
129 | :param flag: if False, disable gradient checkpointing.
130 | """
131 | if flag:
132 | args = tuple(inputs) + tuple(params)
133 | return CheckpointFunction.apply(func, len(inputs), *args)
134 | else:
135 | return func(*inputs)
136 |
137 |
138 | class CheckpointFunction(th.autograd.Function):
139 | @staticmethod
140 | def forward(ctx, run_function, length, *args):
141 | ctx.run_function = run_function
142 | ctx.input_tensors = list(args[:length])
143 | ctx.input_params = list(args[length:])
144 | with th.no_grad():
145 | output_tensors = ctx.run_function(*ctx.input_tensors)
146 | return output_tensors
147 |
148 | @staticmethod
149 | def backward(ctx, *output_grads):
150 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
151 | with th.enable_grad():
152 | # Fixes a bug where the first op in run_function modifies the
153 | # Tensor storage in place, which is not allowed for detach()'d
154 | # Tensors.
155 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
156 | output_tensors = ctx.run_function(*shallow_copies)
157 | input_grads = th.autograd.grad(
158 | output_tensors,
159 | ctx.input_tensors + ctx.input_params,
160 | output_grads,
161 | allow_unused=True,
162 | )
163 | del ctx.input_tensors
164 | del ctx.input_params
165 | del output_tensors
166 | return (None, None) + input_grads
167 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 |
16 | # For ImageNet experiments, this was a good default value.
17 | # We found that the lg_loss_scale quickly climbed to
18 | # 20-21 within the first ~1K steps of training.
19 | INITIAL_LOG_LOSS_SCALE = 20.0
20 |
21 |
22 | class TrainLoop:
23 | def __init__(
24 | self,
25 | *,
26 | model,
27 | diffusion,
28 | data,
29 | batch_size,
30 | microbatch,
31 | lr,
32 | ema_rate,
33 | log_interval,
34 | save_interval,
35 | resume_checkpoint,
36 | use_fp16=False,
37 | fp16_scale_growth=1e-3,
38 | schedule_sampler=None,
39 | weight_decay=0.0,
40 | lr_anneal_steps=0,
41 | ):
42 | self.model = model
43 | self.diffusion = diffusion
44 | self.data = data
45 | self.batch_size = batch_size
46 | self.microbatch = microbatch if microbatch > 0 else batch_size
47 | self.lr = lr
48 | self.ema_rate = (
49 | [ema_rate]
50 | if isinstance(ema_rate, float)
51 | else [float(x) for x in ema_rate.split(",")]
52 | )
53 | self.log_interval = log_interval
54 | self.save_interval = save_interval
55 | self.resume_checkpoint = resume_checkpoint
56 | self.use_fp16 = use_fp16
57 | self.fp16_scale_growth = fp16_scale_growth
58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
59 | self.weight_decay = weight_decay
60 | self.lr_anneal_steps = lr_anneal_steps
61 |
62 | self.step = 0
63 | self.resume_step = 0
64 | self.global_batch = self.batch_size * dist.get_world_size()
65 |
66 | self.sync_cuda = th.cuda.is_available()
67 |
68 | self._load_and_sync_parameters()
69 | self.mp_trainer = MixedPrecisionTrainer(
70 | model=self.model,
71 | use_fp16=self.use_fp16,
72 | fp16_scale_growth=fp16_scale_growth,
73 | )
74 |
75 | self.opt = AdamW(
76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
77 | )
78 | if self.resume_step:
79 | self._load_optimizer_state()
80 | # Model was resumed, either due to a restart or a checkpoint
81 | # being specified at the command line.
82 | self.ema_params = [
83 | self._load_ema_parameters(rate) for rate in self.ema_rate
84 | ]
85 | else:
86 | self.ema_params = [
87 | copy.deepcopy(self.mp_trainer.master_params)
88 | for _ in range(len(self.ema_rate))
89 | ]
90 |
91 | if th.cuda.is_available():
92 | self.use_ddp = True
93 | self.ddp_model = DDP(
94 | self.model,
95 | device_ids=[dist_util.dev()],
96 | output_device=dist_util.dev(),
97 | broadcast_buffers=False,
98 | bucket_cap_mb=128,
99 | find_unused_parameters=False,
100 | )
101 | else:
102 | if dist.get_world_size() > 1:
103 | logger.warn(
104 | "Distributed training requires CUDA. "
105 | "Gradients will not be synchronized properly!"
106 | )
107 | self.use_ddp = False
108 | self.ddp_model = self.model
109 |
110 | def _load_and_sync_parameters(self):
111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
112 |
113 | if resume_checkpoint:
114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
115 | if dist.get_rank() == 0:
116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
117 | self.model.load_state_dict(
118 | dist_util.load_state_dict(
119 | resume_checkpoint, map_location=dist_util.dev()
120 | )
121 | )
122 |
123 | dist_util.sync_params(self.model.parameters())
124 |
125 | def _load_ema_parameters(self, rate):
126 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
127 |
128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
130 | if ema_checkpoint:
131 | if dist.get_rank() == 0:
132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
133 | state_dict = dist_util.load_state_dict(
134 | ema_checkpoint, map_location=dist_util.dev()
135 | )
136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
137 |
138 | dist_util.sync_params(ema_params)
139 | return ema_params
140 |
141 | def _load_optimizer_state(self):
142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
143 | opt_checkpoint = bf.join(
144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
145 | )
146 | if bf.exists(opt_checkpoint):
147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
148 | state_dict = dist_util.load_state_dict(
149 | opt_checkpoint, map_location=dist_util.dev()
150 | )
151 | self.opt.load_state_dict(state_dict)
152 |
153 | def run_loop(self):
154 | while (
155 | not self.lr_anneal_steps
156 | or self.step + self.resume_step < self.lr_anneal_steps
157 | ):
158 | batch, cond = next(self.data)
159 | self.run_step(batch, cond)
160 | if self.step % self.log_interval == 0:
161 | logger.dumpkvs()
162 | if self.step % self.save_interval == 0:
163 | self.save()
164 | # Run for a finite amount of time in integration tests.
165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
166 | return
167 | self.step += 1
168 | # Save the last checkpoint if it wasn't already saved.
169 | if (self.step - 1) % self.save_interval != 0:
170 | self.save()
171 |
172 | def run_step(self, batch, cond):
173 | self.forward_backward(batch, cond)
174 | took_step = self.mp_trainer.optimize(self.opt)
175 | if took_step:
176 | self._update_ema()
177 | self._anneal_lr()
178 | self.log_step()
179 |
180 | def forward_backward(self, batch, cond):
181 | self.mp_trainer.zero_grad()
182 | for i in range(0, batch.shape[0], self.microbatch):
183 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
184 | micro_cond = {
185 | k: v[i : i + self.microbatch].to(dist_util.dev())
186 | for k, v in cond.items()
187 | }
188 | last_batch = (i + self.microbatch) >= batch.shape[0]
189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
190 |
191 | compute_losses = functools.partial(
192 | self.diffusion.training_losses,
193 | self.ddp_model,
194 | micro,
195 | t,
196 | model_kwargs=micro_cond,
197 | )
198 |
199 | if last_batch or not self.use_ddp:
200 | losses = compute_losses()
201 | else:
202 | with self.ddp_model.no_sync():
203 | losses = compute_losses()
204 |
205 | if isinstance(self.schedule_sampler, LossAwareSampler):
206 | self.schedule_sampler.update_with_local_losses(
207 | t, losses["loss"].detach()
208 | )
209 |
210 | loss = (losses["loss"] * weights).mean()
211 | log_loss_dict(
212 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
213 | )
214 | self.mp_trainer.backward(loss)
215 |
216 | def _update_ema(self):
217 | for rate, params in zip(self.ema_rate, self.ema_params):
218 | update_ema(params, self.mp_trainer.master_params, rate=rate)
219 |
220 | def _anneal_lr(self):
221 | if not self.lr_anneal_steps:
222 | return
223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
224 | lr = self.lr * (1 - frac_done)
225 | for param_group in self.opt.param_groups:
226 | param_group["lr"] = lr
227 |
228 | def log_step(self):
229 | logger.logkv("step", self.step + self.resume_step)
230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
231 |
232 | def save(self):
233 | def save_checkpoint(rate, params):
234 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
235 | if dist.get_rank() == 0:
236 | logger.log(f"saving model {rate}...")
237 | if not rate:
238 | filename = f"model{(self.step+self.resume_step):06d}.pt"
239 | else:
240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
242 | th.save(state_dict, f)
243 |
244 | save_checkpoint(0, self.mp_trainer.master_params)
245 | for rate, params in zip(self.ema_rate, self.ema_params):
246 | save_checkpoint(rate, params)
247 |
248 | if dist.get_rank() == 0:
249 | with bf.BlobFile(
250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
251 | "wb",
252 | ) as f:
253 | th.save(self.opt.state_dict(), f)
254 |
255 | dist.barrier()
256 |
257 |
258 | def parse_resume_step_from_filename(filename):
259 | """
260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
261 | checkpoint's number of steps.
262 | """
263 | split = filename.split("model")
264 | if len(split) < 2:
265 | return 0
266 | split1 = split[-1].split(".")[0]
267 | try:
268 | return int(split1)
269 | except ValueError:
270 | return 0
271 |
272 |
273 | def get_blob_logdir():
274 | # You can change this to be a separate path to save checkpoints to
275 | # a blobstore or some external drive.
276 | return logger.get_dir()
277 |
278 |
279 | def find_resume_checkpoint():
280 | # On your infrastructure, you may want to override this to automatically
281 | # discover the latest checkpoint on your blob storage, etc.
282 | return None
283 |
284 |
285 | def find_ema_checkpoint(main_checkpoint, step, rate):
286 | if main_checkpoint is None:
287 | return None
288 | filename = f"ema_{rate}_{(step):06d}.pt"
289 | path = bf.join(bf.dirname(main_checkpoint), filename)
290 | if bf.exists(path):
291 | return path
292 | return None
293 |
294 |
295 | def log_loss_dict(diffusion, ts, losses):
296 | for key, values in losses.items():
297 | logger.logkv_mean(key, values.mean().item())
298 | # Log the quantiles (four quartiles, in particular).
299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
300 | quartile = int(4 * sub_t / diffusion.num_timesteps)
301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
302 |
--------------------------------------------------------------------------------
/improved_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/improved_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | # from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 |
28 | comm = MPI.COMM_WORLD
29 | backend = "gloo" if not th.cuda.is_available() else "nccl"
30 |
31 | if backend == "gloo":
32 | hostname = "localhost"
33 | else:
34 | hostname = socket.gethostbyname(socket.getfqdn())
35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36 | os.environ["RANK"] = str(comm.rank)
37 | os.environ["WORLD_SIZE"] = str(comm.size)
38 |
39 | port = comm.bcast(_find_free_port(), root=0)
40 | os.environ["MASTER_PORT"] = str(port)
41 | dist.init_process_group(backend=backend, init_method="env://")
42 |
43 |
44 | def dev():
45 | """
46 | Get the device to use for torch.distributed.
47 | """
48 | if th.cuda.is_available():
49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
50 | return th.device("cpu")
51 |
52 |
53 | def load_state_dict(path, **kwargs):
54 | """
55 | Load a PyTorch file without redundant fetches across MPI ranks.
56 | """
57 | if 0 == 0:
58 | with bf.BlobFile(path, "rb") as f:
59 | data = f.read()
60 | else:
61 | data = None
62 | # data = MPI.COMM_WORLD.bcast(data)
63 | return th.load(io.BytesIO(data), **kwargs)
64 |
65 |
66 | def sync_params(params):
67 | """
68 | Synchronize a sequence of Tensors across ranks from rank 0.
69 | """
70 | for p in params:
71 | with th.no_grad():
72 | dist.broadcast(p, 0)
73 |
74 |
75 | def _find_free_port():
76 | try:
77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
78 | s.bind(("", 0))
79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
80 | return s.getsockname()[1]
81 | finally:
82 | s.close()
83 |
--------------------------------------------------------------------------------
/improved_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import torch.nn as nn
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 |
8 |
9 | def convert_module_to_f16(l):
10 | """
11 | Convert primitive modules to float16.
12 | """
13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14 | l.weight.data = l.weight.data.half()
15 | l.bias.data = l.bias.data.half()
16 |
17 |
18 | def convert_module_to_f32(l):
19 | """
20 | Convert primitive modules to float32, undoing convert_module_to_f16().
21 | """
22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23 | l.weight.data = l.weight.data.float()
24 | l.bias.data = l.bias.data.float()
25 |
26 |
27 | def make_master_params(model_params):
28 | """
29 | Copy model parameters into a (differently-shaped) list of full-precision
30 | parameters.
31 | """
32 | master_params = _flatten_dense_tensors(
33 | [param.detach().float() for param in model_params]
34 | )
35 | master_params = nn.Parameter(master_params)
36 | master_params.requires_grad = True
37 | return [master_params]
38 |
39 |
40 | def model_grads_to_master_grads(model_params, master_params):
41 | """
42 | Copy the gradients from the model parameters into the master parameters
43 | from make_master_params().
44 | """
45 | master_params[0].grad = _flatten_dense_tensors(
46 | [param.grad.data.detach().float() for param in model_params]
47 | )
48 |
49 |
50 | def master_params_to_model_params(model_params, master_params):
51 | """
52 | Copy the master parameter data back into the model parameters.
53 | """
54 | # Without copying to a list, if a generator is passed, this will
55 | # silently not copy any parameters.
56 | model_params = list(model_params)
57 |
58 | for param, master_param in zip(
59 | model_params, unflatten_master_params(model_params, master_params)
60 | ):
61 | param.detach().copy_(master_param)
62 |
63 |
64 | def unflatten_master_params(model_params, master_params):
65 | """
66 | Unflatten the master parameters to look like model_params.
67 | """
68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69 |
70 |
71 | def zero_grad(model_params):
72 | for param in model_params:
73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74 | if param.grad is not None:
75 | param.grad.detach_()
76 | param.grad.zero_()
77 |
--------------------------------------------------------------------------------
/improved_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import blobfile as bf
3 | from mpi4py import MPI
4 | import numpy as np
5 | from torch.utils.data import DataLoader, Dataset
6 |
7 |
8 | def load_data(
9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False
10 | ):
11 | """
12 | For a dataset, create a generator over (images, kwargs) pairs.
13 |
14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
15 | more keys, each of which map to a batched Tensor of their own.
16 | The kwargs dict can be used for class labels, in which case the key is "y"
17 | and the values are integer tensors of class labels.
18 |
19 | :param data_dir: a dataset directory.
20 | :param batch_size: the batch size of each returned pair.
21 | :param image_size: the size to which images are resized.
22 | :param class_cond: if True, include a "y" key in returned dicts for class
23 | label. If classes are not available and this is true, an
24 | exception will be raised.
25 | :param deterministic: if True, yield results in a deterministic order.
26 | """
27 | if not data_dir:
28 | raise ValueError("unspecified data directory")
29 | all_files = _list_image_files_recursively(data_dir)
30 | classes = None
31 | if class_cond:
32 | # Assume classes are the first part of the filename,
33 | # before an underscore.
34 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
36 | classes = [sorted_classes[x] for x in class_names]
37 | dataset = ImageDataset(
38 | image_size,
39 | all_files,
40 | classes=classes,
41 | shard=MPI.COMM_WORLD.Get_rank(),
42 | num_shards=MPI.COMM_WORLD.Get_size(),
43 | )
44 | if deterministic:
45 | loader = DataLoader(
46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
47 | )
48 | else:
49 | loader = DataLoader(
50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
51 | )
52 | while True:
53 | yield from loader
54 |
55 |
56 | def _list_image_files_recursively(data_dir):
57 | results = []
58 | for entry in sorted(bf.listdir(data_dir)):
59 | full_path = bf.join(data_dir, entry)
60 | ext = entry.split(".")[-1]
61 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
62 | results.append(full_path)
63 | elif bf.isdir(full_path):
64 | results.extend(_list_image_files_recursively(full_path))
65 | return results
66 |
67 |
68 | class ImageDataset(Dataset):
69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1):
70 | super().__init__()
71 | self.resolution = resolution
72 | self.local_images = image_paths[shard:][::num_shards]
73 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
74 |
75 | def __len__(self):
76 | return len(self.local_images)
77 |
78 | def __getitem__(self, idx):
79 | path = self.local_images[idx]
80 | with bf.BlobFile(path, "rb") as f:
81 | pil_image = Image.open(f)
82 | pil_image.load()
83 |
84 | # We are not on a new enough PIL to support the `reducing_gap`
85 | # argument, which uses BOX downsampling at powers of two first.
86 | # Thus, we do it by hand to improve downsample quality.
87 | while min(*pil_image.size) >= 2 * self.resolution:
88 | pil_image = pil_image.resize(
89 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
90 | )
91 |
92 | scale = self.resolution / min(*pil_image.size)
93 | pil_image = pil_image.resize(
94 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
95 | )
96 |
97 | arr = np.array(pil_image.convert("RGB"))
98 | crop_y = (arr.shape[0] - self.resolution) // 2
99 | crop_x = (arr.shape[1] - self.resolution) // 2
100 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
101 | arr = arr.astype(np.float32) / 127.5 - 1
102 |
103 | out_dict = {}
104 | if self.local_classes is not None:
105 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
106 | return np.transpose(arr, [2, 0, 1]), out_dict
107 |
--------------------------------------------------------------------------------
/improved_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch as th
4 |
5 |
6 | def normal_kl(mean1, logvar1, mean2, logvar2):
7 | """
8 | Compute the KL divergence between two gaussians.
9 |
10 | Shapes are automatically broadcasted, so batches can be compared to
11 | scalars, among other use cases.
12 | """
13 | tensor = None
14 | for obj in (mean1, logvar1, mean2, logvar2):
15 | if isinstance(obj, th.Tensor):
16 | tensor = obj
17 | break
18 | assert tensor is not None, "at least one argument must be a Tensor"
19 |
20 | # Force variances to be Tensors. Broadcasting helps convert scalars to
21 | # Tensors, but it does not work for th.exp().
22 | logvar1, logvar2 = [
23 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
24 | for x in (logvar1, logvar2)
25 | ]
26 |
27 | return 0.5 * (
28 | -1.0
29 | + logvar2
30 | - logvar1
31 | + th.exp(logvar1 - logvar2)
32 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
33 | )
34 |
35 |
36 | def approx_standard_normal_cdf(x):
37 | """
38 | A fast approximation of the cumulative distribution function of the
39 | standard normal.
40 | """
41 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
42 |
43 |
44 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
45 | """
46 | Compute the log-likelihood of a Gaussian distribution discretizing to a
47 | given image.
48 |
49 | :param x: the target images. It is assumed that this was uint8 values,
50 | rescaled to the range [-1, 1].
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | assert x.shape == means.shape == log_scales.shape
56 | centered_x = x - means
57 | inv_stdv = th.exp(-log_scales)
58 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
59 | cdf_plus = approx_standard_normal_cdf(plus_in)
60 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
61 | cdf_min = approx_standard_normal_cdf(min_in)
62 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
63 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
64 | cdf_delta = cdf_plus - cdf_min
65 | log_probs = th.where(
66 | x < -0.999,
67 | log_cdf_plus,
68 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
69 | )
70 | assert log_probs.shape == x.shape
71 | return log_probs
72 |
--------------------------------------------------------------------------------
/improved_diffusion/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .lenet import *
2 | from .vggnet import *
3 | from .resnet import *
4 | from .wide_resnet import *
5 |
--------------------------------------------------------------------------------
/improved_diffusion/networks/lenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def conv_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
8 | init.constant(m.bias, 0)
9 |
10 | class LeNet(nn.Module):
11 | def __init__(self, num_classes):
12 | super(LeNet, self).__init__()
13 | self.conv1 = nn.Conv2d(3, 6, 5)
14 | self.conv2 = nn.Conv2d(6, 16, 5)
15 | self.fc1 = nn.Linear(16*5*5, 120)
16 | self.fc2 = nn.Linear(120, 84)
17 | self.fc3 = nn.Linear(84, num_classes)
18 |
19 | def forward(self, x):
20 | out = F.relu(self.conv1(x))
21 | out = F.max_pool2d(out, 2)
22 | out = F.relu(self.conv2(out))
23 | out = F.max_pool2d(out, 2)
24 | out = out.view(out.size(0), -1)
25 | out = F.relu(self.fc1(out))
26 | out = F.relu(self.fc2(out))
27 | out = self.fc3(out)
28 |
29 | return(out)
30 |
--------------------------------------------------------------------------------
/improved_diffusion/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 | import sys
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
10 |
11 | def conv_init(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Conv') != -1:
14 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
15 | init.constant(m.bias, 0)
16 |
17 | def cfg(depth):
18 | depth_lst = [18, 34, 50, 101, 152]
19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
20 | cf_dict = {
21 | '18': (BasicBlock, [2,2,2,2]),
22 | '34': (BasicBlock, [3,4,6,3]),
23 | '50': (Bottleneck, [3,4,6,3]),
24 | '101':(Bottleneck, [3,4,23,3]),
25 | '152':(Bottleneck, [3,8,36,3]),
26 | }
27 |
28 | return cf_dict[str(depth)]
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, in_planes, planes, stride=1):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(in_planes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 |
40 | self.shortcut = nn.Sequential()
41 | if stride != 1 or in_planes != self.expansion * planes:
42 | self.shortcut = nn.Sequential(
43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
44 | nn.BatchNorm2d(self.expansion*planes)
45 | )
46 |
47 | def forward(self, x):
48 | out = F.relu(self.bn1(self.conv1(x)))
49 | out = self.bn2(self.conv2(out))
50 | out += self.shortcut(x)
51 | out = F.relu(out)
52 |
53 | return out
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, in_planes, planes, stride=1):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True)
65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
66 |
67 | self.shortcut = nn.Sequential()
68 | if stride != 1 or in_planes != self.expansion*planes:
69 | self.shortcut = nn.Sequential(
70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
71 | nn.BatchNorm2d(self.expansion*planes)
72 | )
73 |
74 | def forward(self, x):
75 | out = F.relu(self.bn1(self.conv1(x)))
76 | out = F.relu(self.bn2(self.conv2(out)))
77 | out = self.bn3(self.conv3(out))
78 | out += self.shortcut(x)
79 | out = F.relu(out)
80 |
81 | return out
82 |
83 | class ResNet(nn.Module):
84 | def __init__(self, depth, num_classes):
85 | super(ResNet, self).__init__()
86 | self.in_planes = 16
87 |
88 | block, num_blocks = cfg(depth)
89 |
90 | self.conv1 = conv3x3(3,16)
91 | self.bn1 = nn.BatchNorm2d(16)
92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
95 | self.linear = nn.Linear(64*block.expansion, num_classes)
96 |
97 | def _make_layer(self, block, planes, num_blocks, stride):
98 | strides = [stride] + [1]*(num_blocks-1)
99 | layers = []
100 |
101 | for stride in strides:
102 | layers.append(block(self.in_planes, planes, stride))
103 | self.in_planes = planes * block.expansion
104 |
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | out = F.relu(self.bn1(self.conv1(x)))
109 | out = self.layer1(out)
110 | out = self.layer2(out)
111 | out = self.layer3(out)
112 | out = F.avg_pool2d(out, 8)
113 | out = out.view(out.size(0), -1)
114 | out = self.linear(out)
115 |
116 | return out
117 |
118 | if __name__ == '__main__':
119 | net=ResNet(50, 10)
120 | y = net(Variable(torch.randn(1,3,32,32)))
121 | print(y.size())
122 |
--------------------------------------------------------------------------------
/improved_diffusion/networks/vggnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | init.constant(m.bias, 0)
10 |
11 | def cfg(depth):
12 | depth_lst = [11, 13, 16, 19]
13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19"
14 | cf_dict = {
15 | '11': [
16 | 64, 'mp',
17 | 128, 'mp',
18 | 256, 256, 'mp',
19 | 512, 512, 'mp',
20 | 512, 512, 'mp'],
21 | '13': [
22 | 64, 64, 'mp',
23 | 128, 128, 'mp',
24 | 256, 256, 'mp',
25 | 512, 512, 'mp',
26 | 512, 512, 'mp'
27 | ],
28 | '16': [
29 | 64, 64, 'mp',
30 | 128, 128, 'mp',
31 | 256, 256, 256, 'mp',
32 | 512, 512, 512, 'mp',
33 | 512, 512, 512, 'mp'
34 | ],
35 | '19': [
36 | 64, 64, 'mp',
37 | 128, 128, 'mp',
38 | 256, 256, 256, 256, 'mp',
39 | 512, 512, 512, 512, 'mp',
40 | 512, 512, 512, 512, 'mp'
41 | ],
42 | }
43 |
44 | return cf_dict[str(depth)]
45 |
46 | def conv3x3(in_planes, out_planes, stride=1):
47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
48 |
49 | class VGG(nn.Module):
50 | def __init__(self, depth, num_classes):
51 | super(VGG, self).__init__()
52 | self.features = self._make_layers(cfg(depth))
53 | self.linear = nn.Linear(512, num_classes)
54 |
55 | def forward(self, x):
56 | out = self.features(x)
57 | out = out.view(out.size(0), -1)
58 | out = self.linear(out)
59 |
60 | return out
61 |
62 | def _make_layers(self, cfg):
63 | layers = []
64 | in_planes = 3
65 |
66 | for x in cfg:
67 | if x == 'mp':
68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
69 | else:
70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
71 | in_planes = x
72 |
73 | # After cfg convolution
74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
75 | return nn.Sequential(*layers)
76 |
77 | if __name__ == "__main__":
78 | net = VGG(16, 10)
79 | y = net(Variable(torch.randn(1,3,32,32)))
80 | print(y.size())
81 |
--------------------------------------------------------------------------------
/improved_diffusion/networks/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import sys
8 | import numpy as np
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
12 |
13 | def conv_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
17 | init.constant_(m.bias, 0)
18 | elif classname.find('BatchNorm') != -1:
19 | init.constant_(m.weight, 1)
20 | init.constant_(m.bias, 0)
21 |
22 | class wide_basic(nn.Module):
23 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
24 | super(wide_basic, self).__init__()
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
27 | self.dropout = nn.Dropout(p=dropout_rate)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
30 |
31 | self.shortcut = nn.Sequential()
32 | if stride != 1 or in_planes != planes:
33 | self.shortcut = nn.Sequential(
34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
35 | )
36 |
37 | def forward(self, x):
38 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
39 | out = self.conv2(F.relu(self.bn2(out)))
40 | out += self.shortcut(x)
41 |
42 | return out
43 |
44 | class Wide_ResNet(nn.Module):
45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
46 | super(Wide_ResNet, self).__init__()
47 | self.in_planes = 16
48 |
49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
50 | n = (depth-4)/6
51 | k = widen_factor
52 |
53 | print('| Wide-Resnet %dx%d' %(depth, k))
54 | nStages = [16, 16*k, 32*k, 64*k]
55 |
56 | self.conv1 = conv3x3(3,nStages[0])
57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
61 | self.linear = nn.Linear(nStages[3], num_classes)
62 |
63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
64 | strides = [stride] + [1]*(int(num_blocks)-1)
65 | layers = []
66 |
67 | for stride in strides:
68 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
69 | self.in_planes = planes
70 |
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | out = self.conv1(x)
75 | out = self.layer1(out)
76 | out = self.layer2(out)
77 | out = self.layer3(out)
78 | out = F.relu(self.bn1(out))
79 | out = F.avg_pool2d(out, 8)
80 | out = out.view(out.size(0), -1)
81 | out = self.linear(out)
82 |
83 | return out
84 |
85 | if __name__ == '__main__':
86 | net=Wide_ResNet(28, 10, 0.3, 10)
87 | y = net(Variable(torch.randn(1,3,32,32)))
88 |
89 | print(y.size())
90 |
--------------------------------------------------------------------------------
/improved_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/improved_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/improved_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def _wrap_model(self, model):
99 | if isinstance(model, _WrappedModel):
100 | return model
101 | return _WrappedModel(
102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
103 | )
104 |
105 | def _scale_timesteps(self, t):
106 | # Scaling is done by the wrapped model.
107 | return t
108 |
109 |
110 | class _WrappedModel:
111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
112 | self.model = model
113 | self.timestep_map = timestep_map
114 | self.rescale_timesteps = rescale_timesteps
115 | self.original_num_steps = original_num_steps
116 |
117 | def __call__(self, x, ts, **kwargs):
118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
119 | new_ts = map_tensor[ts]
120 | if self.rescale_timesteps:
121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
122 | return self.model(x, new_ts, **kwargs)
123 |
--------------------------------------------------------------------------------
/improved_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | from .unet import SuperResModel, UNetModel
7 |
8 | NUM_CLASSES = 1000
9 |
10 |
11 | def model_and_diffusion_defaults(t_total):
12 | """
13 | Defaults for image training.
14 | """
15 | return dict(
16 | image_size=32,
17 | num_channels=128,
18 | num_res_blocks=3,
19 | num_heads=4,
20 | num_heads_upsample=-1,
21 | attention_resolutions="16,8",
22 | dropout=0.3,
23 | learn_sigma=True,
24 | sigma_small=False,
25 | class_cond=False,
26 | diffusion_steps=t_total,
27 | noise_schedule="cosine",
28 | timestep_respacing="",
29 | use_kl=False,
30 | predict_xstart=False,
31 | rescale_timesteps=True,
32 | rescale_learned_sigmas=True,
33 | use_checkpoint=False,
34 | use_scale_shift_norm=True,
35 | )
36 |
37 |
38 | def create_model_and_diffusion(
39 | image_size,
40 | class_cond,
41 | learn_sigma,
42 | sigma_small,
43 | num_channels,
44 | num_res_blocks,
45 | num_heads,
46 | num_heads_upsample,
47 | attention_resolutions,
48 | dropout,
49 | diffusion_steps,
50 | noise_schedule,
51 | timestep_respacing,
52 | use_kl,
53 | predict_xstart,
54 | rescale_timesteps,
55 | rescale_learned_sigmas,
56 | use_checkpoint,
57 | use_scale_shift_norm,
58 | ):
59 | model = create_model(
60 | image_size,
61 | num_channels,
62 | num_res_blocks,
63 | learn_sigma=learn_sigma,
64 | class_cond=class_cond,
65 | use_checkpoint=use_checkpoint,
66 | attention_resolutions=attention_resolutions,
67 | num_heads=num_heads,
68 | num_heads_upsample=num_heads_upsample,
69 | use_scale_shift_norm=use_scale_shift_norm,
70 | dropout=dropout,
71 | )
72 | diffusion = create_gaussian_diffusion(
73 | steps=diffusion_steps,
74 | learn_sigma=learn_sigma,
75 | sigma_small=sigma_small,
76 | noise_schedule=noise_schedule,
77 | use_kl=use_kl,
78 | predict_xstart=predict_xstart,
79 | rescale_timesteps=rescale_timesteps,
80 | rescale_learned_sigmas=rescale_learned_sigmas,
81 | timestep_respacing=timestep_respacing,
82 | )
83 | return model, diffusion
84 |
85 |
86 | def create_model(
87 | image_size,
88 | num_channels,
89 | num_res_blocks,
90 | learn_sigma,
91 | class_cond,
92 | use_checkpoint,
93 | attention_resolutions,
94 | num_heads,
95 | num_heads_upsample,
96 | use_scale_shift_norm,
97 | dropout,
98 | ):
99 | if image_size == 256:
100 | channel_mult = (1, 1, 2, 2, 4, 4)
101 | elif image_size == 64:
102 | channel_mult = (1, 2, 3, 4)
103 | elif image_size == 32:
104 | channel_mult = (1, 2, 2, 2)
105 | else:
106 | raise ValueError(f"unsupported image size: {image_size}")
107 |
108 | attention_ds = []
109 | for res in attention_resolutions.split(","):
110 | attention_ds.append(image_size // int(res))
111 |
112 | return UNetModel(
113 | in_channels=3,
114 | model_channels=num_channels,
115 | out_channels=(3 if not learn_sigma else 6),
116 | num_res_blocks=num_res_blocks,
117 | attention_resolutions=tuple(attention_ds),
118 | dropout=dropout,
119 | channel_mult=channel_mult,
120 | num_classes=(NUM_CLASSES if class_cond else None),
121 | use_checkpoint=use_checkpoint,
122 | num_heads=num_heads,
123 | num_heads_upsample=num_heads_upsample,
124 | use_scale_shift_norm=use_scale_shift_norm,
125 | )
126 |
127 |
128 | def sr_model_and_diffusion_defaults():
129 | res = model_and_diffusion_defaults()
130 | res["large_size"] = 256
131 | res["small_size"] = 64
132 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
133 | for k in res.copy().keys():
134 | if k not in arg_names:
135 | del res[k]
136 | return res
137 |
138 |
139 | def sr_create_model_and_diffusion(
140 | large_size,
141 | small_size,
142 | class_cond,
143 | learn_sigma,
144 | num_channels,
145 | num_res_blocks,
146 | num_heads,
147 | num_heads_upsample,
148 | attention_resolutions,
149 | dropout,
150 | diffusion_steps,
151 | noise_schedule,
152 | timestep_respacing,
153 | use_kl,
154 | predict_xstart,
155 | rescale_timesteps,
156 | rescale_learned_sigmas,
157 | use_checkpoint,
158 | use_scale_shift_norm,
159 | ):
160 | model = sr_create_model(
161 | large_size,
162 | small_size,
163 | num_channels,
164 | num_res_blocks,
165 | learn_sigma=learn_sigma,
166 | class_cond=class_cond,
167 | use_checkpoint=use_checkpoint,
168 | attention_resolutions=attention_resolutions,
169 | num_heads=num_heads,
170 | num_heads_upsample=num_heads_upsample,
171 | use_scale_shift_norm=use_scale_shift_norm,
172 | dropout=dropout,
173 | )
174 | diffusion = create_gaussian_diffusion(
175 | steps=diffusion_steps,
176 | learn_sigma=learn_sigma,
177 | noise_schedule=noise_schedule,
178 | use_kl=use_kl,
179 | predict_xstart=predict_xstart,
180 | rescale_timesteps=rescale_timesteps,
181 | rescale_learned_sigmas=rescale_learned_sigmas,
182 | timestep_respacing=timestep_respacing,
183 | )
184 | return model, diffusion
185 |
186 |
187 | def sr_create_model(
188 | large_size,
189 | small_size,
190 | num_channels,
191 | num_res_blocks,
192 | learn_sigma,
193 | class_cond,
194 | use_checkpoint,
195 | attention_resolutions,
196 | num_heads,
197 | num_heads_upsample,
198 | use_scale_shift_norm,
199 | dropout,
200 | ):
201 | _ = small_size # hack to prevent unused variable
202 |
203 | if large_size == 256:
204 | channel_mult = (1, 1, 2, 2, 4, 4)
205 | elif large_size == 64:
206 | channel_mult = (1, 2, 3, 4)
207 | else:
208 | raise ValueError(f"unsupported large size: {large_size}")
209 |
210 | attention_ds = []
211 | for res in attention_resolutions.split(","):
212 | attention_ds.append(large_size // int(res))
213 |
214 | return SuperResModel(
215 | in_channels=3,
216 | model_channels=num_channels,
217 | out_channels=(3 if not learn_sigma else 6),
218 | num_res_blocks=num_res_blocks,
219 | attention_resolutions=tuple(attention_ds),
220 | dropout=dropout,
221 | channel_mult=channel_mult,
222 | num_classes=(NUM_CLASSES if class_cond else None),
223 | use_checkpoint=use_checkpoint,
224 | num_heads=num_heads,
225 | num_heads_upsample=num_heads_upsample,
226 | use_scale_shift_norm=use_scale_shift_norm,
227 | )
228 |
229 |
230 | def create_gaussian_diffusion(
231 | *,
232 | steps=1000,
233 | learn_sigma=False,
234 | sigma_small=False,
235 | noise_schedule="linear",
236 | use_kl=False,
237 | predict_xstart=False,
238 | rescale_timesteps=False,
239 | rescale_learned_sigmas=False,
240 | timestep_respacing="",
241 | ):
242 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
243 | if use_kl:
244 | loss_type = gd.LossType.RESCALED_KL
245 | elif rescale_learned_sigmas:
246 | loss_type = gd.LossType.RESCALED_MSE
247 | else:
248 | loss_type = gd.LossType.MSE
249 | if not timestep_respacing:
250 | timestep_respacing = [steps]
251 | return SpacedDiffusion(
252 | use_timesteps=space_timesteps(steps, timestep_respacing),
253 | betas=betas,
254 | model_mean_type=(
255 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
256 | ),
257 | model_var_type=(
258 | (
259 | gd.ModelVarType.FIXED_LARGE
260 | if not sigma_small
261 | else gd.ModelVarType.FIXED_SMALL
262 | )
263 | if not learn_sigma
264 | else gd.ModelVarType.LEARNED_RANGE
265 | ),
266 | loss_type=loss_type,
267 | rescale_timesteps=rescale_timesteps,
268 | )
269 |
270 |
271 | def add_dict_to_argparser(parser, default_dict):
272 | for k, v in default_dict.items():
273 | v_type = type(v)
274 | if v is None:
275 | v_type = str
276 | elif isinstance(v, bool):
277 | v_type = str2bool
278 | parser.add_argument(f"--{k}", default=v, type=v_type)
279 |
280 |
281 | def args_to_dict(args, keys):
282 | return {k: getattr(args, k) for k in keys}
283 |
284 |
285 | def str2bool(v):
286 | """
287 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
288 | """
289 | if isinstance(v, bool):
290 | return v
291 | if v.lower() in ("yes", "true", "t", "y", "1"):
292 | return True
293 | elif v.lower() in ("no", "false", "f", "n", "0"):
294 | return False
295 | else:
296 | raise argparse.ArgumentTypeError("boolean value expected")
297 |
--------------------------------------------------------------------------------
/improved_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import numpy as np
7 | import torch as th
8 | import torch.distributed as dist
9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10 | from torch.optim import AdamW
11 |
12 | from . import dist_util, logger
13 | from .fp16_util import (
14 | make_master_params,
15 | master_params_to_model_params,
16 | model_grads_to_master_grads,
17 | unflatten_master_params,
18 | zero_grad,
19 | )
20 | from .nn import update_ema
21 | from .resample import LossAwareSampler, UniformSampler
22 |
23 | # For ImageNet experiments, this was a good default value.
24 | # We found that the lg_loss_scale quickly climbed to
25 | # 20-21 within the first ~1K steps of training.
26 | INITIAL_LOG_LOSS_SCALE = 20.0
27 |
28 |
29 | class TrainLoop:
30 | def __init__(
31 | self,
32 | *,
33 | model,
34 | diffusion,
35 | data,
36 | batch_size,
37 | microbatch,
38 | lr,
39 | ema_rate,
40 | log_interval,
41 | save_interval,
42 | resume_checkpoint,
43 | use_fp16=False,
44 | fp16_scale_growth=1e-3,
45 | schedule_sampler=None,
46 | weight_decay=0.0,
47 | lr_anneal_steps=0,
48 | ):
49 | self.model = model
50 | self.diffusion = diffusion
51 | self.data = data
52 | self.batch_size = batch_size
53 | self.microbatch = microbatch if microbatch > 0 else batch_size
54 | self.lr = lr
55 | self.ema_rate = (
56 | [ema_rate]
57 | if isinstance(ema_rate, float)
58 | else [float(x) for x in ema_rate.split(",")]
59 | )
60 | self.log_interval = log_interval
61 | self.save_interval = save_interval
62 | self.resume_checkpoint = resume_checkpoint
63 | self.use_fp16 = use_fp16
64 | self.fp16_scale_growth = fp16_scale_growth
65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
66 | self.weight_decay = weight_decay
67 | self.lr_anneal_steps = lr_anneal_steps
68 |
69 | self.step = 0
70 | self.resume_step = 0
71 | self.global_batch = self.batch_size * dist.get_world_size()
72 |
73 | self.model_params = list(self.model.parameters())
74 | self.master_params = self.model_params
75 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
76 | self.sync_cuda = th.cuda.is_available()
77 |
78 | self._load_and_sync_parameters()
79 | if self.use_fp16:
80 | self._setup_fp16()
81 |
82 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
83 | if self.resume_step:
84 | self._load_optimizer_state()
85 | # Model was resumed, either due to a restart or a checkpoint
86 | # being specified at the command line.
87 | self.ema_params = [
88 | self._load_ema_parameters(rate) for rate in self.ema_rate
89 | ]
90 | else:
91 | self.ema_params = [
92 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
93 | ]
94 |
95 | if th.cuda.is_available():
96 | self.use_ddp = True
97 | self.ddp_model = DDP(
98 | self.model,
99 | device_ids=[dist_util.dev()],
100 | output_device=dist_util.dev(),
101 | broadcast_buffers=False,
102 | bucket_cap_mb=128,
103 | find_unused_parameters=False,
104 | )
105 | else:
106 | if dist.get_world_size() > 1:
107 | logger.warn(
108 | "Distributed training requires CUDA. "
109 | "Gradients will not be synchronized properly!"
110 | )
111 | self.use_ddp = False
112 | self.ddp_model = self.model
113 |
114 | def _load_and_sync_parameters(self):
115 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
116 |
117 | if resume_checkpoint:
118 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
119 | if dist.get_rank() == 0:
120 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
121 | self.model.load_state_dict(
122 | dist_util.load_state_dict(
123 | resume_checkpoint, map_location=dist_util.dev()
124 | )
125 | )
126 |
127 | dist_util.sync_params(self.model.parameters())
128 |
129 | def _load_ema_parameters(self, rate):
130 | ema_params = copy.deepcopy(self.master_params)
131 |
132 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
133 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
134 | if ema_checkpoint:
135 | if dist.get_rank() == 0:
136 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
137 | state_dict = dist_util.load_state_dict(
138 | ema_checkpoint, map_location=dist_util.dev()
139 | )
140 | ema_params = self._state_dict_to_master_params(state_dict)
141 |
142 | dist_util.sync_params(ema_params)
143 | return ema_params
144 |
145 | def _load_optimizer_state(self):
146 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
147 | opt_checkpoint = bf.join(
148 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
149 | )
150 | if bf.exists(opt_checkpoint):
151 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
152 | state_dict = dist_util.load_state_dict(
153 | opt_checkpoint, map_location=dist_util.dev()
154 | )
155 | self.opt.load_state_dict(state_dict)
156 |
157 | def _setup_fp16(self):
158 | self.master_params = make_master_params(self.model_params)
159 | self.model.convert_to_fp16()
160 |
161 | def run_loop(self):
162 | while (
163 | not self.lr_anneal_steps
164 | or self.step + self.resume_step < self.lr_anneal_steps
165 | ):
166 | batch, cond = next(self.data)
167 | self.run_step(batch, cond)
168 | if self.step % self.log_interval == 0:
169 | logger.dumpkvs()
170 | if self.step % self.save_interval == 0:
171 | self.save()
172 | # Run for a finite amount of time in integration tests.
173 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
174 | return
175 | self.step += 1
176 | # Save the last checkpoint if it wasn't already saved.
177 | if (self.step - 1) % self.save_interval != 0:
178 | self.save()
179 |
180 | def run_step(self, batch, cond):
181 | self.forward_backward(batch, cond)
182 | if self.use_fp16:
183 | self.optimize_fp16()
184 | else:
185 | self.optimize_normal()
186 | self.log_step()
187 |
188 | def forward_backward(self, batch, cond):
189 | zero_grad(self.model_params)
190 | for i in range(0, batch.shape[0], self.microbatch):
191 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
192 | micro_cond = {
193 | k: v[i : i + self.microbatch].to(dist_util.dev())
194 | for k, v in cond.items()
195 | }
196 | last_batch = (i + self.microbatch) >= batch.shape[0]
197 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
198 |
199 | compute_losses = functools.partial(
200 | self.diffusion.training_losses,
201 | self.ddp_model,
202 | micro,
203 | t,
204 | model_kwargs=micro_cond,
205 | )
206 |
207 | if last_batch or not self.use_ddp:
208 | losses = compute_losses()
209 | else:
210 | with self.ddp_model.no_sync():
211 | losses = compute_losses()
212 |
213 | if isinstance(self.schedule_sampler, LossAwareSampler):
214 | self.schedule_sampler.update_with_local_losses(
215 | t, losses["loss"].detach()
216 | )
217 |
218 | loss = (losses["loss"] * weights).mean()
219 | log_loss_dict(
220 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
221 | )
222 | if self.use_fp16:
223 | loss_scale = 2 ** self.lg_loss_scale
224 | (loss * loss_scale).backward()
225 | else:
226 | loss.backward()
227 |
228 | def optimize_fp16(self):
229 | if any(not th.isfinite(p.grad).all() for p in self.model_params):
230 | self.lg_loss_scale -= 1
231 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
232 | return
233 |
234 | model_grads_to_master_grads(self.model_params, self.master_params)
235 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
236 | self._log_grad_norm()
237 | self._anneal_lr()
238 | self.opt.step()
239 | for rate, params in zip(self.ema_rate, self.ema_params):
240 | update_ema(params, self.master_params, rate=rate)
241 | master_params_to_model_params(self.model_params, self.master_params)
242 | self.lg_loss_scale += self.fp16_scale_growth
243 |
244 | def optimize_normal(self):
245 | self._log_grad_norm()
246 | self._anneal_lr()
247 | self.opt.step()
248 | for rate, params in zip(self.ema_rate, self.ema_params):
249 | update_ema(params, self.master_params, rate=rate)
250 |
251 | def _log_grad_norm(self):
252 | sqsum = 0.0
253 | for p in self.master_params:
254 | sqsum += (p.grad ** 2).sum().item()
255 | logger.logkv_mean("grad_norm", np.sqrt(sqsum))
256 |
257 | def _anneal_lr(self):
258 | if not self.lr_anneal_steps:
259 | return
260 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
261 | lr = self.lr * (1 - frac_done)
262 | for param_group in self.opt.param_groups:
263 | param_group["lr"] = lr
264 |
265 | def log_step(self):
266 | logger.logkv("step", self.step + self.resume_step)
267 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
268 | if self.use_fp16:
269 | logger.logkv("lg_loss_scale", self.lg_loss_scale)
270 |
271 | def save(self):
272 | def save_checkpoint(rate, params):
273 | state_dict = self._master_params_to_state_dict(params)
274 | if dist.get_rank() == 0:
275 | logger.log(f"saving model {rate}...")
276 | if not rate:
277 | filename = f"model{(self.step+self.resume_step):06d}.pt"
278 | else:
279 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
280 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
281 | th.save(state_dict, f)
282 |
283 | save_checkpoint(0, self.master_params)
284 | for rate, params in zip(self.ema_rate, self.ema_params):
285 | save_checkpoint(rate, params)
286 |
287 | if dist.get_rank() == 0:
288 | with bf.BlobFile(
289 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
290 | "wb",
291 | ) as f:
292 | th.save(self.opt.state_dict(), f)
293 |
294 | dist.barrier()
295 |
296 | def _master_params_to_state_dict(self, master_params):
297 | if self.use_fp16:
298 | master_params = unflatten_master_params(
299 | self.model.parameters(), master_params
300 | )
301 | state_dict = self.model.state_dict()
302 | for i, (name, _value) in enumerate(self.model.named_parameters()):
303 | assert name in state_dict
304 | state_dict[name] = master_params[i]
305 | return state_dict
306 |
307 | def _state_dict_to_master_params(self, state_dict):
308 | params = [state_dict[name] for name, _ in self.model.named_parameters()]
309 | if self.use_fp16:
310 | return make_master_params(params)
311 | else:
312 | return params
313 |
314 |
315 | def parse_resume_step_from_filename(filename):
316 | """
317 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
318 | checkpoint's number of steps.
319 | """
320 | split = filename.split("model")
321 | if len(split) < 2:
322 | return 0
323 | split1 = split[-1].split(".")[0]
324 | try:
325 | return int(split1)
326 | except ValueError:
327 | return 0
328 |
329 |
330 | def get_blob_logdir():
331 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
332 |
333 |
334 | def find_resume_checkpoint():
335 | # On your infrastructure, you may want to override this to automatically
336 | # discover the latest checkpoint on your blob storage, etc.
337 | return None
338 |
339 |
340 | def find_ema_checkpoint(main_checkpoint, step, rate):
341 | if main_checkpoint is None:
342 | return None
343 | filename = f"ema_{rate}_{(step):06d}.pt"
344 | path = bf.join(bf.dirname(main_checkpoint), filename)
345 | if bf.exists(path):
346 | return path
347 | return None
348 |
349 |
350 | def log_loss_dict(diffusion, ts, losses):
351 | for key, values in losses.items():
352 | logger.logkv_mean(key, values.mean().item())
353 | # Log the quantiles (four quartiles, in particular).
354 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
355 | quartile = int(4 * sub_t / diffusion.num_timesteps)
356 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
357 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .lenet import *
2 | from .vggnet import *
3 | from .resnet import *
4 | from .wide_resnet import *
5 |
--------------------------------------------------------------------------------
/networks/lenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def conv_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
8 | init.constant(m.bias, 0)
9 |
10 | class LeNet(nn.Module):
11 | def __init__(self, num_classes):
12 | super(LeNet, self).__init__()
13 | self.conv1 = nn.Conv2d(3, 6, 5)
14 | self.conv2 = nn.Conv2d(6, 16, 5)
15 | self.fc1 = nn.Linear(16*5*5, 120)
16 | self.fc2 = nn.Linear(120, 84)
17 | self.fc3 = nn.Linear(84, num_classes)
18 |
19 | def forward(self, x):
20 | out = F.relu(self.conv1(x))
21 | out = F.max_pool2d(out, 2)
22 | out = F.relu(self.conv2(out))
23 | out = F.max_pool2d(out, 2)
24 | out = out.view(out.size(0), -1)
25 | out = F.relu(self.fc1(out))
26 | out = F.relu(self.fc2(out))
27 | out = self.fc3(out)
28 |
29 | return(out)
30 |
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 | import sys
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
10 |
11 | def conv_init(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Conv') != -1:
14 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
15 | init.constant(m.bias, 0)
16 |
17 | def cfg(depth):
18 | depth_lst = [18, 34, 50, 101, 152]
19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
20 | cf_dict = {
21 | '18': (BasicBlock, [2,2,2,2]),
22 | '34': (BasicBlock, [3,4,6,3]),
23 | '50': (Bottleneck, [3,4,6,3]),
24 | '101':(Bottleneck, [3,4,23,3]),
25 | '152':(Bottleneck, [3,8,36,3]),
26 | }
27 |
28 | return cf_dict[str(depth)]
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, in_planes, planes, stride=1):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(in_planes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 |
40 | self.shortcut = nn.Sequential()
41 | if stride != 1 or in_planes != self.expansion * planes:
42 | self.shortcut = nn.Sequential(
43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
44 | nn.BatchNorm2d(self.expansion*planes)
45 | )
46 |
47 | def forward(self, x):
48 | out = F.relu(self.bn1(self.conv1(x)))
49 | out = self.bn2(self.conv2(out))
50 | out += self.shortcut(x)
51 | out = F.relu(out)
52 |
53 | return out
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, in_planes, planes, stride=1):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True)
65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
66 |
67 | self.shortcut = nn.Sequential()
68 | if stride != 1 or in_planes != self.expansion*planes:
69 | self.shortcut = nn.Sequential(
70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
71 | nn.BatchNorm2d(self.expansion*planes)
72 | )
73 |
74 | def forward(self, x):
75 | out = F.relu(self.bn1(self.conv1(x)))
76 | out = F.relu(self.bn2(self.conv2(out)))
77 | out = self.bn3(self.conv3(out))
78 | out += self.shortcut(x)
79 | out = F.relu(out)
80 |
81 | return out
82 |
83 | class ResNet(nn.Module):
84 | def __init__(self, depth, num_classes):
85 | super(ResNet, self).__init__()
86 | self.in_planes = 16
87 |
88 | block, num_blocks = cfg(depth)
89 |
90 | self.conv1 = conv3x3(3,16)
91 | self.bn1 = nn.BatchNorm2d(16)
92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
95 | self.linear = nn.Linear(64*block.expansion, num_classes)
96 |
97 | def _make_layer(self, block, planes, num_blocks, stride):
98 | strides = [stride] + [1]*(num_blocks-1)
99 | layers = []
100 |
101 | for stride in strides:
102 | layers.append(block(self.in_planes, planes, stride))
103 | self.in_planes = planes * block.expansion
104 |
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | out = F.relu(self.bn1(self.conv1(x)))
109 | out = self.layer1(out)
110 | out = self.layer2(out)
111 | out = self.layer3(out)
112 | out = F.avg_pool2d(out, 8)
113 | out = out.view(out.size(0), -1)
114 | out = self.linear(out)
115 |
116 | return out
117 |
118 | if __name__ == '__main__':
119 | net=ResNet(50, 10)
120 | y = net(Variable(torch.randn(1,3,32,32)))
121 | print(y.size())
122 |
--------------------------------------------------------------------------------
/networks/vggnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | init.constant(m.bias, 0)
10 |
11 | def cfg(depth):
12 | depth_lst = [11, 13, 16, 19]
13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19"
14 | cf_dict = {
15 | '11': [
16 | 64, 'mp',
17 | 128, 'mp',
18 | 256, 256, 'mp',
19 | 512, 512, 'mp',
20 | 512, 512, 'mp'],
21 | '13': [
22 | 64, 64, 'mp',
23 | 128, 128, 'mp',
24 | 256, 256, 'mp',
25 | 512, 512, 'mp',
26 | 512, 512, 'mp'
27 | ],
28 | '16': [
29 | 64, 64, 'mp',
30 | 128, 128, 'mp',
31 | 256, 256, 256, 'mp',
32 | 512, 512, 512, 'mp',
33 | 512, 512, 512, 'mp'
34 | ],
35 | '19': [
36 | 64, 64, 'mp',
37 | 128, 128, 'mp',
38 | 256, 256, 256, 256, 'mp',
39 | 512, 512, 512, 512, 'mp',
40 | 512, 512, 512, 512, 'mp'
41 | ],
42 | }
43 |
44 | return cf_dict[str(depth)]
45 |
46 | def conv3x3(in_planes, out_planes, stride=1):
47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
48 |
49 | class VGG(nn.Module):
50 | def __init__(self, depth, num_classes):
51 | super(VGG, self).__init__()
52 | self.features = self._make_layers(cfg(depth))
53 | self.linear = nn.Linear(512, num_classes)
54 |
55 | def forward(self, x):
56 | out = self.features(x)
57 | out = out.view(out.size(0), -1)
58 | out = self.linear(out)
59 |
60 | return out
61 |
62 | def _make_layers(self, cfg):
63 | layers = []
64 | in_planes = 3
65 |
66 | for x in cfg:
67 | if x == 'mp':
68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
69 | else:
70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
71 | in_planes = x
72 |
73 | # After cfg convolution
74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
75 | return nn.Sequential(*layers)
76 |
77 | if __name__ == "__main__":
78 | net = VGG(16, 10)
79 | y = net(Variable(torch.randn(1,3,32,32)))
80 | print(y.size())
81 |
--------------------------------------------------------------------------------
/networks/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import sys
8 | import numpy as np
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
12 |
13 | def conv_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
17 | init.constant_(m.bias, 0)
18 | elif classname.find('BatchNorm') != -1:
19 | init.constant_(m.weight, 1)
20 | init.constant_(m.bias, 0)
21 |
22 | class wide_basic(nn.Module):
23 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
24 | super(wide_basic, self).__init__()
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
27 | self.dropout = nn.Dropout(p=dropout_rate)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
30 |
31 | self.shortcut = nn.Sequential()
32 | if stride != 1 or in_planes != planes:
33 | self.shortcut = nn.Sequential(
34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
35 | )
36 |
37 | def forward(self, x):
38 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
39 | out = self.conv2(F.relu(self.bn2(out)))
40 | out += self.shortcut(x)
41 |
42 | return out
43 |
44 | class Wide_ResNet(nn.Module):
45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
46 | super(Wide_ResNet, self).__init__()
47 | self.in_planes = 16
48 |
49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
50 | n = (depth-4)/6
51 | k = widen_factor
52 |
53 | print('| Wide-Resnet %dx%d' %(depth, k))
54 | nStages = [16, 16*k, 32*k, 64*k]
55 |
56 | self.conv1 = conv3x3(3,nStages[0])
57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
61 | self.linear = nn.Linear(nStages[3], num_classes)
62 |
63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
64 | strides = [stride] + [1]*(int(num_blocks)-1)
65 | layers = []
66 |
67 | for stride in strides:
68 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
69 | self.in_planes = planes
70 |
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | out = self.conv1(x)
75 | out = self.layer1(out)
76 | out = self.layer2(out)
77 | out = self.layer3(out)
78 | out = F.relu(self.bn1(out))
79 | out = F.avg_pool2d(out, 8)
80 | out = out.view(out.size(0), -1)
81 | out = self.linear(out)
82 |
83 | return out
84 |
85 | if __name__ == '__main__':
86 | net=Wide_ResNet(28, 10, 0.3, 10)
87 | y = net(Variable(torch.randn(1,3,32,32)))
88 |
89 | print(y.size())
90 |
--------------------------------------------------------------------------------
/pictures/densepure_flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jayfeather1024/DensePure/1ed105d13a6ccfaf34a4fc240609dae89abc6a0d/pictures/densepure_flowchart.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.4
2 | pyyaml==5.3.1
3 | wheel==0.34.2
4 | scipy==1.5.2
5 | pillow==7.2.0
6 | matplotlib==3.3.0
7 | tqdm==4.56.1
8 | tensorboardX==2.0
9 | seaborn==0.10.1
10 | pandas==1.2.0
11 | requests==2.25.0
12 | xvfbwrapper==0.2.9
13 | torchdiffeq==0.2.1
14 | timm==0.5.4
15 | lmdb
16 | Ninja
17 | foolbox
18 | torchsde
19 | statsmodels
20 | transformers
21 | blobfile
22 | mpi4py
--------------------------------------------------------------------------------
/results/merge_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | sigma=$1
4 | steps=$2
5 | majority_vote_num=$3
6 |
7 | python merge_results.py \
8 | --sample_id_list $(seq -s ' ' 0 20 9980) \
9 | --sample_num 500 \
10 | --majority_vote_num $majority_vote_num \
11 | --N 100000 \
12 | --N0 100 \
13 | --sigma $sigma \
14 | --classes_num 10 \
15 | --datasets cifar10 \
16 | --steps $steps
17 |
--------------------------------------------------------------------------------
/results/merge_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | sigma=$1
4 | steps=$2
5 | majority_vote_num=$3
6 |
7 | python merge_results.py \
8 | --sample_id_list $(seq -s ' ' 0 500 49500) \
9 | --sample_num 100 \
10 | --majority_vote_num $majority_vote_num \
11 | --N 10000 \
12 | --N0 100 \
13 | --sigma $sigma \
14 | --classes_num 1000 \
15 | --datasets imagenet \
16 | --steps $steps
--------------------------------------------------------------------------------
/results/merge_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pandas as pd
4 | from statsmodels.stats.proportion import proportion_confint
5 | from scipy.stats import norm
6 |
7 | def gain_results(args):
8 | file_merge = open(str(args.datasets)+'_'+str(args.sigma)+'_'+str(args.results_file), 'w')
9 | file_merge.write("idx\tlabel\tpredict\tradius\tcorrect\n")
10 | for sample_id in range(args.sample_num):
11 |
12 | n0_predictions_list = []
13 | n_predictions_list = []
14 | for i in range(args.majority_vote_num):
15 | id_file = open(str(args.datasets)+'-densepure-sample_num_'+str(args.N0)+'-noise_'+str(args.sigma)+'-'+str(args.steps)+'-steps-'+str(i), 'r')
16 | lines = id_file.readlines()
17 | line = lines[sample_id+1].split('\t')
18 | label = int(line[1])
19 |
20 | n0_predictions = np.load('exp/'+str(args.datasets)+'/'+str(args.sigma)+'-'+str(args.sample_id_list[sample_id])+'-'+str(i)+'-n0_predictions.npy')
21 | n_predictions = np.load('exp/'+str(args.datasets)+'/'+str(args.sigma)+'-'+str(args.sample_id_list[sample_id])+'-'+str(i)+'-n_predictions.npy')
22 |
23 | n0_predictions_list.append(n0_predictions)
24 | n_predictions_list.append(n_predictions)
25 |
26 | n0_predictions_list = np.array(n0_predictions_list).T
27 | n_predictions_list = np.array(n_predictions_list).T
28 | count_max_list = np.zeros(args.N0,dtype=int)
29 |
30 | for i in range(args.N0):
31 | count_max = max(list(n0_predictions_list[i]),key=list(n0_predictions_list[i]).count)
32 | count_max_list[i] = count_max
33 | counts = np.zeros(args.classes_num, dtype=int)
34 | for idx in count_max_list:
35 | counts[idx] += 1
36 | prediction = counts.argmax().item()
37 |
38 | count_max_list = np.zeros(args.N,dtype=int)
39 | for i in range(args.N):
40 | count_max = max(list(n_predictions_list[i]),key=list(n_predictions_list[i]).count)
41 | count_max_list[i] = count_max
42 | counts = np.zeros(args.classes_num, dtype=int)
43 | for idx in count_max_list:
44 | counts[idx] += 1
45 |
46 | nA = counts[prediction].item()
47 | pABar = proportion_confint(nA, args.N, alpha=2 * 0.001, method="beta")[0]
48 | if pABar < 0.5:
49 | prediction = -1
50 | radius = 0.0
51 | else:
52 | radius = args.sigma * norm.ppf(pABar)
53 |
54 | correct = int(prediction == label)
55 |
56 | file_merge.write("{}\t{}\t{}\t{:.3}\t{}".format(args.sample_id_list[sample_id], label, prediction, radius, correct))
57 | file_merge.write("\n")
58 |
59 |
60 | def parse_args():
61 | parser = argparse.ArgumentParser(description=globals()['__doc__'])
62 | parser.add_argument("--sample_id_list", type=int, nargs='+', default=[0], help="sample id for evaluation")
63 | parser.add_argument('--sample_num', type=int, default=100, help='sample numbers')
64 | parser.add_argument('--majority_vote_num', type=int, default=10, help='majority vote numbers')
65 | parser.add_argument("--N0", type=int, default=100)
66 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
67 | parser.add_argument('--sigma', type=float, default=0.25, help='noise hyperparameter')
68 | parser.add_argument('--classes_num', type=int, default=10, help='classes numbers of datasets')
69 | parser.add_argument("--results_file", type=str, default='merge_results.txt', help="output file")
70 | parser.add_argument("--datasets", type=str, default='cifar10', help="cifar10 or imagenet")
71 | parser.add_argument("--steps", type=int, default=2)
72 |
73 | args = parser.parse_args()
74 | return args
75 |
76 |
77 |
78 | if __name__ == '__main__':
79 | args = parse_args()
80 | print(args)
81 | gain_results(args)
--------------------------------------------------------------------------------
/run_scripts/carlini22_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd ..
3 |
4 | sigma=$1
5 |
6 | python eval_certified_densepure.py \
7 | --exp exp \
8 | --config cifar10.yml \
9 | -i cifar10-carlini22-sample_num_100000-noise_$sigma-1step \
10 | --domain cifar10 \
11 | --seed 0 \
12 | --diffusion_type ddpm \
13 | --lp_norm L2 \
14 | --outfile results/cifar10-carlini22-sample_num_100000-noise_$sigma-1step \
15 | --sigma $sigma \
16 | --N 100000 \
17 | --N0 100 \
18 | --certified_batch 100 \
19 | --sample_id $(seq -s ' ' 0 20 9980) \
20 | --use_id \
21 | --certify_mode purify \
22 | --advanced_classifier vit \
23 | --use_one_step
--------------------------------------------------------------------------------
/run_scripts/carlini22_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd ..
3 |
4 | sigma=$1
5 |
6 | python eval_certified_densepure.py \
7 | --exp exp \
8 | --config imagenet.yml \
9 | -i imagenet-carlini22-sample_num_10000-noise_$sigma-1step \
10 | --domain imagenet \
11 | --seed 0 \
12 | --diffusion_type guided-ddpm \
13 | --lp_norm L2 \
14 | --outfile results/imagenet-carlini22-sample_num_10000-noise_$sigma-1step \
15 | --sigma $sigma \
16 | --N 10000 \
17 | --N0 100 \
18 | --certified_batch 16 \
19 | --sample_id $(seq -s ' ' 0 500 49500) \
20 | --use_id \
21 | --certify_mode purify \
22 | --advanced_classifier beit \
23 | --use_one_step
--------------------------------------------------------------------------------
/run_scripts/densepure_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd ..
3 |
4 | sigma=$1
5 | steps=$2
6 | reverse_seed=$3
7 |
8 | python eval_certified_densepure.py \
9 | --exp exp/cifar10 \
10 | --config cifar10.yml \
11 | -i cifar10-densepure-sample_num_100000-noise_$sigma-$steps-$reverse_seed \
12 | --domain cifar10 \
13 | --seed 0 \
14 | --diffusion_type ddpm \
15 | --lp_norm L2 \
16 | --outfile results/cifar10-densepure-sample_num_100000-noise_$sigma-$steps-steps-$reverse_seed \
17 | --sigma $sigma \
18 | --N 100000 \
19 | --N0 100 \
20 | --certified_batch 100 \
21 | --sample_id $(seq -s ' ' 0 20 9980) \
22 | --use_id \
23 | --certify_mode purify \
24 | --advanced_classifier vit \
25 | --use_t_steps \
26 | --num_t_steps $steps \
27 | --save_predictions \
28 | --predictions_path exp/cifar10/$sigma- \
29 | --reverse_seed $reverse_seed
--------------------------------------------------------------------------------
/run_scripts/densepure_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd ..
3 |
4 | sigma=$1
5 | steps=$2
6 | reverse_seed=$3
7 |
8 | python eval_certified_densepure.py \
9 | --exp exp/imagenet \
10 | --config imagenet.yml \
11 | -i imagenet-densepure-sample_num_10000-noise_$sigma-$steps-$reverse_seed \
12 | --domain imagenet \
13 | --seed 0 \
14 | --diffusion_type guided-ddpm \
15 | --lp_norm L2 \
16 | --outfile imagenet-densepure-sample_num_10000-noise_$sigma-$steps-$reverse_seed \
17 | --sigma $sigma \
18 | --N 10000 \
19 | --N0 100 \
20 | --certified_batch 16 \
21 | --sample_id $(seq -s ' ' 0 500 49500) \
22 | --use_id \
23 | --certify_mode purify \
24 | --advanced_classifier beit \
25 | --use_t_steps \
26 | --num_t_steps $steps \
27 | --save_predictions \
28 | --predictions_path exp/imagenet/$sigma- \
29 | --reverse_seed $reverse_seed
--------------------------------------------------------------------------------
/runners/diffpure_ddpm_densepure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 |
6 | import torch
7 | import torchvision.utils as tvu
8 |
9 | from improved_diffusion import dist_util, logger
10 | from improved_diffusion.script_util import (
11 | NUM_CLASSES,
12 | model_and_diffusion_defaults,
13 | create_model_and_diffusion,
14 | add_dict_to_argparser,
15 | args_to_dict,
16 | )
17 | from improved_diffusion import gaussian_diffusion
18 |
19 | import math
20 |
21 |
22 | def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
23 | betas = np.linspace(beta_start, beta_end,
24 | num_diffusion_timesteps, dtype=np.float64)
25 | assert betas.shape == (num_diffusion_timesteps,)
26 | return betas
27 |
28 |
29 | def extract(a, t, x_shape):
30 | """Extract coefficients from a based on t and reshape to make it
31 | broadcastable with x_shape."""
32 | bs, = t.shape
33 | assert x_shape[0] == bs
34 | out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
35 | assert out.shape == (bs,)
36 | out = out.reshape((bs,) + (1,) * (len(x_shape) - 1))
37 | return out
38 |
39 |
40 | def image_editing_denoising_step_flexible_mask(x, t, *, model, logvar, betas):
41 | """
42 | Sample from p(x_{t-1} | x_t)
43 | """
44 | alphas = 1.0 - betas
45 | alphas_cumprod = alphas.cumprod(dim=0)
46 |
47 | model_output = model(x, t)
48 | weighted_score = betas / torch.sqrt(1 - alphas_cumprod)
49 | mean = extract(1 / torch.sqrt(alphas), t, x.shape) * (x - extract(weighted_score, t, x.shape) * model_output)
50 |
51 | logvar = extract(logvar, t, x.shape)
52 | noise = torch.randn_like(x)
53 | mask = 1 - (t == 0).float()
54 | mask = mask.reshape((x.shape[0],) + (1,) * (len(x.shape) - 1))
55 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
56 | sample = sample.float()
57 | return sample
58 |
59 |
60 | class Diffusion(torch.nn.Module):
61 | def __init__(self, args, config, device=None):
62 | super().__init__()
63 | self.args = args
64 | self.config = config
65 | if device is None:
66 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
67 | self.device = device
68 | self.reverse_state = None
69 | self.reverse_state_cuda = None
70 |
71 | print("Loading model")
72 | defaults = model_and_diffusion_defaults(self.args.t_total)
73 | model, diffusion = create_model_and_diffusion(**defaults)
74 | model.load_state_dict(
75 | dist_util.load_state_dict("pretrained/cifar10_uncond_50M_500K.pt", map_location="cpu")
76 | )
77 | model.to(self.device)
78 | model.eval()
79 |
80 | self.model = model
81 | self.diffusion = diffusion
82 | sigma = self.args.sigma
83 | a = 1/(1+(sigma*2)**2)
84 | self.scale = a**0.5
85 | sigma = sigma*2
86 | T = self.args.t_total
87 | self.t = T*(1-(2*1.008*math.asin(math.sin(math.pi/(2*1.008))/(1+sigma**2)**0.5))/math.pi)
88 |
89 |
90 | def image_editing_sample(self, img=None, bs_id=0, tag=None, sigma=0.0):
91 | assert isinstance(img, torch.Tensor)
92 | batch_size = img.shape[0]
93 |
94 | with torch.no_grad():
95 | if tag is None:
96 | tag = 'rnd' + str(random.randint(0, 10000))
97 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag)
98 |
99 | assert img.ndim == 4, img.ndim
100 | x0 = img
101 |
102 | x0 = self.scale*(img)
103 | t = self.t
104 |
105 | if self.args.use_clustering:
106 | x0 = x0.unsqueeze(1).repeat(1,self.args.clustering_batch,1,1,1).view(batch_size*self.args.clustering_batch,3,32,32)
107 |
108 | if self.args.use_one_step:
109 | # one step denoise
110 | t = torch.tensor([round(t)] * x0.shape[0], device=self.device)
111 | out = self.diffusion.p_sample(
112 | self.model,
113 | x0,
114 | t+self.args.t_plus,
115 | clip_denoised=True,
116 | )
117 | x0 = out["pred_xstart"]
118 |
119 | elif self.args.use_t_steps:
120 |
121 | #save random state
122 | if self.args.save_predictions:
123 | global_seed_state = torch.random.get_rng_state()
124 | if torch.cuda.is_available():
125 | global_cuda_state = torch.cuda.random.get_rng_state_all()
126 |
127 | if self.reverse_state==None:
128 | torch.manual_seed(self.args.reverse_seed)
129 | if torch.cuda.is_available():
130 | torch.cuda.manual_seed_all(self.args.reverse_seed)
131 | else:
132 | torch.random.set_rng_state(self.reverse_state)
133 | if torch.cuda.is_available():
134 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda)
135 |
136 | # t steps denoise
137 | inter = t/self.args.num_t_steps
138 | indices_t_steps = [round(t-i*inter) for i in range(self.args.num_t_steps)]
139 |
140 | for i in range(len(indices_t_steps)):
141 | t = torch.tensor([len(indices_t_steps)-i-1] * x0.shape[0], device=self.device)
142 | real_t = torch.tensor([indices_t_steps[i]] * x0.shape[0], device=self.device)
143 | with torch.no_grad():
144 | out = self.diffusion.p_sample(
145 | self.model,
146 | x0,
147 | t,
148 | clip_denoised=True,
149 | indices_t_steps = indices_t_steps.copy(),
150 | T = self.args.t_total,
151 | step = len(indices_t_steps)-i,
152 | real_t = real_t
153 | )
154 | x0 = out["sample"]
155 |
156 | #load random state
157 | if self.args.save_predictions:
158 | self.reverse_state = torch.random.get_rng_state()
159 | if torch.cuda.is_available():
160 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all()
161 |
162 | torch.random.set_rng_state(global_seed_state)
163 | if torch.cuda.is_available():
164 | torch.cuda.random.set_rng_state_all(global_cuda_state)
165 |
166 | else:
167 | #save random state
168 | if self.args.save_predictions:
169 | global_seed_state = torch.random.get_rng_state()
170 | if torch.cuda.is_available():
171 | global_cuda_state = torch.cuda.random.get_rng_state_all()
172 |
173 | if self.reverse_state==None:
174 | torch.manual_seed(self.args.reverse_seed)
175 | if torch.cuda.is_available():
176 | torch.cuda.manual_seed_all(self.args.reverse_seed)
177 | else:
178 | torch.random.set_rng_state(self.reverse_state)
179 | if torch.cuda.is_available():
180 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda)
181 |
182 | # full steps denoise
183 | indices = list(range(round(t)))[::-1]
184 | for i in indices:
185 | t = torch.tensor([i] * x0.shape[0], device=self.device)
186 | with torch.no_grad():
187 | out = self.diffusion.p_sample(
188 | self.model,
189 | x0,
190 | t,
191 | clip_denoised=True,
192 | )
193 | x0 = out["sample"]
194 |
195 | #load random state
196 | if self.args.save_predictions:
197 | self.reverse_state = torch.random.get_rng_state()
198 | if torch.cuda.is_available():
199 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all()
200 |
201 | torch.random.set_rng_state(global_seed_state)
202 | if torch.cuda.is_available():
203 | torch.cuda.random.set_rng_state_all(global_cuda_state)
204 |
205 | return x0
206 |
--------------------------------------------------------------------------------
/runners/diffpure_guided_densepure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import torch
5 | import torchvision.utils as tvu
6 |
7 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
8 | import math
9 | import numpy as np
10 |
11 |
12 | class GuidedDiffusion(torch.nn.Module):
13 | def __init__(self, args, config, device=None, model_dir='pretrained'):
14 | super().__init__()
15 | self.args = args
16 | self.config = config
17 | if device is None:
18 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
19 | self.device = device
20 | self.reverse_state = None
21 | self.reverse_state_cuda = None
22 |
23 | # load model
24 | model_config = model_and_diffusion_defaults()
25 | model_config.update(vars(self.config.model))
26 | print(f'model_config: {model_config}')
27 | model, diffusion = create_model_and_diffusion(**model_config)
28 | model.load_state_dict(torch.load(f'{model_dir}/256x256_diffusion_uncond.pt', map_location='cpu'))
29 | model.requires_grad_(False).eval().to(self.device)
30 |
31 | if model_config['use_fp16']:
32 | model.convert_to_fp16()
33 |
34 | self.model = model
35 | self.model.eval()
36 | self.diffusion = diffusion
37 | self.betas = diffusion.betas
38 | alphas = 1.0 - self.betas
39 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
40 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
41 |
42 | sigma = self.args.sigma
43 |
44 | a = 1/(1+(sigma*2)**2)
45 | self.scale = a**0.5
46 |
47 | sigma = sigma*2
48 | T = self.args.t_total
49 | for t in range(len(self.sqrt_recipm1_alphas_cumprod)-1):
50 | if self.sqrt_recipm1_alphas_cumprod[t]=sigma:
51 | if sigma - self.sqrt_recipm1_alphas_cumprod[t] > self.sqrt_recipm1_alphas_cumprod[t+1] - sigma:
52 | self.t = t+1
53 | break
54 | else:
55 | self.t = t
56 | break
57 | self.t = len(diffusion.alphas_cumprod)-1
58 |
59 | def image_editing_sample(self, img=None, bs_id=0, tag=None, sigma=0.0):
60 | assert isinstance(img, torch.Tensor)
61 | batch_size = img.shape[0]
62 |
63 | with torch.no_grad():
64 | if tag is None:
65 | tag = 'rnd' + str(random.randint(0, 10000))
66 | out_dir = os.path.join(self.args.log_dir, 'bs' + str(bs_id) + '_' + tag)
67 |
68 | assert img.ndim == 4, img.ndim
69 | x0 = img
70 |
71 | x0 = self.scale*(img)
72 | t = self.t
73 |
74 | if self.args.use_clustering:
75 | x0 = x0.unsqueeze(1).repeat(1,self.args.clustering_batch,1,1,1).view(batch_size*self.args.clustering_batch,3,256,256)
76 | self.model.eval()
77 |
78 | if self.args.use_one_step:
79 | # one step denoise
80 | t = torch.tensor([round(t)] * x0.shape[0], device=self.device)
81 | out = self.diffusion.p_sample(
82 | self.model,
83 | x0,
84 | t+self.args.t_plus,
85 | clip_denoised=True,
86 | )
87 |
88 | x0 = out["pred_xstart"]
89 |
90 | elif self.args.use_t_steps:
91 | #save random state
92 | if self.args.save_predictions:
93 | global_seed_state = torch.random.get_rng_state()
94 | if torch.cuda.is_available():
95 | global_cuda_state = torch.cuda.random.get_rng_state_all()
96 |
97 | if self.reverse_state==None:
98 | torch.manual_seed(self.args.reverse_seed)
99 | if torch.cuda.is_available():
100 | torch.cuda.manual_seed_all(self.args.reverse_seed)
101 | else:
102 | torch.random.set_rng_state(self.reverse_state)
103 | if torch.cuda.is_available():
104 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda)
105 |
106 | # t steps denoise
107 | inter = t/self.args.num_t_steps
108 | indices_t_steps = [round(t-i*inter) for i in range(self.args.num_t_steps)]
109 |
110 | for i in range(len(indices_t_steps)):
111 | t = torch.tensor([len(indices_t_steps)-i-1] * x0.shape[0], device=self.device)
112 | real_t = torch.tensor([indices_t_steps[i]] * x0.shape[0], device=self.device)
113 | with torch.no_grad():
114 | out = self.diffusion.p_sample(
115 | self.model,
116 | x0,
117 | t,
118 | clip_denoised=True,
119 | indices_t_steps = indices_t_steps.copy(),
120 | T = self.args.t_total,
121 | step = len(indices_t_steps)-i,
122 | real_t = real_t
123 | )
124 | x0 = out["sample"]
125 |
126 | #load random state
127 | if self.args.save_predictions:
128 | self.reverse_state = torch.random.get_rng_state()
129 | if torch.cuda.is_available():
130 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all()
131 |
132 | torch.random.set_rng_state(global_seed_state)
133 | if torch.cuda.is_available():
134 | torch.cuda.random.set_rng_state_all(global_cuda_state)
135 |
136 | else:
137 | #save random state
138 | if self.args.save_predictions:
139 | global_seed_state = torch.random.get_rng_state()
140 | if torch.cuda.is_available():
141 | global_cuda_state = torch.cuda.random.get_rng_state_all()
142 |
143 | if self.reverse_state==None:
144 | torch.manual_seed(self.args.reverse_seed)
145 | if torch.cuda.is_available():
146 | torch.cuda.manual_seed_all(self.args.reverse_seed)
147 | else:
148 | torch.random.set_rng_state(self.reverse_state)
149 | if torch.cuda.is_available():
150 | torch.cuda.random.set_rng_state_all(self.reverse_state_cuda)
151 |
152 | # full steps denoise
153 | indices = list(range(round(t)))[::-1]
154 | for i in indices:
155 | t = torch.tensor([i] * x0.shape[0], device=self.device)
156 | with torch.no_grad():
157 | out = self.diffusion.p_sample(
158 | self.model,
159 | x0,
160 | t,
161 | clip_denoised=True,
162 | )
163 | x0 = out["sample"]
164 |
165 | #load random state
166 | if self.args.save_predictions:
167 | self.reverse_state = torch.random.get_rng_state()
168 | if torch.cuda.is_available():
169 | self.reverse_state_cuda = torch.cuda.random.get_rng_state_all()
170 |
171 | torch.random.set_rng_state(global_seed_state)
172 | if torch.cuda.is_available():
173 | torch.cuda.random.set_rng_state_all(global_cuda_state)
174 |
175 | return x0
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from typing import Any
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.models as models
7 | from torch.utils.data import DataLoader
8 | import torchvision.transforms as transforms
9 | from architectures import get_architecture
10 | import data
11 |
12 |
13 | def compute_n_params(model, return_str=True):
14 | tot = 0
15 | for p in model.parameters():
16 | w = 1
17 | for x in p.shape:
18 | w *= x
19 | tot += w
20 | if return_str:
21 | if tot >= 1e6:
22 | return '{:.1f}M'.format(tot / 1e6)
23 | else:
24 | return '{:.1f}K'.format(tot / 1e3)
25 | else:
26 | return tot
27 |
28 |
29 | class Logger(object):
30 | """
31 | Redirect stderr to stdout, optionally print stdout to a file,
32 | and optionally force flushing on both stdout and the file.
33 | """
34 |
35 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
36 | self.file = None
37 |
38 | if file_name is not None:
39 | self.file = open(file_name, file_mode)
40 |
41 | self.should_flush = should_flush
42 | self.stdout = sys.stdout
43 | self.stderr = sys.stderr
44 |
45 | sys.stdout = self
46 | sys.stderr = self
47 |
48 | def __enter__(self) -> "Logger":
49 | return self
50 |
51 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
52 | self.close()
53 |
54 | def write(self, text: str) -> None:
55 | """Write text to stdout (and a file) and optionally flush."""
56 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
57 | return
58 |
59 | if self.file is not None:
60 | self.file.write(text)
61 |
62 | self.stdout.write(text)
63 |
64 | if self.should_flush:
65 | self.flush()
66 |
67 | def flush(self) -> None:
68 | """Flush written text to both stdout and a file, if open."""
69 | if self.file is not None:
70 | self.file.flush()
71 |
72 | self.stdout.flush()
73 |
74 | def close(self) -> None:
75 | """Flush, close possible files, and remove stdout/stderr mirroring."""
76 | self.flush()
77 |
78 | # if using multiple loggers, prevent closing in wrong order
79 | if sys.stdout is self:
80 | sys.stdout = self.stdout
81 | if sys.stderr is self:
82 | sys.stderr = self.stderr
83 |
84 | if self.file is not None:
85 | self.file.close()
86 |
87 |
88 | def dict2namespace(config):
89 | namespace = argparse.Namespace()
90 | for key, value in config.items():
91 | if isinstance(value, dict):
92 | new_value = dict2namespace(value)
93 | else:
94 | new_value = value
95 | setattr(namespace, key, new_value)
96 | return namespace
97 |
98 |
99 | def str2bool(v):
100 | if isinstance(v, bool):
101 | return v
102 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
103 | return True
104 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
105 | return False
106 | else:
107 | raise argparse.ArgumentTypeError('Boolean value expected.')
108 |
109 |
110 | def update_state_dict(state_dict, idx_start=9):
111 |
112 | from collections import OrderedDict
113 | new_state_dict = OrderedDict()
114 | for k, v in state_dict.items():
115 | name = k[idx_start:] # remove 'module.0.' of dataparallel
116 | new_state_dict[name]=v
117 |
118 | return new_state_dict
119 |
120 |
121 | # ------------------------------------------------------------------------
122 | def get_accuracy(model, x_orig, y_orig, bs=64, device=torch.device('cuda:0')):
123 | n_batches = x_orig.shape[0] // bs
124 | acc = 0.
125 | for counter in range(n_batches):
126 | x = x_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(device)
127 | y = y_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(device)
128 | output = model(x)
129 | acc += (output.max(1)[1] == y).float().sum()
130 |
131 | return (acc / x_orig.shape[0]).item()
132 |
133 | def get_image_classifier_certified(classifier_path, dataset):
134 | checkpoint = torch.load(classifier_path)
135 | base_classifier = get_architecture(checkpoint["arch"], dataset)
136 | base_classifier.load_state_dict(checkpoint['state_dict'])
137 | return base_classifier
138 |
139 |
140 | def load_data(args, adv_batch_size):
141 | if 'imagenet' in args.domain:
142 | val_dir = '/home/data/imagenet/imagenet' # using imagenet lmdb data
143 | val_transform = data.get_transform(args.domain, 'imval', base_size=224)
144 | val_data = data.imagenet_lmdb_dataset_sub(val_dir, transform=val_transform,
145 | num_sub=args.num_sub, data_seed=args.data_seed)
146 | n_samples = len(val_data)
147 | val_loader = DataLoader(val_data, batch_size=n_samples, shuffle=False, pin_memory=True, num_workers=4)
148 | x_val, y_val = next(iter(val_loader))
149 | elif 'cifar10' in args.domain:
150 | data_dir = './dataset'
151 | transform = transforms.Compose([transforms.ToTensor()])
152 | val_data = data.cifar10_dataset_sub(data_dir, transform=transform,
153 | num_sub=args.num_sub, data_seed=args.data_seed)
154 | n_samples = len(val_data)
155 | val_loader = DataLoader(val_data, batch_size=n_samples, shuffle=False, pin_memory=True, num_workers=4)
156 | x_val, y_val = next(iter(val_loader))
157 | elif 'celebahq' in args.domain:
158 | data_dir = './dataset/celebahq'
159 | attribute = args.classifier_name.split('__')[-1] # `celebahq__Smiling`
160 | val_transform = data.get_transform('celebahq', 'imval')
161 | clean_dset = data.get_dataset('celebahq', 'val', attribute, root=data_dir, transform=val_transform,
162 | fraction=2, data_seed=args.data_seed) # data_seed randomizes here
163 | loader = DataLoader(clean_dset, batch_size=adv_batch_size, shuffle=False,
164 | pin_memory=True, num_workers=4)
165 | x_val, y_val = next(iter(loader)) # [0, 1], 256x256
166 | else:
167 | raise NotImplementedError(f'Unknown domain: {args.domain}!')
168 |
169 | print(f'x_val shape: {x_val.shape}')
170 | x_val, y_val = x_val.contiguous().requires_grad_(True), y_val.contiguous()
171 | print(f'x (min, max): ({x_val.min()}, {x_val.max()})')
172 |
173 | return x_val, y_val
174 |
--------------------------------------------------------------------------------
/zipdata.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os.path as op
3 | from threading import local
4 | from zipfile import ZipFile, BadZipFile
5 |
6 | from PIL import Image
7 | from io import BytesIO
8 | import torch.utils.data as data
9 |
10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
11 |
12 | class ZipData(data.Dataset):
13 | _IGNORE_ATTRS = {'_zip_file'}
14 |
15 | def __init__(self, path, map_file,
16 | transform=None, target_transform=None,
17 | extensions=None):
18 | self._path = path
19 | if not extensions:
20 | extensions = _VALID_IMAGE_TYPES
21 | self._zip_file = ZipFile(path)
22 | self.zip_dict = {}
23 | self.samples = []
24 | self.transform = transform
25 | self.target_transform = target_transform
26 | self.class_to_idx = {}
27 | with open(map_file, 'r') as f:
28 | for line in iter(f.readline, ""):
29 | line = line.strip()
30 | if not line:
31 | continue
32 | cls_idx = [l for l in line.split('\t') if l]
33 | if not cls_idx:
34 | continue
35 | assert len(cls_idx) >= 2, "invalid line: {}".format(line)
36 | idx = int(cls_idx[1])
37 | cls = cls_idx[0]
38 | del cls_idx
39 | at_idx = cls.find('@')
40 | assert at_idx >= 0, "invalid class: {}".format(cls)
41 | cls = cls[at_idx + 1:]
42 | if cls.startswith('/'):
43 | # Python ZipFile expects no root
44 | cls = cls[1:]
45 | assert cls, "invalid class in line {}".format(line)
46 | prev_idx = self.class_to_idx.get(cls)
47 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
48 | cls, idx, prev_idx
49 | )
50 | self.class_to_idx[cls] = idx
51 |
52 | for fst in self._zip_file.infolist():
53 | fname = fst.filename
54 | target = self.class_to_idx.get(fname)
55 | if target is None:
56 | continue
57 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
58 | continue
59 | ext = op.splitext(fname)[1].lower()
60 | if ext in extensions:
61 | self.samples.append((fname, target))
62 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
63 |
64 | def __repr__(self):
65 | return 'ZipData({}, size={})'.format(self._path, len(self))
66 |
67 | def __getstate__(self):
68 | return {
69 | key: val if key not in self._IGNORE_ATTRS else None
70 | for key, val in self.__dict__.iteritems()
71 | }
72 |
73 | def __getitem__(self, index):
74 | proc = multiprocessing.current_process()
75 | pid = proc.pid # get pid of this process.
76 | if pid not in self.zip_dict:
77 | self.zip_dict[pid] = ZipFile(self._path)
78 | zip_file = self.zip_dict[pid]
79 |
80 | if index >= len(self) or index < 0:
81 | raise KeyError("{} is invalid".format(index))
82 | path, target = self.samples[index]
83 | try:
84 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
85 | except BadZipFile:
86 | print("bad zip file")
87 | return None, None
88 | if self.transform is not None:
89 | sample = self.transform(sample)
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 | return sample, target
93 |
94 | def __len__(self):
95 | return len(self.samples)
96 |
--------------------------------------------------------------------------------