├── .gitignore ├── FID.py ├── LICENSE ├── README.md ├── config.py ├── generate.py ├── model.py └── pytorch_fid ├── celeba64_train_stat.npy ├── cifar10_train_stat.npy ├── fid_score.py ├── inception.py └── lsun_bedroom_train_stat.npy /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | __pycache__/ 3 | */__pycache__/ 4 | generated/ 5 | 6 | -------------------------------------------------------------------------------- /FID.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from pytorch_fid.fid_score import calculate_fid_given_paths 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | # dataset and model 10 | parser.add_argument('-name', '--name', type=str, choices=["cifar10", "lsun_bedroom", "celeba64"], 11 | help='Name of experiment') 12 | parser.add_argument('-ema', '--ema', action='store_true', help='Whether use ema') 13 | 14 | # fast generation parameters 15 | parser.add_argument('-approxdiff', '--approxdiff', type=str, choices=['STD', 'STEP', 'VAR'], help='approximate diffusion process') 16 | parser.add_argument('-kappa', '--kappa', type=float, default=1.0, help='factor to be multiplied to sigma') 17 | parser.add_argument('-S', '--S', type=int, default=50, help='number of steps') 18 | parser.add_argument('-schedule', '--schedule', type=str, choices=['linear', 'quadratic'], help='noise level schedules') 19 | 20 | parser.add_argument('-gpu', '--gpu', type=int, default=0, help='gpu device') 21 | 22 | args = parser.parse_args() 23 | 24 | kwargs = {'batch_size': 50, 'device': torch.device('cuda:{}'.format(args.gpu)), 'dims': 2048} 25 | 26 | if args.approxdiff == 'STD': 27 | variance_schedule = '1000' 28 | else: 29 | variance_schedule = '{}{}'.format(args.S, args.schedule) 30 | folder = '{}{}_{}{}_kappa{}'.format('ema_' if args.ema else '', 31 | args.name, 32 | args.approxdiff, 33 | variance_schedule, 34 | args.kappa) 35 | if folder not in os.listdir('generated'): 36 | raise Exception('folder not found') 37 | 38 | paths = ['./generated/{}'.format(folder), 39 | './pytorch_fid/{}_train_stat.npy'.format(args.name)] 40 | fid = calculate_fid_given_paths(paths=paths, **kwargs) 41 | print('{}: FID = {}'.format(folder, fid)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhifeng Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". 2 | FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. See paper via [this link](https://arxiv.org/abs/2106.00132). 3 | 4 | # Pretrained models 5 | Download checkpoints from [this link](https://heibox.uni-heidelberg.de/d/01207c3f6b8441779abf/) and [this link](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing). Put them under ```checkpoints\ema_diffusion_${dataset_name}_model\model.ckpt```, where ```${dataset_name}``` is ```cifar10```, ```celeba64```, ```lsun_bedroom```, ```lsun_church```, or ```lsun_cat```. 6 | 7 | # Usage 8 | General command: ```python generate.py -ema -name ${dataset_name} -approxdiff ${approximate_diffusion_process} -kappa ${kappa} -S ${FastDPM_length} -schedule ${noise_level_schedule} -n ${number_to_generate} -bs ${batchsize} -gpu ${gpu_index}``` 9 | - ```${dataset_name}```: ```cifar10```, ```celeba64```, ```lsun_bedroom```, ```lsun_church```, or ```lsun_cat``` 10 | - ```${approximate_diffusion_process}```: ```VAR``` or ```STEP``` 11 | - ```${kappa}```: a real value between 0 and 1 12 | - ```${FastDPM_length}```: an integer between 1 and 1000; 10, 20, 50, 100 used in paper. 13 | - ```${noise_level_schedule}```: ```linear``` or ```quadratic``` 14 | 15 | ## CIFAR-10 16 | Below are commands to generate CIFAR-10 images. 17 | - Standard DDPM generation: ```python generate.py -ema -name cifar10 -approxdiff STD -n 16 -bs 16``` 18 | - FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16``` 19 | - FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16``` 20 | - FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16``` 21 | - FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16``` 22 | 23 | ## CelebA 24 | Below are commands to generate CelebA images. 25 | - Standard DDPM generation: ```python generate.py -ema -name celeba64 -approxdiff STD -n 16 -bs 16``` 26 | - FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16``` 27 | - FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16``` 28 | - FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16``` 29 | - FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16``` 30 | 31 | ## LSUN_bedroom 32 | Below are commands to generate LSUN bedroom images. 33 | - Standard DDPM generation: ```python generate.py -ema -name lsun_bedroom -approxdiff STD -n 8 -bs 8``` 34 | - FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8``` 35 | - FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8``` 36 | - FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8``` 37 | - FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8``` 38 | 39 | ## Note 40 | To generate 50K samples, set ```-n 50000``` and batchsize (```-bs```) divisible by 50K. 41 | 42 | # Compute FID 43 | To compute FID of generated samples, first make sure there are 50K images, and then run 44 | - ```python FID.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic``` 45 | 46 | # Code References 47 | - [DDPM TensorFlow official](https://github.com/hojonathanho/diffusion) 48 | - [DDPM PyTorch](https://github.com/pesser/pytorch_diffusion) 49 | - [DDPM CelebA-HQ](https://github.com/FengNiMa/pytorch_diffusion_model_celebahq) 50 | - [DDIM PyTorch](https://github.com/ermongroup/ddim) 51 | - [FID PyTorch](https://github.com/mseitzer/pytorch-fid) 52 | - [DiffWave PyTorch 1](https://github.com/lmnt-com/diffwave) 53 | - [DiffWave PyTorch 2](https://github.com/philsyn/DiffWave-Vocoder) 54 | - [DiffWave PyTorch 3](https://github.com/philsyn/DiffWave-unconditional) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | cifar10_cfg = { 2 | "resolution": 32, 3 | "in_channels": 3, 4 | "out_ch": 3, 5 | "ch": 128, 6 | "ch_mult": (1,2,2,2), 7 | "num_res_blocks": 2, 8 | "attn_resolutions": (16,), 9 | "dropout": 0.1, 10 | } 11 | 12 | lsun_cfg = { 13 | "resolution": 256, 14 | "in_channels": 3, 15 | "out_ch": 3, 16 | "ch": 128, 17 | "ch_mult": (1,1,2,2,4,4), 18 | "num_res_blocks": 2, 19 | "attn_resolutions": (16,), 20 | "dropout": 0.0, 21 | } 22 | 23 | celeba64_cfg = { 24 | "resolution": 64, 25 | "in_channels": 3, 26 | "out_ch": 3, 27 | "ch": 128, 28 | "ch_mult": (1,2,2,2,4), 29 | "num_res_blocks": 2, 30 | "attn_resolutions": (16,), 31 | "dropout": 0.1, 32 | } 33 | 34 | model_config_map = { 35 | "cifar10": cifar10_cfg, 36 | "lsun_bedroom": lsun_cfg, 37 | "lsun_cat": lsun_cfg, 38 | "lsun_church": lsun_cfg, 39 | "celeba64": celeba64_cfg 40 | } 41 | 42 | diffusion_config = { 43 | "beta_0": 0.0001, 44 | "beta_T": 0.02, 45 | "T": 1000, 46 | } 47 | 48 | model_var_type_map = { 49 | "cifar10": "fixedlarge", 50 | "lsun_bedroom": "fixedsmall", 51 | "lsun_cat": "fixedsmall", 52 | "lsun_church": "fixedsmall", 53 | } 54 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | np.random.seed(0) 8 | 9 | import torch 10 | import torch.nn as nn 11 | torch.manual_seed(0) 12 | 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | from torchvision.utils import save_image, make_grid 16 | 17 | from model import Model 18 | from config import diffusion_config 19 | 20 | 21 | def _map_gpu(gpu): 22 | if gpu == 'cuda': 23 | return lambda x: x.cuda() 24 | else: 25 | return lambda x: x.to(torch.device('cuda:'+gpu)) 26 | 27 | 28 | def rescale(X, batch=True): 29 | if not batch: 30 | return (X - X.min()) / (X.max() - X.min()) 31 | else: 32 | for i in range(X.shape[0]): 33 | X[i] = rescale(X[i], batch=False) 34 | return X 35 | 36 | 37 | def std_normal(size): 38 | return map_gpu(torch.normal(0, 1, size=size)) 39 | 40 | 41 | def print_size(net): 42 | """ 43 | Print the number of parameters of a network 44 | """ 45 | if net is not None and isinstance(net, torch.nn.Module): 46 | module_parameters = filter(lambda p: p.requires_grad, net.parameters()) 47 | params = sum([np.prod(p.size()) for p in module_parameters]) 48 | print("{} Parameters: {:.6f}M".format( 49 | net.__class__.__name__, params / 1e6), flush=True) 50 | 51 | 52 | def calc_diffusion_hyperparams(T, beta_0, beta_T): 53 | """ 54 | Compute diffusion process hyperparameters 55 | 56 | Parameters: 57 | T (int): number of diffusion steps 58 | beta_0 and beta_T (float): beta schedule start/end value, 59 | where any beta_t in the middle is linearly interpolated 60 | 61 | Returns: 62 | a dictionary of diffusion hyperparameters including: 63 | T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, )) 64 | """ 65 | 66 | Beta = torch.linspace(beta_0, beta_T, T) 67 | Alpha = 1 - Beta 68 | Alpha_bar = Alpha + 0 69 | Beta_tilde = Beta + 0 70 | for t in range(1, T): 71 | Alpha_bar[t] *= Alpha_bar[t-1] 72 | Beta_tilde[t] *= (1-Alpha_bar[t-1]) / (1-Alpha_bar[t]) 73 | Sigma = torch.sqrt(Beta_tilde) 74 | 75 | _dh = {} 76 | _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = T, Beta, Alpha, Alpha_bar, Sigma 77 | diffusion_hyperparams = _dh 78 | return diffusion_hyperparams 79 | 80 | 81 | def bisearch(f, domain, target, eps=1e-8): 82 | """ 83 | find smallest x such that f(x) > target 84 | 85 | Parameters: 86 | f (function): function 87 | domain (tuple): x in (left, right) 88 | target (float): target value 89 | 90 | Returns: 91 | x (float) 92 | """ 93 | # 94 | sign = -1 if target < 0 else 1 95 | left, right = domain 96 | for _ in range(1000): 97 | x = (left + right) / 2 98 | if f(x) < target: 99 | right = x 100 | elif f(x) > (1 + sign * eps) * target: 101 | left = x 102 | else: 103 | break 104 | return x 105 | 106 | 107 | def get_VAR_noise(S, schedule='linear'): 108 | """ 109 | Compute VAR noise levels 110 | 111 | Parameters: 112 | S (int): approximante diffusion process length 113 | schedule (str): linear or quadratic 114 | 115 | Returns: 116 | np array of noise levels, size = (S, ) 117 | """ 118 | target = np.prod(1 - np.linspace(diffusion_config["beta_0"], diffusion_config["beta_T"], diffusion_config["T"])) 119 | 120 | if schedule == 'linear': 121 | g = lambda x: np.linspace(diffusion_config["beta_0"], x, S) 122 | domain = (diffusion_config["beta_0"], 0.99) 123 | elif schedule == 'quadratic': 124 | g = lambda x: np.array([diffusion_config["beta_0"] * (1+i*x) ** 2 for i in range(S)]) 125 | domain = (0.0, 0.95 / np.sqrt(diffusion_config["beta_0"]) / S) 126 | else: 127 | raise NotImplementedError 128 | 129 | f = lambda x: np.prod(1 - g(x)) 130 | largest_var = bisearch(f, domain, target, eps=1e-4) 131 | return g(largest_var) 132 | 133 | 134 | def get_STEP_step(S, schedule='linear'): 135 | """ 136 | Compute STEP steps 137 | 138 | Parameters: 139 | S (int): approximante diffusion process length 140 | schedule (str): linear or quadratic 141 | 142 | Returns: 143 | np array of steps, size = (S, ) 144 | """ 145 | if schedule == 'linear': 146 | c = (diffusion_config["T"] - 1.0) / (S - 1.0) 147 | list_tau = [np.floor(i * c) for i in range(S)] 148 | elif schedule == 'quadratic': 149 | list_tau = np.linspace(0, np.sqrt(diffusion_config["T"] * 0.8), S) ** 2 150 | else: 151 | raise NotImplementedError 152 | 153 | return [int(s) for s in list_tau] 154 | 155 | 156 | def _log_gamma(x): 157 | # Gamma(x+1) ~= sqrt(2\pi x) * (x/e)^x (1 + 1 / 12x) 158 | y = x - 1 159 | return np.log(2 * np.pi * y) / 2 + y * (np.log(y) - 1) + np.log(1 + 1 / (12 * y)) 160 | 161 | 162 | def _log_cont_noise(t, beta_0, beta_T, T): 163 | # We want log_cont_noise(t, beta_0, beta_T, T) ~= np.log(Alpha_bar[-1].numpy()) 164 | delta_beta = (beta_T - beta_0) / (T - 1) 165 | _c = (1.0 - beta_0) / delta_beta 166 | t_1 = t + 1 167 | return t_1 * np.log(delta_beta) + _log_gamma(_c + 1) - _log_gamma(_c - t_1 + 1) 168 | 169 | 170 | # Standard DDPM generation 171 | def STD_sampling(net, size, diffusion_hyperparams): 172 | """ 173 | Perform the complete sampling step according to DDPM 174 | 175 | Parameters: 176 | net (torch network): the model 177 | size (tuple): size of tensor to be generated, 178 | usually is (number of audios to generate, channels=1, length of audio) 179 | diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams 180 | note, the tensors need to be cuda tensors 181 | 182 | Returns: 183 | the generated images in torch.tensor, shape=size 184 | """ 185 | 186 | _dh = diffusion_hyperparams 187 | T, Alpha, Alpha_bar, Beta = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Beta"] 188 | assert len(Alpha_bar) == T 189 | assert len(size) == 4 190 | 191 | Sigma = _dh["Sigma"] 192 | 193 | x = std_normal(size) 194 | with torch.no_grad(): 195 | for t in range(T-1, -1, -1): 196 | diffusion_steps = t * map_gpu(torch.ones(size[0])) 197 | epsilon_theta = net(x, diffusion_steps) 198 | x = (x - (1-Alpha[t])/torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t]) 199 | if t > 0: 200 | x = x + Sigma[t] * std_normal(size) 201 | return x 202 | 203 | 204 | # STEP 205 | def STEP_sampling(net, size, diffusion_hyperparams, user_defined_steps, kappa): 206 | """ 207 | Perform the complete sampling step according to https://arxiv.org/pdf/2010.02502.pdf 208 | official repo: https://github.com/ermongroup/ddim 209 | 210 | Parameters: 211 | net (torch network): the model 212 | size (tuple): size of tensor to be generated, 213 | usually is (number of audios to generate, channels=1, length of audio) 214 | diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams 215 | note, the tensors need to be cuda tensors 216 | user_defined_steps (int list): User defined steps (sorted) 217 | kappa (float): factor multipled over sigma, between 0 and 1 218 | 219 | Returns: 220 | the generated images in torch.tensor, shape=size 221 | """ 222 | _dh = diffusion_hyperparams 223 | T, Alpha, Alpha_bar, _ = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] 224 | assert len(Alpha_bar) == T 225 | assert len(size) == 4 226 | assert 0.0 <= kappa <= 1.0 227 | 228 | T_user = len(user_defined_steps) 229 | user_defined_steps = sorted(list(user_defined_steps), reverse=True) 230 | 231 | x = std_normal(size) 232 | with torch.no_grad(): 233 | for i, tau in enumerate(user_defined_steps): 234 | diffusion_steps = tau * map_gpu(torch.ones(size[0])) 235 | epsilon_theta = net(x, diffusion_steps) 236 | if i == T_user - 1: # the next step is to generate x_0 237 | assert tau == 0 238 | alpha_next = torch.tensor(1.0) 239 | sigma = torch.tensor(0.0) 240 | else: 241 | alpha_next = Alpha_bar[user_defined_steps[i+1]] 242 | sigma = kappa * torch.sqrt((1-alpha_next) / (1-Alpha_bar[tau]) * (1 - Alpha_bar[tau] / alpha_next)) 243 | x *= torch.sqrt(alpha_next / Alpha_bar[tau]) 244 | c = torch.sqrt(1 - alpha_next - sigma ** 2) - torch.sqrt(1 - Alpha_bar[tau]) * torch.sqrt(alpha_next / Alpha_bar[tau]) 245 | x += c * epsilon_theta + sigma * std_normal(size) 246 | return x 247 | 248 | 249 | # VAR 250 | def _precompute_VAR_steps(diffusion_hyperparams, user_defined_eta): 251 | _dh = diffusion_hyperparams 252 | T, Alpha, Alpha_bar, Beta = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Beta"] 253 | assert len(Alpha_bar) == T 254 | 255 | # compute diffusion hyperparameters for user defined noise 256 | T_user = len(user_defined_eta) 257 | Beta_tilde = map_gpu(torch.from_numpy(user_defined_eta)).to(torch.float32) 258 | Gamma_bar = 1 - Beta_tilde 259 | for t in range(1, T_user): 260 | Gamma_bar[t] *= Gamma_bar[t-1] 261 | 262 | assert Gamma_bar[0] <= Alpha_bar[0] and Gamma_bar[-1] >= Alpha_bar[-1] 263 | 264 | continuous_steps = [] 265 | with torch.no_grad(): 266 | for t in range(T_user-1, -1, -1): 267 | t_adapted = None 268 | for i in range(T - 1): 269 | if Alpha_bar[i] >= Gamma_bar[t] > Alpha_bar[i+1]: 270 | t_adapted = bisearch(f=lambda _t: _log_cont_noise(_t, Beta[0].cpu().numpy(), Beta[-1].cpu().numpy(), T), 271 | domain=(i-0.01, i+1.01), 272 | target=np.log(Gamma_bar[t].cpu().numpy())) 273 | break 274 | if t_adapted is None: 275 | t_adapted = T - 1 276 | continuous_steps.append(t_adapted) # must be decreasing 277 | return continuous_steps 278 | 279 | 280 | def VAR_sampling(net, size, diffusion_hyperparams, user_defined_eta, kappa, continuous_steps): 281 | """ 282 | Perform the complete sampling step according to user defined variances 283 | 284 | Parameters: 285 | net (torch network): the model 286 | size (tuple): size of tensor to be generated, 287 | usually is (number of audios to generate, channels=1, length of audio) 288 | diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams 289 | note, the tensors need to be cuda tensors 290 | user_defined_eta (np.array): User defined noise 291 | kappa (float): factor multipled over sigma, between 0 and 1 292 | continuous_steps (list): continuous steps computed from user_defined_eta 293 | 294 | Returns: 295 | the generated images in torch.tensor, shape=size 296 | """ 297 | 298 | _dh = diffusion_hyperparams 299 | T, Alpha, Alpha_bar, Beta = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Beta"] 300 | assert len(Alpha_bar) == T 301 | assert len(size) == 4 302 | assert 0.0 <= kappa <= 1.0 303 | 304 | # compute diffusion hyperparameters for user defined noise 305 | T_user = len(user_defined_eta) 306 | Beta_tilde = map_gpu(torch.from_numpy(user_defined_eta)).to(torch.float32) 307 | Gamma_bar = 1 - Beta_tilde 308 | for t in range(1, T_user): 309 | Gamma_bar[t] *= Gamma_bar[t-1] 310 | 311 | assert Gamma_bar[0] <= Alpha_bar[0] and Gamma_bar[-1] >= Alpha_bar[-1] 312 | 313 | # print('begin sampling, total number of reverse steps = %s' % T_user) 314 | 315 | x = std_normal(size) 316 | with torch.no_grad(): 317 | for i, tau in enumerate(continuous_steps): 318 | diffusion_steps = tau * map_gpu(torch.ones(size[0])) 319 | epsilon_theta = net(x, diffusion_steps) 320 | if i == T_user - 1: # the next step is to generate x_0 321 | assert abs(tau) < 0.1 322 | alpha_next = torch.tensor(1.0) 323 | sigma = torch.tensor(0.0) 324 | else: 325 | alpha_next = Gamma_bar[T_user-1-i - 1] 326 | sigma = kappa * torch.sqrt((1-alpha_next) / (1-Gamma_bar[T_user-1-i]) * (1 - Gamma_bar[T_user-1-i] / alpha_next)) 327 | x *= torch.sqrt(alpha_next / Gamma_bar[T_user-1-i]) 328 | c = torch.sqrt(1 - alpha_next - sigma ** 2) - torch.sqrt(1 - Gamma_bar[T_user-1-i]) * torch.sqrt(alpha_next / Gamma_bar[T_user-1-i]) 329 | x += c * epsilon_theta + sigma * std_normal(size) 330 | 331 | return x 332 | 333 | 334 | def generate(output_name, model_path, model_config, 335 | diffusion_config, approxdiff, generation_param, 336 | n_generate, batchsize, n_exist): 337 | """ 338 | Parameters: 339 | output_name (str): save generated images to this folder 340 | model_path (str): checkpoint file 341 | model_config (dic): dic of model config 342 | diffusion_config (dic): dic of diffusion config 343 | generation_param (dic): parameter: user defined variance or user defined steps 344 | approxdiff (str): diffusion style: STD, STEP, VAR 345 | n_generate (int): number of generated samples 346 | batchsize (int): batch size of training 347 | n_exist (int): existing number of samples 348 | 349 | Returns: 350 | Generated images (tensor): (B, C, H, W) where C = 3 351 | """ 352 | if batchsize > n_generate: 353 | batchsize = n_generate 354 | assert n_generate % batchsize == 0 355 | 356 | if 'generated' not in os.listdir(): 357 | os.mkdir('generated') 358 | if output_name not in os.listdir('generated'): 359 | os.mkdir(os.path.join('generated', output_name)) 360 | 361 | # map diffusion hyperparameters to gpu 362 | diffusion_hyperparams = calc_diffusion_hyperparams(**diffusion_config) 363 | for key in diffusion_hyperparams: 364 | if key is not "T": 365 | diffusion_hyperparams[key] = map_gpu(diffusion_hyperparams[key]) 366 | 367 | # predefine model 368 | net = Model(**model_config) 369 | print_size(net) 370 | 371 | # load checkpoint 372 | try: 373 | checkpoint = torch.load(model_path, map_location='cpu') 374 | net.load_state_dict(checkpoint) 375 | net = map_gpu(net) 376 | net.eval() 377 | print('checkpoint successfully loaded') 378 | except: 379 | raise Exception('No valid model found') 380 | 381 | # sampling 382 | C, H, W = model_config["in_channels"], model_config["resolution"], model_config["resolution"] 383 | for i in tqdm(range(n_exist // batchsize, n_generate // batchsize)): 384 | if approxdiff == 'STD': 385 | Xi = STD_sampling(net, (batchsize, C, H, W), diffusion_hyperparams) 386 | elif approxdiff == 'STEP': 387 | user_defined_steps = generation_param["user_defined_steps"] 388 | Xi = STEP_sampling(net, (batchsize, C, H, W), 389 | diffusion_hyperparams, 390 | user_defined_steps, 391 | kappa=generation_param["kappa"]) 392 | elif approxdiff == 'VAR': 393 | user_defined_eta = generation_param["user_defined_eta"] 394 | continuous_steps = _precompute_VAR_steps(diffusion_hyperparams, user_defined_eta) 395 | Xi = VAR_sampling(net, (batchsize, C, H, W), 396 | diffusion_hyperparams, 397 | user_defined_eta, 398 | kappa=generation_param["kappa"], 399 | continuous_steps=continuous_steps) 400 | 401 | # save image 402 | for j, x in enumerate(rescale(Xi)): 403 | index = i * batchsize + j 404 | save_image(x, fp=os.path.join('generated', output_name, '{}.jpg'.format(index))) 405 | save_image(make_grid(rescale(Xi)[:64]), fp=os.path.join('generated', '{}.jpg'.format(output_name))) 406 | 407 | 408 | if __name__ == '__main__': 409 | parser = argparse.ArgumentParser() 410 | # dataset and model 411 | parser.add_argument('-name', '--name', type=str, choices=["cifar10", "lsun_bedroom", "lsun_church", "lsun_cat", "celeba64"], 412 | help='Name of experiment') 413 | parser.add_argument('-ema', '--ema', action='store_true', help='Whether use ema') 414 | 415 | # fast generation parameters 416 | parser.add_argument('-approxdiff', '--approxdiff', type=str, choices=['STD', 'STEP', 'VAR'], help='approximate diffusion process') 417 | parser.add_argument('-kappa', '--kappa', type=float, default=1.0, help='factor to be multiplied to sigma') 418 | parser.add_argument('-S', '--S', type=int, default=50, help='number of steps') 419 | parser.add_argument('-schedule', '--schedule', type=str, choices=['linear', 'quadratic'], help='noise level schedules') 420 | 421 | # generation util 422 | parser.add_argument('-n', '--n_generate', type=int, help='Number of samples to generate') 423 | parser.add_argument('-bs', '--batchsize', type=int, default=256, help='Batchsize of generation') 424 | parser.add_argument('-gpu', '--gpu', type=str, default='cuda', choices=['cuda']+[str(i) for i in range(16)], help='gpu device') 425 | 426 | args = parser.parse_args() 427 | 428 | global map_gpu 429 | map_gpu = _map_gpu(args.gpu) 430 | 431 | from config import model_config_map 432 | model_config = model_config_map[args.name] 433 | 434 | 435 | kappa = args.kappa 436 | if args.approxdiff == 'STD': 437 | variance_schedule = '1000' 438 | generation_param = {"kappa": kappa} 439 | 440 | elif args.approxdiff == 'VAR': # user defined variance 441 | user_defined_eta = get_VAR_noise(args.S, args.schedule) 442 | generation_param = {"kappa": kappa, 443 | "user_defined_eta": user_defined_eta} 444 | variance_schedule = '{}{}'.format(args.S, args.schedule) 445 | 446 | elif args.approxdiff == 'STEP': # user defined step 447 | user_defined_steps = get_STEP_step(args.S, args.schedule) 448 | generation_param = {"kappa": kappa, 449 | "user_defined_steps": user_defined_steps} 450 | variance_schedule = '{}{}'.format(args.S, args.schedule) 451 | 452 | else: 453 | raise NotImplementedError 454 | 455 | output_name = '{}{}_{}{}_kappa{}'.format('ema_' if args.ema else '', 456 | args.name, 457 | args.approxdiff, 458 | variance_schedule, 459 | kappa) 460 | 461 | n_exist = 0 462 | if 'generated' in os.listdir() and output_name in os.listdir('generated'): 463 | if len(os.listdir(os.path.join('generated', output_name))) == args.n_generate: 464 | print('{} already finished'.format(output_name)) 465 | n_exist = args.n_generate 466 | else: 467 | n_exist = len(os.listdir(os.path.join('generated', output_name))) 468 | 469 | if n_exist < args.n_generate: 470 | if n_exist > 0: 471 | print('{} already generated, resuming'.format(n_exist)) 472 | else: 473 | print('start generating') 474 | model_path = os.path.join('checkpoints', 475 | '{}diffusion_{}_model'.format('ema_' if args.ema else '', args.name), 476 | 'model.ckpt') 477 | generate(output_name, model_path, model_config, 478 | diffusion_config, args.approxdiff, generation_param, 479 | args.n_generate, args.batchsize, n_exist) 480 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0,1,0,0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x*torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 49 | if self.with_conv: 50 | x = self.conv(x) 51 | return x 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, in_channels, with_conv): 56 | super().__init__() 57 | self.with_conv = with_conv 58 | if self.with_conv: 59 | # no asymmetric padding in torch conv, must do it ourselves 60 | self.conv = torch.nn.Conv2d(in_channels, 61 | in_channels, 62 | kernel_size=3, 63 | stride=2, 64 | padding=0) 65 | 66 | def forward(self, x): 67 | if self.with_conv: 68 | pad = (0,1,0,1) 69 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 70 | x = self.conv(x) 71 | else: 72 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 73 | return x 74 | 75 | 76 | class ResnetBlock(nn.Module): 77 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 78 | dropout, temb_channels=512): 79 | super().__init__() 80 | self.in_channels = in_channels 81 | out_channels = in_channels if out_channels is None else out_channels 82 | self.out_channels = out_channels 83 | self.use_conv_shortcut = conv_shortcut 84 | 85 | self.norm1 = Normalize(in_channels) 86 | self.conv1 = torch.nn.Conv2d(in_channels, 87 | out_channels, 88 | kernel_size=3, 89 | stride=1, 90 | padding=1) 91 | self.temb_proj = torch.nn.Linear(temb_channels, 92 | out_channels) 93 | self.norm2 = Normalize(out_channels) 94 | self.dropout = torch.nn.Dropout(dropout) 95 | self.conv2 = torch.nn.Conv2d(out_channels, 96 | out_channels, 97 | kernel_size=3, 98 | stride=1, 99 | padding=1) 100 | if self.in_channels != self.out_channels: 101 | if self.use_conv_shortcut: 102 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 103 | out_channels, 104 | kernel_size=3, 105 | stride=1, 106 | padding=1) 107 | else: 108 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 109 | out_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | 114 | def forward(self, x, temb): 115 | h = x 116 | h = self.norm1(h) 117 | h = nonlinearity(h) 118 | h = self.conv1(h) 119 | 120 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 121 | 122 | h = self.norm2(h) 123 | h = nonlinearity(h) 124 | h = self.dropout(h) 125 | h = self.conv2(h) 126 | 127 | if self.in_channels != self.out_channels: 128 | if self.use_conv_shortcut: 129 | x = self.conv_shortcut(x) 130 | else: 131 | x = self.nin_shortcut(x) 132 | 133 | return x+h 134 | 135 | 136 | class AttnBlock(nn.Module): 137 | def __init__(self, in_channels): 138 | super().__init__() 139 | self.in_channels = in_channels 140 | 141 | self.norm = Normalize(in_channels) 142 | self.q = torch.nn.Conv2d(in_channels, 143 | in_channels, 144 | kernel_size=1, 145 | stride=1, 146 | padding=0) 147 | self.k = torch.nn.Conv2d(in_channels, 148 | in_channels, 149 | kernel_size=1, 150 | stride=1, 151 | padding=0) 152 | self.v = torch.nn.Conv2d(in_channels, 153 | in_channels, 154 | kernel_size=1, 155 | stride=1, 156 | padding=0) 157 | self.proj_out = torch.nn.Conv2d(in_channels, 158 | in_channels, 159 | kernel_size=1, 160 | stride=1, 161 | padding=0) 162 | 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b,c,h,w = q.shape 173 | q = q.reshape(b,c,h*w) 174 | q = q.permute(0,2,1) # b,hw,c 175 | k = k.reshape(b,c,h*w) # b,c,hw 176 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c)**(-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b,c,h*w) 182 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 183 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = h_.reshape(b,c,h,w) 185 | 186 | h_ = self.proj_out(h_) 187 | 188 | return x+h_ 189 | 190 | 191 | class Model(nn.Module): 192 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 193 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 194 | resolution): 195 | super().__init__() 196 | self.ch = ch 197 | self.temb_ch = self.ch*4 198 | self.num_resolutions = len(ch_mult) 199 | self.num_res_blocks = num_res_blocks 200 | self.resolution = resolution 201 | self.in_channels = in_channels 202 | 203 | # timestep embedding 204 | self.temb = nn.Module() 205 | self.temb.dense = nn.ModuleList([ 206 | torch.nn.Linear(self.ch, 207 | self.temb_ch), 208 | torch.nn.Linear(self.temb_ch, 209 | self.temb_ch), 210 | ]) 211 | 212 | # downsampling 213 | self.conv_in = torch.nn.Conv2d(in_channels, 214 | self.ch, 215 | kernel_size=3, 216 | stride=1, 217 | padding=1) 218 | 219 | curr_res = resolution 220 | in_ch_mult = (1,)+ch_mult 221 | self.down = nn.ModuleList() 222 | for i_level in range(self.num_resolutions): 223 | block = nn.ModuleList() 224 | attn = nn.ModuleList() 225 | block_in = ch*in_ch_mult[i_level] 226 | block_out = ch*ch_mult[i_level] 227 | for i_block in range(self.num_res_blocks): 228 | block.append(ResnetBlock(in_channels=block_in, 229 | out_channels=block_out, 230 | temb_channels=self.temb_ch, 231 | dropout=dropout)) 232 | block_in = block_out 233 | if curr_res in attn_resolutions: 234 | attn.append(AttnBlock(block_in)) 235 | down = nn.Module() 236 | down.block = block 237 | down.attn = attn 238 | if i_level != self.num_resolutions-1: 239 | down.downsample = Downsample(block_in, resamp_with_conv) 240 | curr_res = curr_res // 2 241 | self.down.append(down) 242 | 243 | # middle 244 | self.mid = nn.Module() 245 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 246 | out_channels=block_in, 247 | temb_channels=self.temb_ch, 248 | dropout=dropout) 249 | self.mid.attn_1 = AttnBlock(block_in) 250 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 251 | out_channels=block_in, 252 | temb_channels=self.temb_ch, 253 | dropout=dropout) 254 | 255 | # upsampling 256 | self.up = nn.ModuleList() 257 | for i_level in reversed(range(self.num_resolutions)): 258 | block = nn.ModuleList() 259 | attn = nn.ModuleList() 260 | block_out = ch*ch_mult[i_level] 261 | skip_in = ch*ch_mult[i_level] 262 | for i_block in range(self.num_res_blocks+1): 263 | if i_block == self.num_res_blocks: 264 | skip_in = ch*in_ch_mult[i_level] 265 | block.append(ResnetBlock(in_channels=block_in+skip_in, 266 | out_channels=block_out, 267 | temb_channels=self.temb_ch, 268 | dropout=dropout)) 269 | block_in = block_out 270 | if curr_res in attn_resolutions: 271 | attn.append(AttnBlock(block_in)) 272 | up = nn.Module() 273 | up.block = block 274 | up.attn = attn 275 | if i_level != 0: 276 | up.upsample = Upsample(block_in, resamp_with_conv) 277 | curr_res = curr_res * 2 278 | self.up.insert(0, up) # prepend to get consistent order 279 | 280 | # end 281 | self.norm_out = Normalize(block_in) 282 | self.conv_out = torch.nn.Conv2d(block_in, 283 | out_ch, 284 | kernel_size=3, 285 | stride=1, 286 | padding=1) 287 | 288 | 289 | def forward(self, x, t): 290 | assert x.shape[2] == x.shape[3] == self.resolution 291 | 292 | # timestep embedding 293 | temb = get_timestep_embedding(t, self.ch) 294 | temb = self.temb.dense[0](temb) 295 | temb = nonlinearity(temb) 296 | temb = self.temb.dense[1](temb) 297 | 298 | # downsampling 299 | hs = [self.conv_in(x)] 300 | for i_level in range(self.num_resolutions): 301 | for i_block in range(self.num_res_blocks): 302 | h = self.down[i_level].block[i_block](hs[-1], temb) 303 | if len(self.down[i_level].attn) > 0: 304 | h = self.down[i_level].attn[i_block](h) 305 | hs.append(h) 306 | if i_level != self.num_resolutions-1: 307 | hs.append(self.down[i_level].downsample(hs[-1])) 308 | 309 | # middle 310 | h = hs[-1] 311 | h = self.mid.block_1(h, temb) 312 | h = self.mid.attn_1(h) 313 | h = self.mid.block_2(h, temb) 314 | 315 | # upsampling 316 | for i_level in reversed(range(self.num_resolutions)): 317 | for i_block in range(self.num_res_blocks+1): 318 | h = self.up[i_level].block[i_block]( 319 | torch.cat([h, hs.pop()], dim=1), temb) 320 | if len(self.up[i_level].attn) > 0: 321 | h = self.up[i_level].attn[i_block](h) 322 | if i_level != 0: 323 | h = self.up[i_level].upsample(h) 324 | 325 | # end 326 | h = self.norm_out(h) 327 | h = nonlinearity(h) 328 | h = self.conv_out(h) 329 | return h -------------------------------------------------------------------------------- /pytorch_fid/celeba64_train_stat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhifengkong/FastDPM_pytorch/6540c1cdac3799aff8a5f7b9de430269bbd0b7c3/pytorch_fid/celeba64_train_stat.npy -------------------------------------------------------------------------------- /pytorch_fid/cifar10_train_stat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhifengkong/FastDPM_pytorch/6540c1cdac3799aff8a5f7b9de430269bbd0b7c3/pytorch_fid/cifar10_train_stat.npy -------------------------------------------------------------------------------- /pytorch_fid/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 37 | from multiprocessing import cpu_count 38 | 39 | import numpy as np 40 | import torch 41 | import torch.nn.functional as F 42 | import torchvision.transforms as TF 43 | from PIL import Image 44 | from scipy import linalg 45 | from torch.nn.functional import adaptive_avg_pool2d 46 | 47 | try: 48 | from tqdm import tqdm 49 | except ImportError: 50 | # If tqdm is not available, provide a mock version of it 51 | def tqdm(x): 52 | return x 53 | 54 | try: 55 | from inception import InceptionV3 56 | except ImportError: 57 | from .inception import InceptionV3 58 | 59 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 60 | parser.add_argument('--batch-size', type=int, default=50, 61 | help='Batch size to use') 62 | parser.add_argument('--device', type=str, default=None, 63 | help='Device to use. Like cuda, cuda:0 or cpu') 64 | parser.add_argument('--dims', type=int, default=2048, 65 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 66 | help=('Dimensionality of Inception features to use. ' 67 | 'By default, uses pool3 features')) 68 | parser.add_argument('path', type=str, nargs=2, 69 | help=('Paths to the generated images or ' 70 | 'to .npz statistic files')) 71 | 72 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 73 | 'tif', 'tiff', 'webp'} 74 | 75 | 76 | class Crop(object): 77 | def __init__(self, x1, x2, y1, y2): 78 | self.x1 = x1 79 | self.x2 = x2 80 | self.y1 = y1 81 | self.y2 = y2 82 | 83 | def __call__(self, img): 84 | return TF.functional.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 85 | 86 | def __repr__(self): 87 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 88 | self.x1, self.x2, self.y1, self.y2 89 | ) 90 | 91 | 92 | class ImagePathDataset(torch.utils.data.Dataset): 93 | def __init__(self, files, transforms=None): 94 | self.files = files 95 | self.transforms = transforms 96 | 97 | def __len__(self): 98 | return len(self.files) 99 | 100 | def __getitem__(self, i): 101 | path = self.files[i] 102 | img = Image.open(path).convert('RGB') 103 | if self.transforms is not None: 104 | img = self.transforms(img) 105 | return img 106 | 107 | 108 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize=0): 109 | """Calculates the activations of the pool_3 layer for all images. 110 | 111 | Params: 112 | -- files : List of image files paths 113 | -- model : Instance of inception model 114 | -- batch_size : Batch size of images for the model to process at once. 115 | Make sure that the number of samples is a multiple of 116 | the batch size, otherwise some samples are ignored. This 117 | behavior is retained to match the original FID score 118 | implementation. 119 | -- dims : Dimensionality of features returned by Inception 120 | -- device : Device to run calculations 121 | 122 | Returns: 123 | -- A numpy array of dimension (num images, dims) that contains the 124 | activations of the given tensor when feeding inception with the 125 | query tensor. 126 | """ 127 | model.eval() 128 | 129 | if batch_size > len(files): 130 | print(('Warning: batch size is bigger than the data size. ' 131 | 'Setting batch size to data size')) 132 | batch_size = len(files) 133 | 134 | if resize > 0: 135 | print('Resized to ({}, {})'.format(resize, resize)) 136 | dataset = ImagePathDataset(files, transforms=TF.Compose([TF.Resize(size=(resize, resize)), 137 | TF.ToTensor()])) 138 | else: 139 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 140 | dataloader = torch.utils.data.DataLoader(dataset, 141 | batch_size=batch_size, 142 | shuffle=False, 143 | drop_last=False, 144 | num_workers=cpu_count()) 145 | 146 | pred_arr = np.empty((len(files), dims)) 147 | 148 | start_idx = 0 149 | 150 | for batch in tqdm(dataloader): 151 | batch = batch.to(device) 152 | 153 | with torch.no_grad(): 154 | pred = model(batch)[0] 155 | 156 | # If model output is not scalar, apply global spatial average pooling. 157 | # This happens if you choose a dimensionality not equal 2048. 158 | if pred.size(2) != 1 or pred.size(3) != 1: 159 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 160 | 161 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 162 | 163 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 164 | 165 | start_idx = start_idx + pred.shape[0] 166 | 167 | return pred_arr 168 | 169 | 170 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 171 | """Numpy implementation of the Frechet Distance. 172 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 173 | and X_2 ~ N(mu_2, C_2) is 174 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 175 | 176 | Stable version by Dougal J. Sutherland. 177 | 178 | Params: 179 | -- mu1 : Numpy array containing the activations of a layer of the 180 | inception net (like returned by the function 'get_predictions') 181 | for generated samples. 182 | -- mu2 : The sample mean over activations, precalculated on an 183 | representative data set. 184 | -- sigma1: The covariance matrix over activations for generated samples. 185 | -- sigma2: The covariance matrix over activations, precalculated on an 186 | representative data set. 187 | 188 | Returns: 189 | -- : The Frechet Distance. 190 | """ 191 | 192 | mu1 = np.atleast_1d(mu1) 193 | mu2 = np.atleast_1d(mu2) 194 | 195 | sigma1 = np.atleast_2d(sigma1) 196 | sigma2 = np.atleast_2d(sigma2) 197 | 198 | assert mu1.shape == mu2.shape, \ 199 | 'Training and test mean vectors have different lengths' 200 | assert sigma1.shape == sigma2.shape, \ 201 | 'Training and test covariances have different dimensions' 202 | 203 | diff = mu1 - mu2 204 | 205 | # Product might be almost singular 206 | covmean = linalg.sqrtm(sigma1.dot(sigma2)) 207 | if not np.isfinite(covmean).all(): 208 | msg = ('fid calculation produces singular product; ' 209 | 'adding %s to diagonal of cov estimates') % eps 210 | print(msg) 211 | offset = np.eye(sigma1.shape[0]) * eps 212 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 213 | 214 | # Numerical error might give slight imaginary component 215 | if np.iscomplexobj(covmean): 216 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 217 | m = np.max(np.abs(covmean.imag)) 218 | raise ValueError('Imaginary component {}'.format(m)) 219 | covmean = covmean.real 220 | 221 | tr_covmean = np.trace(covmean) 222 | 223 | return (diff.dot(diff) + np.trace(sigma1) 224 | + np.trace(sigma2) - 2 * tr_covmean) 225 | 226 | 227 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 228 | device='cpu', resize=0): 229 | """Calculation of the statistics used by the FID. 230 | Params: 231 | -- files : List of image files paths 232 | -- model : Instance of inception model 233 | -- batch_size : The images numpy array is split into batches with 234 | batch size batch_size. A reasonable batch size 235 | depends on the hardware. 236 | -- dims : Dimensionality of features returned by Inception 237 | -- device : Device to run calculations 238 | -- resize : resize image to this shape 239 | 240 | Returns: 241 | -- mu : The mean over samples of the activations of the pool_3 layer of 242 | the inception model. 243 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 244 | the inception model. 245 | """ 246 | act = get_activations(files, model, batch_size, dims, device, resize) 247 | mu = np.mean(act, axis=0) 248 | sigma = np.cov(act, rowvar=False) 249 | return mu, sigma 250 | 251 | 252 | def compute_statistics_of_path(path, model, batch_size, dims, device, resize=0): 253 | if path.endswith('.npz') or path.endswith('.npy'): 254 | f = np.load(path, allow_pickle=True) 255 | try: 256 | m, s = f['mu'][:], f['sigma'][:] 257 | except: 258 | m, s = f.item()['mu'][:], f.item()['sigma'][:] 259 | else: 260 | path_str = path[:] 261 | path = pathlib.Path(path) 262 | files = sorted([file for ext in IMAGE_EXTENSIONS 263 | for file in path.glob('*.{}'.format(ext))]) 264 | m, s = calculate_activation_statistics(files, model, batch_size, 265 | dims, device, resize) 266 | return m, s 267 | 268 | 269 | def calculate_fid_given_paths(paths, batch_size, device, dims, resize=0): 270 | """Calculates the FID of two paths""" 271 | for p in paths: 272 | if not os.path.exists(p): 273 | raise RuntimeError('Invalid path: %s' % p) 274 | 275 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 276 | 277 | model = InceptionV3([block_idx]).to(device) 278 | 279 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 280 | dims, device, resize) 281 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 282 | dims, device, resize) 283 | 284 | del model 285 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 286 | return fid_value 287 | 288 | 289 | def compute_statistics_of_dataloader(ds, batch_size, dims, device): 290 | 291 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 292 | model = InceptionV3([block_idx]).to(device) 293 | model.eval() 294 | 295 | if ds == 'lsun_bedroom_train': 296 | from torchvision.datasets import LSUN 297 | # seed following https://github.com/ermongroup/ddim 298 | random_state = np.random.get_state() 299 | np.random.seed(2019) 300 | np.random.set_state(random_state) 301 | torch.manual_seed(1234) 302 | dataset = LSUN('/tmp2/LSUN', ['bedroom_train'], transform=TF.Compose([TF.Resize(256), 303 | TF.CenterCrop(256), 304 | TF.RandomHorizontalFlip(p=0.5), 305 | TF.ToTensor()])) 306 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1) 307 | elif ds == 'celeba64_train': 308 | from torchvision.datasets import CelebA 309 | # seed and preprocess following https://github.com/ermongroup/ddim 310 | random_state = np.random.get_state() 311 | np.random.seed(2019) 312 | np.random.set_state(random_state) 313 | torch.manual_seed(1234) 314 | cx = 89 315 | cy = 121 316 | x1 = cy - 64 317 | x2 = cy + 64 318 | y1 = cx - 64 319 | y2 = cx + 64 320 | dataset = CelebA('/tmp2/celeba64', split='train', transform=TF.Compose([Crop(x1, x2, y1, y2), 321 | TF.Resize(size=(64, 64)), 322 | TF.RandomHorizontalFlip(), 323 | TF.ToTensor()])) 324 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1) 325 | else: 326 | raise NotImplementedError 327 | 328 | pred_arr = np.empty((len(dataset), dims)) 329 | start_idx = 0 330 | for batch in tqdm(dataloader): 331 | if type(batch) in [tuple, list]: 332 | batch = batch[0] 333 | batch = batch.to(device) 334 | with torch.no_grad(): 335 | pred = model(batch)[0] 336 | 337 | # If model output is not scalar, apply global spatial average pooling. 338 | # This happens if you choose a dimensionality not equal 2048. 339 | if pred.size(2) != 1 or pred.size(3) != 1: 340 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 341 | 342 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 343 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 344 | start_idx = start_idx + pred.shape[0] 345 | act = pred_arr 346 | 347 | m = np.mean(act, axis=0) 348 | s = np.cov(act, rowvar=False) 349 | np.save('{}_stat.npy'.format(ds), 350 | {'mu': m, 'sigma': s}, 351 | allow_pickle=True) 352 | 353 | 354 | def main(): 355 | args = parser.parse_args() 356 | 357 | if args.device is None: 358 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 359 | else: 360 | device = torch.device(args.device) 361 | 362 | fid_value = calculate_fid_given_paths(args.path, 363 | args.batch_size, 364 | device, 365 | args.dims) 366 | print('FID: ', fid_value) 367 | 368 | 369 | if __name__ == '__main__': 370 | main() -------------------------------------------------------------------------------- /pytorch_fid/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /pytorch_fid/lsun_bedroom_train_stat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhifengkong/FastDPM_pytorch/6540c1cdac3799aff8a5f7b9de430269bbd0b7c3/pytorch_fid/lsun_bedroom_train_stat.npy --------------------------------------------------------------------------------