├── README.md ├── celeba_lsun_codes ├── core │ ├── func │ │ ├── __init__.py │ │ ├── differential.py │ │ └── functions.py │ └── inference │ │ ├── __init__.py │ │ ├── ll │ │ └── elbo.py │ │ ├── sampler │ │ ├── __init__.py │ │ ├── reverse_ddim.py │ │ └── reverse_ddpm.py │ │ └── utils.py ├── interface │ ├── datasets │ │ ├── __init__.py │ │ ├── cifar10.py │ │ ├── dataset_factory.py │ │ ├── other_dst │ │ │ ├── __init__.py │ │ │ ├── celeba.py │ │ │ ├── ffhq.py │ │ │ ├── lsun.py │ │ │ ├── utils.py │ │ │ └── vision.py │ │ └── utils.py │ ├── runner.py │ ├── task_schedule.py │ └── utils.py ├── pytorch_diffusion │ ├── __init__.py │ ├── ckpt_util.py │ ├── demo.py │ ├── diffusion.py │ └── model.py ├── run_celeba.py ├── run_celeba_rep.py ├── run_lsun_bedroom.py ├── scripts │ ├── convert.py │ └── pytorch_diffusion_demo └── tools │ ├── fid_score.py │ └── inception.py └── cifar_imagenet_codes ├── core ├── __init__.py ├── criterions │ ├── __init__.py │ ├── base.py │ └── ddpm.py ├── evaluate │ ├── __init__.py │ ├── sample.py │ └── score.py ├── func │ ├── __init__.py │ ├── differential.py │ └── functions.py ├── inference │ ├── __init__.py │ ├── ll │ │ └── elbo.py │ ├── sampler │ │ ├── __init__.py │ │ ├── reverse_ddim.py │ │ └── reverse_ddpm.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── fp16_util.py │ ├── nn.py │ └── unet.py └── utils │ ├── __init__.py │ ├── clip_grad.py │ ├── device_utils.py │ ├── diagnose.py │ ├── ema.py │ └── managers.py ├── dp_vb.py ├── interface ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── dataset_factory.py │ ├── imagenet64.py │ └── utils.py ├── evaluators │ ├── __init__.py │ ├── base.py │ ├── ddpm_evaluator.py │ └── utils.py ├── runner │ ├── __init__.py │ ├── fit.py │ ├── runner.py │ └── timing.py └── utils │ ├── __init__.py │ ├── ckpt.py │ ├── dict_utils.py │ ├── exp_templates.py │ ├── interact.py │ ├── misc.py │ ├── plot.py │ ├── profile_utils.py │ ├── reproducibility.py │ └── task_schedule.py ├── profiles ├── __init__.py ├── common.py └── ddpm │ ├── __init__.py │ ├── beta_schedules.py │ ├── cifar10 │ ├── __init__.py │ ├── base.py │ ├── naive_evaluate.py │ └── train.py │ └── imagenet64 │ ├── __init__.py │ ├── base.py │ └── evaluate.py ├── run_cifar10.py ├── run_imagenet64.py └── tools ├── eval.py ├── fid_score.py └── inception.py /README.md: -------------------------------------------------------------------------------- 1 | # Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models 2 | 3 | Code for the paper [Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models 4 | ](https://arxiv.org/abs/2201.06503) 5 | 6 | News (May 18, 2022): We provide an extended codebase (https://github.com/baofff/Extended-Analytic-DPM) for Analytic-DPM: 7 | * It reproduces all main results, and additionally applies Analytic-DPM to score-based SDE. 8 | * For easy reproducing, it provides pretrained DPMs converted to a format that can be directly used, as well as running commands and FID statistics. 9 | 10 | News (Apr 22, 2022): Analytic-DPM received an *Outstanding Paper Award* at ICLR 2022! 11 | 12 | ## Requirements 13 | pytorch=1.9.0 14 | 15 | ## Run experiments 16 | 17 | You can change the `phase` variable in the code to determine the specific experiment you run. 18 | 19 | For example, setting `phase = "sample_analytic_ddpm"` will run sampling using the Analytic-DDPM. 20 | 21 | You can find all available phases in `run_xxx.py`. 22 | 23 | 24 | ### CIFAR10 25 | ``` 26 | $ cd cifar_imagenet_codes 27 | $ python run_cifar10.py 28 | ``` 29 | 30 | ### CelebA 64x64 31 | ``` 32 | $ cd celeba_lsun_codes 33 | $ python run_celeba.py 34 | ``` 35 | 36 | ### Imagenet 64x64 37 | ``` 38 | $ cd cifar_imagenet_codes 39 | $ python run_imagenet64.py 40 | ``` 41 | 42 | ### LSUN Bedroom 43 | ``` 44 | $ cd celeba_lsun_codes 45 | $ python run_lsun_bedroom.py 46 | ``` 47 | 48 | ## Pretrained models and precalculated statistics 49 | 50 | * CIFAR10 model: [[checkpoint](https://drive.google.com/file/d/1WyoUFDQeJUJblAT85Tc1ntbqklvqMP3J/view?usp=sharing)] trained by ourselves 51 | 52 | * CelebA 64x64 model: [[checkpoint](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing)] from https://github.com/ermongroup/ddim 53 | 54 | * Imagenet 64x64 model: [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_100M_1500K.pt)] from https://github.com/openai/improved-diffusion 55 | 56 | * LSUN Bedroom model: [[checkpoint](https://heibox.uni-heidelberg.de/d/01207c3f6b8441779abf/)] from https://github.com/pesser/pytorch_diffusion 57 | 58 | * Precalculated Gamma vectors: [[link](https://drive.google.com/file/d/1pnwxNFY-0P_IZaTVP1zNBxzKb3T1QeD7/view?usp=sharing)] 59 | 60 | * Precalculated FID statistics (calculated as described in Appendix F.2 in the paper): [[link](https://drive.google.com/drive/folders/1aqSXiJSFRqtqHBAsgUw4puZcRqrqOoHx?usp=sharing)]. 61 | 62 | ## This implementation is based on / inspired by 63 | 64 | * https://github.com/pesser/pytorch_diffusion (provide codes of models for CelebA64x64 and LSUN Bedroom) 65 | 66 | * https://github.com/openai/improved-diffusion (provide codes of models for CIFAR10 and Imagenet64x64) 67 | 68 | * https://github.com/mseitzer/pytorch-fid (provide the official implementation of FID to PyTorch) 69 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/func/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .differential import * 3 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/func/differential.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["RequiresGradContext", "differential"] 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.autograd as autograd 8 | from typing import Union, List 9 | 10 | 11 | def judge_requires_grad(obj: Union[torch.Tensor, nn.Module]): 12 | if isinstance(obj, torch.Tensor): 13 | return obj.requires_grad 14 | elif isinstance(obj, nn.Module): 15 | return next(obj.parameters()).requires_grad 16 | else: 17 | raise TypeError 18 | 19 | 20 | class RequiresGradContext(object): 21 | def __init__(self, *objs: Union[torch.Tensor, nn.Module], requires_grad: Union[List[bool], bool]): 22 | self.objs = objs 23 | self.backups = [judge_requires_grad(obj) for obj in objs] 24 | if isinstance(requires_grad, bool): 25 | self.requires_grads = [requires_grad] * len(objs) 26 | elif isinstance(requires_grad, list): 27 | self.requires_grads = requires_grad 28 | else: 29 | raise TypeError 30 | assert len(self.objs) == len(self.requires_grads) 31 | 32 | def __enter__(self): 33 | for obj, requires_grad in zip(self.objs, self.requires_grads): 34 | obj.requires_grad_(requires_grad) 35 | 36 | def __exit__(self, exc_type, exc_val, exc_tb): 37 | for obj, backup in zip(self.objs, self.backups): 38 | obj.requires_grad_(backup) 39 | 40 | 41 | def differential(fn, v, retain_graph=None, create_graph=False): 42 | r""" d fn / dv 43 | Args: 44 | fn: a batch of tensor -> a batch of scalar 45 | v: a batch of tensor 46 | retain_graph: see autograd.grad, default to create_graph 47 | create_graph: see autograd.grad 48 | """ 49 | if retain_graph is None: 50 | retain_graph = create_graph 51 | with RequiresGradContext(v, requires_grad=True): 52 | return autograd.grad(fn(v).sum(), v, retain_graph=retain_graph, create_graph=create_graph)[0] 53 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/func/functions.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["stp", "sos", "mos", "inner_product", "duplicate", "unsqueeze_like", "logsumexp", "log_discretized_normal", 3 | "binary_cross_entropy_with_logits", "log_bernoulli", "kl_between_normal"] 4 | 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch 9 | 10 | 11 | def stp(s: np.ndarray, ts: torch.Tensor): # scalar tensor product 12 | s = torch.from_numpy(s).type_as(ts) 13 | extra_dims = (1,) * (ts.dim() - 1) 14 | return s.view(-1, *extra_dims) * ts 15 | 16 | 17 | def sos(a, start_dim=1): # sum of square 18 | return a.pow(2).flatten(start_dim=start_dim).sum(dim=-1) 19 | 20 | 21 | def mos(a, start_dim=1): # mean of square 22 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 23 | 24 | 25 | def inner_product(a, b, start_dim=1): 26 | return (a * b).flatten(start_dim=start_dim).sum(dim=-1) 27 | 28 | 29 | def duplicate(tensor, *size): 30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape) 31 | 32 | 33 | def unsqueeze_like(tensor, template, start="left"): 34 | if start == "left": 35 | tensor_dim = tensor.dim() 36 | template_dim = template.dim() 37 | assert tensor.shape == template.shape[:tensor_dim] 38 | return tensor.view(*tensor.shape, *([1] * (template_dim - tensor_dim))) 39 | elif start == "right": 40 | tensor_dim = tensor.dim() 41 | template_dim = template.dim() 42 | assert tensor.shape == template.shape[-tensor_dim:] 43 | return tensor.view(*([1] * (template_dim - tensor_dim)), *tensor.shape) 44 | else: 45 | raise ValueError 46 | 47 | 48 | def logsumexp(tensor, dim, keepdim=False): 49 | # the logsumexp of pytorch is not stable! 50 | tensor_max, _ = tensor.max(dim=dim, keepdim=True) 51 | ret = (tensor - tensor_max).exp().sum(dim=dim, keepdim=True).log() + tensor_max 52 | if not keepdim: 53 | ret.squeeze_(dim=dim) 54 | return ret 55 | 56 | 57 | def approx_standard_normal_cdf(x): 58 | """ 59 | A fast approximation of the cumulative distribution function of the 60 | standard normal. 61 | """ 62 | return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) 63 | 64 | 65 | def log_discretized_normal(x, mu, var): # element-wise 66 | centered_x = x - mu 67 | std = var ** 0.5 68 | left = (centered_x - 1. / 255) / std 69 | right = (centered_x + 1. / 255) / std 70 | 71 | cdf_right = approx_standard_normal_cdf(right) 72 | cdf_left = approx_standard_normal_cdf(left) 73 | cdf_delta = cdf_right - cdf_left 74 | 75 | return torch.where( 76 | x < -0.999, 77 | cdf_right.clamp(min=1e-12).log(), 78 | torch.where(x > 0.999, (1. - cdf_left).clamp(min=1e-12).log(), cdf_delta.clamp(min=1e-12).log()), 79 | ) 80 | 81 | 82 | def binary_cross_entropy_with_logits(logits, inputs): 83 | r""" -inputs * log (sigmoid(logits)) - (1 - inputs) * log (1 - sigmoid(logits)) element wise 84 | with automatically expand dimensions 85 | """ 86 | if inputs.dim() < logits.dim(): 87 | inputs = inputs.expand_as(logits) 88 | else: 89 | logits = logits.expand_as(inputs) 90 | return F.binary_cross_entropy_with_logits(logits, inputs, reduction="none") 91 | 92 | 93 | def log_bernoulli(inputs, logits, n_data_dim): 94 | return -binary_cross_entropy_with_logits(logits, inputs).flatten(-n_data_dim).sum(dim=-1) 95 | 96 | 97 | def kl_between_normal(mu_0, var_0, mu_1, var_1): # element-wise 98 | tensor = None 99 | for obj in (mu_0, var_0, mu_1, var_1): 100 | if isinstance(obj, torch.Tensor): 101 | tensor = obj 102 | break 103 | assert tensor is not None 104 | 105 | var_0, var_1 = [ 106 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 107 | for x in (var_0, var_1) 108 | ] 109 | 110 | return 0.5 * (var_0 / var_1 + (mu_0 - mu_1).pow(2) / var_1 + var_1.log() - var_0.log() - 1.) 111 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/celeba_lsun_codes/core/inference/__init__.py -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/ll/elbo.py: -------------------------------------------------------------------------------- 1 | import core.func as func 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | from core.inference.utils import _x_0_pred, _choice_steps, _report_statistics 6 | import logging 7 | 8 | 9 | @ torch.no_grad() 10 | def nelbo_naive_ddpm(x_0, betas, small_sigma, clip_denoise, rescale_timesteps, eps_model=None, d_model=None, sample_steps=None): 11 | assert (eps_model is None and d_model is not None) or (eps_model is not None and d_model is None) 12 | assert isinstance(betas, np.ndarray) and betas[0] == 0 13 | N = len(betas) - 1 14 | sample_steps = sample_steps or N 15 | ns = _choice_steps(N, sample_steps, 'linear') 16 | assert ns[0] == 1 and ns[-1] == N and len(ns) == sample_steps 17 | alphas = 1. - betas 18 | cum_alphas = alphas.cumprod() 19 | cum_betas = 1. - cum_alphas 20 | 21 | logging.info("nelbo_naive_ddpm with {}, small_sigma={}, clip_denoise={}, sample_steps={}" 22 | .format("eps_model" if eps_model is not None else "d_model", small_sigma, clip_denoise, sample_steps)) 23 | 24 | nelbo = torch.zeros(x_0.size(0), device=x_0.device) 25 | rev_terms = [] 26 | 27 | mu_q = cum_alphas[N] ** 0.5 * x_0 28 | var_q = cum_betas[N] 29 | mu_p = torch.zeros_like(mu_q) 30 | var_p = 1. 31 | term = func.kl_between_normal(mu_q, var_q, mu_p, var_p).flatten(1).sum(1) 32 | nelbo += term 33 | rev_terms.append(term) 34 | 35 | for s, r in tqdm(list(zip([0] + ns, ns))[::-1]): 36 | skip_alpha = alphas[s + 1: r + 1].prod() 37 | skip_beta = 1. - skip_alpha 38 | cum_alpha_s, cum_alpha_r, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[r] 39 | eps = torch.randn_like(x_0) 40 | x_r = cum_alpha_r ** 0.5 * x_0 + cum_beta_r ** 0.5 * eps 41 | 42 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / (1. - cum_alpha_r) 43 | coeff2 = skip_alpha ** 0.5 * (1. - cum_alpha_s) / (1. - cum_alpha_r) 44 | x_0_pred, eps_pred = _x_0_pred(x_r, r, cum_alphas, rescale_timesteps, eps_model, d_model) 45 | if clip_denoise: 46 | x_0_pred = x_0_pred.clamp(-1., 1.) 47 | mu_p = coeff1 * x_0_pred + coeff2 * x_r 48 | 49 | # if small_sigma: 50 | # var_p = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) if s != 0 else var_p 51 | # else: 52 | # var_p = skip_beta 53 | 54 | if s != 0: 55 | var_p = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) if small_sigma else skip_beta 56 | else: 57 | var_p = _sigma2_small(1, 2, alphas, cum_alphas) 58 | 59 | if s != 0: 60 | mu_q = coeff1 * x_0 + coeff2 * x_r 61 | var_q = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) 62 | term = func.kl_between_normal(mu_q, var_q, mu_p, var_p).flatten(1).sum(1) 63 | else: 64 | term = -func.log_discretized_normal(x_0, mu_p, var_p).flatten(1).sum(1) 65 | nelbo += term 66 | rev_terms.append(term) 67 | 68 | return nelbo, rev_terms[::-1] 69 | 70 | 71 | def _sigma2_small(s, r, alphas, cum_alphas): 72 | skip_alpha = alphas[s + 1: r + 1].prod() 73 | skip_beta = 1. - skip_alpha 74 | cum_alpha_s, cum_alpha_r = cum_alphas[s], cum_alphas[r] 75 | sigma2_small = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) 76 | return sigma2_small 77 | 78 | 79 | @ torch.no_grad() 80 | def nelbo_ms_eps_ddpm(x_0, betas, rescale_timesteps, steps_type='linear', eps_model=None, ms_eps=None, sample_steps=None): 81 | assert eps_model is not None and ms_eps is not None 82 | assert isinstance(betas, np.ndarray) and betas[0] == 0 83 | N = len(betas) - 1 84 | sample_steps = sample_steps or N 85 | ns = _choice_steps(N, sample_steps, steps_type, ms_eps=ms_eps, betas=betas) 86 | assert ns[0] == 1 and ns[-1] == N and len(ns) == sample_steps 87 | alphas = 1. - betas 88 | cum_alphas = alphas.cumprod() 89 | cum_betas = 1. - cum_alphas 90 | 91 | logging.info("nelbo_ms_eps_ddpm with eps_model, sample_steps={}, steps_type={}".format(sample_steps, steps_type)) 92 | 93 | nelbo = torch.zeros(x_0.size(0), device=x_0.device) 94 | rev_terms = [] 95 | 96 | mu_q = cum_alphas[N] ** 0.5 * x_0 97 | var_q = cum_betas[N] 98 | mu_p = torch.zeros_like(mu_q) 99 | var_p = 1. 100 | term = func.kl_between_normal(mu_q, var_q, mu_p, var_p).flatten(1).sum(1) 101 | nelbo += term 102 | rev_terms.append(term) 103 | 104 | for s, r in list(zip([0] + ns, ns))[::-1]: 105 | statistics = {} 106 | skip_alpha = alphas[s + 1: r + 1].prod() 107 | skip_beta = 1. - skip_alpha 108 | cum_alpha_s, cum_alpha_r, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[r] 109 | statistics['skip_beta'] = skip_beta 110 | statistics['cum_beta_r'] = cum_beta_r 111 | statistics['cum_alpha_s'] = cum_alpha_s 112 | 113 | eps = torch.randn_like(x_0) 114 | x_r = cum_alpha_r ** 0.5 * x_0 + cum_beta_r ** 0.5 * eps 115 | 116 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / (1. - cum_alpha_r) 117 | coeff2 = skip_alpha ** 0.5 * (1. - cum_alpha_s) / (1. - cum_alpha_r) 118 | x_0_pred, eps_pred = _x_0_pred(x_r, r, cum_alphas, rescale_timesteps, eps_model) 119 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 120 | mu_p = coeff1 * x_0_pred_clamp + coeff2 * x_r 121 | 122 | if s != 0: 123 | sigma2_small = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) 124 | cov_x_0_pred = cum_beta_r / cum_alpha_r * (1. - ms_eps[r]) 125 | cov_x_0_pred_clamp = np.clip(cov_x_0_pred, 0., 1.) 126 | coeff_cov_x_0 = cum_alpha_s * skip_beta ** 2 / cum_beta_r ** 2 127 | offset = coeff_cov_x_0 * cov_x_0_pred_clamp 128 | var_p = sigma2_small + offset 129 | statistics['sigma2_small'] = sigma2_small 130 | statistics['cov_x_0'] = cov_x_0_pred.item() 131 | statistics['cov_x_0_clamp'] = cov_x_0_pred_clamp.item() 132 | statistics['coeff_cov_x_0'] = coeff_cov_x_0 133 | statistics['offset'] = offset.item() 134 | else: 135 | var_p = _sigma2_small(1, 2, alphas, cum_alphas) 136 | statistics['var_p'] = var_p.item() 137 | 138 | if s != 0: 139 | mu_q = coeff1 * x_0 + coeff2 * x_r 140 | var_q = skip_beta * (1. - cum_alpha_s) / (1. - cum_alpha_r) 141 | statistics['var_q'] = var_q 142 | term = func.kl_between_normal(mu_q, var_q, mu_p, var_p).flatten(1).sum(1) 143 | else: 144 | term = -func.log_discretized_normal(x_0, mu_p, var_p).flatten(1).sum(1) 145 | 146 | nelbo += term 147 | rev_terms.append(term) 148 | _report_statistics(s, r, statistics) 149 | 150 | return nelbo, rev_terms[::-1] 151 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/sampler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/celeba_lsun_codes/core/inference/sampler/__init__.py -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/sampler/reverse_ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | import math 5 | from core.inference.utils import _choice_steps, _x_0_pred, _report_statistics 6 | 7 | 8 | @ torch.no_grad() 9 | def reverse_ddim_naive(x_init, betas, rescale_timesteps, eta=0., steps_type='linear', eps_model=None, sample_steps=None): 10 | assert eps_model is not None 11 | assert isinstance(betas, np.ndarray) and betas[0] == 0 12 | N = len(betas) - 1 13 | sample_steps = sample_steps or N 14 | ns = _choice_steps(N, sample_steps, typ=steps_type) 15 | alphas = 1. - betas 16 | cum_alphas = alphas.cumprod() 17 | cum_betas = 1. - cum_alphas 18 | 19 | logging.info("reverse_ddim_naive with eps_model, rescale_timesteps={}, eta={}, sample_steps={}, steps_type={}" 20 | .format(rescale_timesteps, eta, sample_steps, steps_type)) 21 | 22 | x = x_init 23 | for s, r in list(zip([0] + ns, ns))[::-1]: 24 | statistics = {} 25 | skip_alpha = alphas[s + 1: r + 1].prod() 26 | skip_beta = 1. - skip_alpha 27 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 28 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 29 | lamb2 = eta ** 2 * sigma2_small 30 | statistics['skip_beta'] = skip_beta 31 | statistics['cum_beta_s'] = cum_beta_s 32 | statistics['cum_beta_r'] = cum_beta_r 33 | statistics['cum_alpha_s'] = cum_alpha_s 34 | 35 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model) 36 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 37 | coeff1 = cum_alpha_s ** 0.5 38 | coeff2 = (cum_beta_s - lamb2) ** 0.5 39 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * eps_pred 40 | if s != 0: 41 | sigma2 = lamb2 42 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 43 | statistics['sigma2'] = sigma2 44 | else: 45 | x = x_mean 46 | _report_statistics(s, r, statistics) 47 | return x 48 | 49 | 50 | @ torch.no_grad() 51 | def reverse_ddim_ms_eps(x_init, betas, rescale_timesteps, steps_type='linear', eta=0., eps_model=None, 52 | ms_eps=None, sample_steps=None, clip_sigma_idx=0, clip_pixel=2): 53 | assert eps_model is not None and ms_eps is not None 54 | assert isinstance(betas, np.ndarray) and betas[0] == 0 55 | N = len(betas) - 1 56 | sample_steps = sample_steps or N 57 | ns = _choice_steps(N, sample_steps, typ=steps_type, ms_eps=ms_eps, betas=betas) 58 | alphas = 1. - betas 59 | cum_alphas = alphas.cumprod() 60 | cum_betas = 1. - cum_alphas 61 | 62 | logging.info("reverse_ddim_ms_eps with eps_model, rescale_timesteps={}, eta={}, sample_steps={}, steps_type={}, clip_sigma_idx={}, clip_pixel={}" 63 | .format(rescale_timesteps, eta, sample_steps, steps_type, clip_sigma_idx, clip_pixel)) 64 | 65 | x = x_init 66 | for s, r in list(zip([0] + ns, ns))[::-1]: 67 | statistics = {} 68 | skip_alpha = alphas[s + 1: r + 1].prod() 69 | skip_beta = 1. - skip_alpha 70 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 71 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 72 | lamb2 = eta ** 2 * sigma2_small 73 | statistics['skip_beta'] = skip_beta 74 | statistics['cum_beta_s'] = cum_beta_s 75 | statistics['cum_beta_r'] = cum_beta_r 76 | statistics['cum_alpha_s'] = cum_alpha_s 77 | 78 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model) 79 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 80 | coeff1 = cum_alpha_s ** 0.5 81 | coeff2 = (cum_beta_s - lamb2) ** 0.5 82 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * eps_pred 83 | if s != 0: 84 | cov_x_0_pred = cum_beta_r / cum_alpha_r * (1. - ms_eps[r]) 85 | cov_x_0_pred_clamp = np.clip(cov_x_0_pred, 0., 1.) 86 | coeff_cov_x_0 = (cum_alpha_s ** 0.5 - ((cum_beta_s - lamb2) * cum_alpha_r / cum_beta_r) ** 0.5) ** 2 87 | offset = coeff_cov_x_0 * cov_x_0_pred_clamp 88 | sigma2 = lamb2 + offset 89 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 90 | statistics['sigma2_unclip'] = sigma2.item() 91 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 92 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 93 | statistics['sigma2_threshold'] = sigma2_threshold 94 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 95 | statistics['sigma2'] = sigma2 96 | else: 97 | x = x_mean 98 | _report_statistics(s, r, statistics) 99 | return x 100 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/sampler/reverse_ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | from core.inference.utils import _choice_steps, _x_0_pred, _report_statistics 5 | import math 6 | 7 | 8 | @ torch.no_grad() 9 | def reverse_ddpm_naive(x_init, betas, small_sigma, clip_denoise, rescale_timesteps, steps_type='linear', clip_sigma_idx=0, clip_pixel=2, 10 | eps_model=None, d_model=None, sample_steps=None): 11 | assert (eps_model is None and d_model is not None) or (eps_model is not None and d_model is None) 12 | assert isinstance(betas, np.ndarray) and betas[0] == 0 13 | N = len(betas) - 1 14 | sample_steps = sample_steps or N 15 | ns = _choice_steps(N, sample_steps, steps_type) 16 | alphas = 1. - betas 17 | cum_alphas = alphas.cumprod() 18 | cum_betas = 1. - cum_alphas 19 | 20 | logging.info("reverse_ddpm_naive with {}, small_sigma={}, clip_denoise={}, rescale_timesteps={}, " 21 | "sample_steps={}, steps_type={}, clip_sigma_idx={}, clip_pixel={}" 22 | .format("eps_model" if eps_model is not None else "d_model", 23 | small_sigma, clip_denoise, rescale_timesteps, sample_steps, steps_type, clip_sigma_idx, clip_pixel)) 24 | 25 | x = x_init 26 | for s, r in list(zip([0] + ns, ns))[::-1]: 27 | statistics = {} 28 | skip_alpha = alphas[s + 1: r + 1].prod() 29 | skip_beta = 1. - skip_alpha 30 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 31 | statistics['skip_beta'] = skip_beta 32 | statistics['cum_beta_s'] = cum_beta_s 33 | statistics['cum_beta_r'] = cum_beta_r 34 | statistics['cum_alpha_s'] = cum_alpha_s 35 | 36 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model, d_model) 37 | if clip_denoise: 38 | x_0_pred = x_0_pred.clamp(-1., 1.) 39 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / cum_beta_r 40 | coeff2 = skip_alpha ** 0.5 * cum_beta_s / cum_beta_r 41 | x_mean = coeff1 * x_0_pred + coeff2 * x 42 | if s != 0: 43 | sigma2 = skip_beta 44 | if small_sigma: 45 | sigma2 *= cum_beta_s / cum_beta_r 46 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 47 | statistics['sigma2_unclip'] = sigma2.item() 48 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 49 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 50 | statistics['sigma2_threshold'] = sigma2_threshold 51 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 52 | statistics['sigma2'] = sigma2 53 | else: 54 | x = x_mean 55 | _report_statistics(s, r, statistics) 56 | return x 57 | 58 | 59 | @ torch.no_grad() 60 | def reverse_ddpm_ms_eps(x_init, betas, rescale_timesteps, steps_type='linear', clip_sigma_idx=0, clip_pixel=2, eps_model=None, ms_eps=None, sample_steps=None): 61 | assert eps_model is not None and ms_eps is not None 62 | assert isinstance(betas, np.ndarray) and betas[0] == 0 63 | N = len(betas) - 1 64 | sample_steps = sample_steps or N 65 | ns = _choice_steps(N, sample_steps, steps_type, ms_eps=ms_eps, betas=betas) 66 | alphas = 1. - betas 67 | cum_alphas = alphas.cumprod() 68 | cum_betas = 1. - cum_alphas 69 | 70 | logging.info("reverse_ddpm_ms_eps with eps_model, rescale_timesteps={}, sample_steps={}, steps_type={}, clip_sigma_idx={}, clip_pixel={}" 71 | .format(rescale_timesteps, sample_steps, steps_type, clip_sigma_idx, clip_pixel)) 72 | 73 | x = x_init 74 | for s, r in list(zip([0] + ns, ns))[::-1]: 75 | statistics = {} 76 | skip_alpha = alphas[s + 1: r + 1].prod() 77 | skip_beta = 1. - skip_alpha 78 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 79 | statistics['skip_beta'] = skip_beta 80 | statistics['cum_beta_s'] = cum_beta_s 81 | statistics['cum_beta_r'] = cum_beta_r 82 | statistics['cum_alpha_s'] = cum_alpha_s 83 | 84 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model) 85 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 86 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / cum_beta_r 87 | coeff2 = skip_alpha ** 0.5 * cum_beta_s / cum_beta_r 88 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * x 89 | if s != 0: 90 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 91 | cov_x_0_pred = cum_beta_r / cum_alpha_r * (1. - ms_eps[r]) 92 | cov_x_0_pred_clamp = np.clip(cov_x_0_pred, 0., 1.) 93 | coeff_cov_x_0 = cum_alpha_s * skip_beta ** 2 / cum_beta_r ** 2 94 | offset = coeff_cov_x_0 * cov_x_0_pred_clamp 95 | sigma2 = sigma2_small + offset 96 | statistics['sigma2_small'] = sigma2_small 97 | statistics['cov_x_0'] = cov_x_0_pred.item() 98 | statistics['cov_x_0_clamp'] = cov_x_0_pred_clamp.item() 99 | statistics['coeff_cov_x_0'] = coeff_cov_x_0 100 | statistics['offset'] = offset.item() 101 | statistics['sigma2_small/offset'] = statistics['sigma2_small'] / statistics['offset'] 102 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 103 | statistics['sigma2_unclip'] = sigma2.item() 104 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 105 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 106 | statistics['sigma2_threshold'] = sigma2_threshold 107 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 108 | statistics['sigma2'] = sigma2 109 | else: 110 | x = x_mean 111 | _report_statistics(s, r, statistics) 112 | return x 113 | -------------------------------------------------------------------------------- /celeba_lsun_codes/core/inference/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | import math 5 | 6 | 7 | def _rescale_timesteps(n, N, flag): 8 | if flag: 9 | return n * 1000.0 / float(N) 10 | return n 11 | 12 | 13 | def _report_statistics(s, r, statistics): 14 | statistics_str = {k: "{:.5e}".format(v) for k, v in statistics.items()} 15 | logging.info("[(s, r): ({}, {})] [{}]".format(s, r, statistics_str)) 16 | 17 | 18 | def _x_0_pred(x, n, cum_alphas, rescale_timesteps, eps_model=None, d_model=None): # estimate of E[x_0|x_n] w.r.t. q 19 | N = len(cum_alphas) - 1 20 | cum_alpha_n = cum_alphas[n] 21 | cum_beta_n = 1. - cum_alpha_n 22 | if eps_model is not None: 23 | eps_pred = eps_model(x, _rescale_timesteps(torch.tensor([n] * x.size(0)).type_as(x), N, rescale_timesteps)) 24 | x_0_pred = cum_alpha_n ** -0.5 * x - (1. / cum_alpha_n - 1.) ** 0.5 * eps_pred 25 | else: 26 | x_0_pred = d_model(x, _rescale_timesteps(torch.tensor([n] * x.size(0)).type_as(x), N, rescale_timesteps)) 27 | eps_pred = - (cum_alpha_n / cum_beta_n) ** 0.5 * x_0_pred + (1. / cum_beta_n ** 0.5) * x 28 | return x_0_pred, eps_pred 29 | 30 | 31 | def _cov_x_0_pred(x, n, cum_alphas, cum_betas, rescale_timesteps, tau_model=None, eps_pred=None, kappa_model=None, x_0_pred=None): # estimate Cov[x_0|x_n] w.r.t. q 32 | N = len(cum_alphas) - 1 33 | cum_alpha_n, cum_beta_n = cum_alphas[n], cum_betas[n] 34 | if tau_model is not None: 35 | tau_pred = tau_model(x, _rescale_timesteps(torch.tensor([n] * x.size(0)).type_as(x), N, rescale_timesteps)) 36 | delta_pred = tau_pred - eps_pred.pow(2) 37 | cov_x_0_pred = cum_beta_n / cum_alpha_n * delta_pred 38 | else: 39 | x_0_2_pred = kappa_model(x, _rescale_timesteps(torch.tensor([n] * x.size(0)).type_as(x), N, rescale_timesteps)) 40 | cov_x_0_pred = x_0_2_pred - x_0_pred.pow(2) 41 | return cov_x_0_pred 42 | 43 | 44 | def _choice_steps_linear(N, sample_steps): 45 | assert sample_steps > 1 46 | frac_stride = (N - 1) / (sample_steps - 1) 47 | cur_idx = 1.0 48 | steps = [] 49 | for _ in range(sample_steps): 50 | steps.append(round(cur_idx)) 51 | cur_idx += frac_stride 52 | return steps 53 | 54 | 55 | def _choice_steps_linear_ddim(N, sample_steps): 56 | skip = N // sample_steps 57 | seq = list(range(1, N + 1, skip)) 58 | return seq 59 | 60 | 61 | def _choice_steps_quad_ddim(N, sample_steps): 62 | seq = np.linspace(0, np.sqrt(N * 0.8), sample_steps) ** 2 63 | seq = [int(s) + 1 for s in list(seq)] 64 | return seq 65 | 66 | 67 | def _split(ms_eps, N, K): 68 | idx_g1 = N + 1 69 | for n in range(1, N): # Theoretically, ms_eps <= 1. Remove points of poor estimation 70 | if ms_eps[n] > 1: 71 | idx_g1 = n 72 | break 73 | num_bad = 2 * (N - idx_g1 + 1) 74 | bad_ratio = num_bad / N 75 | 76 | N1 = N - num_bad 77 | K1 = math.ceil((1. - 0.8 * bad_ratio) * K) 78 | K2 = K - K1 79 | if K1 > N1: 80 | K1 = N1 81 | K2 = K - K1 82 | if K2 > num_bad: 83 | K2 = num_bad 84 | K1 = K - K2 85 | if num_bad > 0 and K2 == 0: 86 | K2 = 1 87 | K1 = K - K2 88 | assert num_bad <= N 89 | assert K1 <= N1 and K2 <= N - N1 90 | return K1, N1, K2, num_bad 91 | 92 | 93 | def _ms_score(ms_eps, betas): 94 | alphas = 1. - betas 95 | cum_alphas = alphas.cumprod() 96 | cum_betas = 1. - cum_alphas 97 | ms_score = np.zeros_like(ms_eps) 98 | ms_score[1:] = ms_eps[1:] / cum_betas[1:] 99 | return ms_score 100 | 101 | 102 | def _solve_fn_dp(fn, N, K): # F[st, ed] with 1 <= st < ed <= N, other elements is inf 103 | if N == K: 104 | return list(range(1, N + 1)) 105 | 106 | F = fn[: N + 1, : N + 1] 107 | 108 | C = np.full((K + 1, N + 1), float('inf')) # C[k, n] with 2 <= k <= K, k <= n <= N 109 | D = np.full((K + 1, N + 1), -1) # D[k, n] with 2 <= k <= K, k <= n <= N 110 | 111 | C[2, 2: N] = F[1, 2: N] 112 | D[2, 2: N] = 1 113 | 114 | for k in range(3, K + 1): 115 | # {C[k-1, s] + F[s, r]}_{0 <= s, r <= N} = {C[k-1, s] + F[s, r]}_{k-1 <= s < r <= N} 116 | tmp = C[k - 1, :].reshape(N + 1, 1) + F 117 | C[k, k: N + 1] = np.min(tmp, axis=0)[k: N + 1] 118 | D[k, k: N + 1] = np.argmin(tmp, axis=0)[k: N + 1] 119 | 120 | res = [N] 121 | n, k = N, K 122 | while k > 2: 123 | n = D[k, n] 124 | res.append(n) 125 | k -= 1 126 | res.append(1) 127 | return res[::-1] 128 | 129 | 130 | def _get_fn_m(ms_score, alphas, N): 131 | F = np.full((N + 1, N + 1), float('inf')) # F[st, ed] with 1 <= st < ed <= N 132 | for s in range(1, N + 1): 133 | skip_alphas = alphas[s + 1: N + 1].cumprod() 134 | skip_betas = 1. - skip_alphas 135 | before_log = 1. - skip_betas * ms_score[s + 1: N + 1] 136 | F[s, s + 1: N + 1] = np.log(before_log) 137 | return F 138 | 139 | 140 | def _dp_seg(ms_eps, betas, N, K): 141 | K1, N1, K2, num_bad = _split(ms_eps, N, K) 142 | 143 | alphas = 1. - betas 144 | ms_score = _ms_score(ms_eps, betas) 145 | F = _get_fn_m(ms_score, alphas, N1) 146 | 147 | steps1 = _solve_fn_dp(F, N1, K1) 148 | if K2 > 0: 149 | frac = (N - N1) / K2 150 | steps2 = [round(N - frac * k) for k in range(K2)][::-1] 151 | assert steps1[-1] < steps2[0] 152 | assert len(steps1) + len(steps2) == K 153 | assert steps1[0] == 1 and steps1[-1] == N1 154 | assert steps2[-1] == N 155 | else: 156 | steps2 = [] 157 | steps = steps1 + steps2 158 | assert steps[0] == 1 and steps[-1] == N 159 | assert all(steps[i] < steps[i + 1] for i in range(len(steps) - 1)) 160 | return steps 161 | 162 | 163 | def _choice_steps(N, sample_steps, typ, ms_eps=None, betas=None): 164 | if typ == 'linear': 165 | steps = _choice_steps_linear(N, sample_steps) 166 | elif typ == 'linear_ddim': 167 | steps = _choice_steps_linear_ddim(N, sample_steps) 168 | elif typ == 'quad_ddim': 169 | steps = _choice_steps_quad_ddim(N, sample_steps) 170 | elif typ == 'dp_seg': 171 | steps = _dp_seg(ms_eps, betas, N, sample_steps) 172 | else: 173 | raise NotImplementedError 174 | 175 | assert len(steps) == sample_steps and steps[0] == 1 176 | if typ not in ["linear_ddim", "quad_ddim"]: 177 | assert steps[-1] == N 178 | 179 | return steps 180 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import Cifar10 2 | from .other_dst import get_dataset 3 | 4 | 5 | def get_test_dataset(name): 6 | if name == 'cifar10': 7 | return Cifar10("workspace/datasets/cifar10/").get_test_data(False) 8 | elif name == 'celeba': 9 | return get_dataset('celeba')[1] 10 | 11 | 12 | def get_train_dataset(name): 13 | if name == 'cifar10': 14 | return Cifar10("workspace/datasets/cifar10/").get_train_val_data(False) 15 | elif name == 'celeba': 16 | return get_dataset('celeba')[0] 17 | elif name == 'lsun_bedroom': 18 | return get_dataset('lsun_bedroom')[0] 19 | elif name == 'lsun_church': 20 | return get_dataset('lsun_church')[0] 21 | else: 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from torchvision import datasets 3 | import torchvision.transforms as transforms 4 | from .dataset_factory import DatasetFactory 5 | from .utils import * 6 | 7 | 8 | class Cifar10(DatasetFactory): 9 | r""" Cifar10 dataset 10 | 11 | Information of the raw dataset: 12 | train: 40,000 13 | val: 10,000 14 | test: 10,000 15 | shape: 3 * 32 * 32 16 | """ 17 | 18 | def __init__(self, data_path, gauss_noise=False, noise_std=0.01): 19 | super(Cifar10, self).__init__() 20 | self.data_path = data_path 21 | self.gauss_noise = gauss_noise 22 | self.noise_std = noise_std 23 | 24 | _transform = [transforms.ToTensor()] 25 | if self.gauss_noise: 26 | _transform.append(AddGaussNoise(self.noise_std)) 27 | im_transform = transforms.Compose(_transform) 28 | self.train_val = datasets.CIFAR10(self.data_path, train=True, transform=im_transform, download=True) 29 | self.train = Subset(self.train_val, list(range(40000))) 30 | self.val = Subset(self.train_val, list(range(40000, 50000))) 31 | self.test = datasets.CIFAR10(self.data_path, train=False, transform=im_transform, download=True) 32 | 33 | def affine_transform(self, dataset): 34 | return StandardizedDataset(dataset, mean=0.5, std=0.5) # scale to [-1, 1] 35 | 36 | def preprocess(self, v): 37 | return 2. * (v - 0.5) 38 | 39 | def unpreprocess(self, v): 40 | v = 0.5 * (v + 1.) 41 | v.clamp_(0., 1.) 42 | return v 43 | 44 | @property 45 | def data_shape(self): 46 | return 3, 32, 32 47 | 48 | @property 49 | def fid_stat(self): 50 | return 'workspace/fid_stats/fid_stats_cifar10_train_pytorch.npz' 51 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | from .utils import is_labelled, UnlabeledDataset 2 | from torch.utils.data import ConcatDataset 3 | import numpy as np 4 | 5 | 6 | class DatasetFactory(object): 7 | r""" Output dataset after two transformations to the raw data: 8 | 1. distribution transform (e.g. binarized, adding noise), often irreversible, a part of which is implemented 9 | in distribution_transform 10 | 2. an affine transform (preprocess), which is bijective 11 | """ 12 | 13 | def __init__(self): 14 | self.train = None 15 | self.val = None 16 | self.test = None 17 | 18 | def allow_labelled(self): 19 | return is_labelled(self.train) 20 | 21 | def get_data(self, dataset, labelled): 22 | assert not (not is_labelled(dataset) and labelled) 23 | if is_labelled(dataset) and not labelled: 24 | dataset = UnlabeledDataset(dataset) 25 | return self.affine_transform(self.distribution_transform(dataset)) 26 | 27 | def get_train_data(self, labelled=False): 28 | return self.get_data(self.train, labelled=labelled) 29 | 30 | def get_val_data(self, labelled=False): 31 | return self.get_data(self.val, labelled=labelled) 32 | 33 | def get_train_val_data(self, labelled=False): 34 | train_val = ConcatDataset([self.train, self.val]) 35 | return self.get_data(train_val, labelled=labelled) 36 | 37 | def get_test_data(self, labelled=False): 38 | return self.get_data(self.test, labelled=labelled) 39 | 40 | def distribution_transform(self, dataset): 41 | return dataset 42 | 43 | def affine_transform(self, dataset): 44 | return dataset 45 | 46 | def preprocess(self, v): 47 | r""" The mathematical form of the affine transform 48 | """ 49 | return v 50 | 51 | def unpreprocess(self, v): 52 | r""" The mathematical form of the affine transform's inverse 53 | """ 54 | return v 55 | 56 | @property 57 | def data_shape(self): 58 | raise NotImplementedError 59 | 60 | @property 61 | def data_dim(self): 62 | return int(np.prod(self.data_shape)) 63 | 64 | @property 65 | def fid_stat(self): 66 | return None 67 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from torchvision.datasets import CIFAR10 7 | from .celeba import CelebA 8 | # from .ffhq import FFHQ 9 | from .lsun import LSUN 10 | from torch.utils.data import Subset 11 | import numpy as np 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class Crop(object): 16 | def __init__(self, x1, x2, y1, y2): 17 | self.x1 = x1 18 | self.x2 = x2 19 | self.y1 = y1 20 | self.y2 = y2 21 | 22 | def __call__(self, img): 23 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 27 | self.x1, self.x2, self.y1, self.y2 28 | ) 29 | 30 | 31 | class StandardizedDataset(Dataset): 32 | def __init__(self, dataset, mean, std): 33 | self.dataset = dataset 34 | self.mean = mean 35 | self.std = std 36 | self.std_inv = 1. / std 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | 41 | def __getitem__(self, item): 42 | x = self.dataset[item] 43 | return self.std_inv * (x - self.mean) 44 | 45 | 46 | def get_dataset(dataset): 47 | 48 | if dataset == "celeba": 49 | cx = 89 50 | cy = 121 51 | x1 = cy - 64 52 | x2 = cy + 64 53 | y1 = cx - 64 54 | y2 = cx + 64 55 | 56 | dataset = CelebA( 57 | root="workspace/datasets/celeba/", 58 | split="train", 59 | transform=transforms.Compose( 60 | [ 61 | Crop(x1, x2, y1, y2), 62 | transforms.Resize(64), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | ] 66 | ), 67 | download=True, 68 | ) 69 | 70 | test_dataset = CelebA( 71 | root=os.path.join("workspace/datasets/celeba/"), 72 | split="test", 73 | transform=transforms.Compose( 74 | [ 75 | Crop(x1, x2, y1, y2), 76 | transforms.Resize(64), 77 | transforms.ToTensor(), 78 | ] 79 | ), 80 | download=True, 81 | ) 82 | 83 | elif dataset == "lsun_bedroom": 84 | # if config.data.random_flip: 85 | # dataset = LSUN( 86 | # root="workspace/datasets/lsun_bedroom", 87 | # classes=["bedroom_train"], 88 | # transform=transforms.Compose( 89 | # [ 90 | # transforms.Resize(256), 91 | # transforms.CenterCrop(256), 92 | # transforms.RandomHorizontalFlip(p=0.5), 93 | # transforms.ToTensor(), 94 | # ] 95 | # ), 96 | # ) 97 | # else: 98 | dataset = LSUN( 99 | root="workspace/datasets/lsun_bedroom", 100 | classes=["bedroom_train"], 101 | transform=transforms.Compose( 102 | [ 103 | transforms.Resize(256), 104 | transforms.CenterCrop(256), 105 | transforms.ToTensor(), 106 | ] 107 | ), 108 | ) 109 | 110 | test_dataset = None 111 | 112 | # test_dataset = LSUN( 113 | # root="workspace/datasets/lsun_bedroom", 114 | # classes=["bedroom_val"], 115 | # transform=transforms.Compose( 116 | # [ 117 | # transforms.Resize(256), 118 | # transforms.CenterCrop(256), 119 | # transforms.ToTensor(), 120 | # ] 121 | # ), 122 | # ) 123 | 124 | elif dataset == "lsun_church": 125 | dataset = LSUN( 126 | root="workspace/datasets/lsun_church", 127 | classes=["church_outdoor_train"], 128 | transform=transforms.Compose( 129 | [ 130 | transforms.Resize(256), 131 | transforms.CenterCrop(256), 132 | transforms.ToTensor(), 133 | ] 134 | ), 135 | ) 136 | test_dataset = None 137 | 138 | elif dataset == "FFHQ": 139 | if config.data.random_flip: 140 | dataset = FFHQ( 141 | path=os.path.join(args.exp, "datasets", "FFHQ"), 142 | transform=transforms.Compose( 143 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 144 | ), 145 | resolution=config.data.image_size, 146 | ) 147 | else: 148 | dataset = FFHQ( 149 | path=os.path.join(args.exp, "datasets", "FFHQ"), 150 | transform=transforms.ToTensor(), 151 | resolution=config.data.image_size, 152 | ) 153 | 154 | num_items = len(dataset) 155 | indices = list(range(num_items)) 156 | random_state = np.random.get_state() 157 | np.random.seed(2019) 158 | np.random.shuffle(indices) 159 | np.random.set_state(random_state) 160 | train_indices, test_indices = ( 161 | indices[: int(num_items * 0.9)], 162 | indices[int(num_items * 0.9) :], 163 | ) 164 | test_dataset = Subset(dataset, test_indices) 165 | dataset = Subset(dataset, train_indices) 166 | else: 167 | dataset, test_dataset = None, None 168 | 169 | return StandardizedDataset(dataset, mean=0.5, std=0.5), StandardizedDataset(test_dataset, mean=0.5, std=0.5) 170 | 171 | 172 | def logit_transform(image, lam=1e-6): 173 | image = lam + (1 - 2 * lam) * image 174 | return torch.log(image) - torch.log1p(-image) 175 | 176 | 177 | def data_transform(config, X): 178 | if config.data.uniform_dequantization: 179 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 180 | if config.data.gaussian_dequantization: 181 | X = X + torch.randn_like(X) * 0.01 182 | 183 | if config.data.rescaled: 184 | X = 2 * X - 1.0 185 | elif config.data.logit_transform: 186 | X = logit_transform(X) 187 | 188 | if hasattr(config, "image_mean"): 189 | return X - config.image_mean.to(X.device)[None, ...] 190 | 191 | return X 192 | 193 | 194 | def inverse_data_transform(config, X): 195 | if hasattr(config, "image_mean"): 196 | X = X + config.image_mean.to(X.device)[None, ...] 197 | 198 | if config.data.logit_transform: 199 | X = torch.sigmoid(X) 200 | elif config.data.rescaled: 201 | X = (X + 1.0) / 2.0 202 | 203 | return torch.clamp(X, 0.0, 1.0) 204 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | # target = [] 137 | # for t in self.target_type: 138 | # if t == "attr": 139 | # target.append(self.attr[index, :]) 140 | # elif t == "identity": 141 | # target.append(self.identity[index, 0]) 142 | # elif t == "bbox": 143 | # target.append(self.bbox[index, :]) 144 | # elif t == "landmarks": 145 | # target.append(self.landmarks_align[index, :]) 146 | # else: 147 | # raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | # target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | # if self.target_transform is not None: 154 | # target = self.target_transform(target) 155 | 156 | return X #, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/lsun.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import io 6 | from collections.abc import Iterable 7 | import pickle 8 | import torch 9 | 10 | 11 | def iterable_to_str(iterable: Iterable) -> str: 12 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 13 | 14 | 15 | def verify_str_arg( 16 | value, arg=None, valid_values=None, custom_msg=None, 17 | ): 18 | if not isinstance(value, torch._six.string_classes): 19 | if arg is None: 20 | msg = "Expected type str, but got type {type}." 21 | else: 22 | msg = "Expected type str for argument {arg}, but got type {type}." 23 | msg = msg.format(type=type(value), arg=arg) 24 | raise ValueError(msg) 25 | 26 | if valid_values is None: 27 | return value 28 | 29 | if value not in valid_values: 30 | if custom_msg is not None: 31 | msg = custom_msg 32 | else: 33 | msg = ("Unknown value '{value}' for argument {arg}. " 34 | "Valid values are {{{valid_values}}}.") 35 | msg = msg.format(value=value, arg=arg, 36 | valid_values=iterable_to_str(valid_values)) 37 | raise ValueError(msg) 38 | 39 | return value 40 | 41 | 42 | class LSUNClass(VisionDataset): 43 | def __init__(self, root, transform=None, target_transform=None): 44 | import lmdb 45 | 46 | super(LSUNClass, self).__init__( 47 | root, transform=transform, target_transform=target_transform 48 | ) 49 | 50 | self.env = lmdb.open( 51 | root, 52 | max_readers=1, 53 | readonly=True, 54 | lock=False, 55 | readahead=False, 56 | meminit=False, 57 | ) 58 | with self.env.begin(write=False) as txn: 59 | self.length = txn.stat()["entries"] 60 | root_split = root.split("/") 61 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 62 | if os.path.isfile(cache_file): 63 | self.keys = pickle.load(open(cache_file, "rb")) 64 | else: 65 | with self.env.begin(write=False) as txn: 66 | self.keys = [key for key, _ in txn.cursor()] 67 | pickle.dump(self.keys, open(cache_file, "wb")) 68 | 69 | def __getitem__(self, index): 70 | img, target = None, None 71 | env = self.env 72 | with env.begin(write=False) as txn: 73 | imgbuf = txn.get(self.keys[index]) 74 | 75 | buf = io.BytesIO() 76 | buf.write(imgbuf) 77 | buf.seek(0) 78 | img = Image.open(buf).convert("RGB") 79 | 80 | if self.transform is not None: 81 | img = self.transform(img) 82 | 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | 86 | return img, target 87 | 88 | def __len__(self): 89 | return self.length 90 | 91 | 92 | class LSUN(VisionDataset): 93 | """ 94 | `LSUN `_ dataset. 95 | 96 | Args: 97 | root (string): Root directory for the database files. 98 | classes (string or list): One of {'train', 'val', 'test'} or a list of 99 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 100 | transform (callable, optional): A function/transform that takes in an PIL image 101 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 102 | target_transform (callable, optional): A function/transform that takes in the 103 | target and transforms it. 104 | """ 105 | 106 | def __init__(self, root, classes="train", transform=None, target_transform=None): 107 | super(LSUN, self).__init__( 108 | root, transform=transform, target_transform=target_transform 109 | ) 110 | self.classes = self._verify_classes(classes) 111 | 112 | # for each class, create an LSUNClassDataset 113 | self.dbs = [] 114 | for c in self.classes: 115 | self.dbs.append( 116 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 117 | ) 118 | 119 | self.indices = [] 120 | count = 0 121 | for db in self.dbs: 122 | count += len(db) 123 | self.indices.append(count) 124 | 125 | self.length = count 126 | 127 | def _verify_classes(self, classes): 128 | categories = [ 129 | "bedroom", 130 | "bridge", 131 | "church_outdoor", 132 | "classroom", 133 | "conference_room", 134 | "dining_room", 135 | "kitchen", 136 | "living_room", 137 | "restaurant", 138 | "tower", 139 | ] 140 | dset_opts = ["train", "val", "test"] 141 | 142 | try: 143 | verify_str_arg(classes, "classes", dset_opts) 144 | if classes == "test": 145 | classes = [classes] 146 | else: 147 | classes = [c + "_" + classes for c in categories] 148 | except ValueError: 149 | if not isinstance(classes, Iterable): 150 | msg = ( 151 | "Expected type str or Iterable for argument classes, " 152 | "but got type {}." 153 | ) 154 | raise ValueError(msg.format(type(classes))) 155 | 156 | classes = list(classes) 157 | msg_fmtstr = ( 158 | "Expected type str for elements in argument classes, " 159 | "but got type {}." 160 | ) 161 | for c in classes: 162 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 163 | c_short = c.split("_") 164 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 165 | 166 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 167 | msg = msg_fmtstr.format( 168 | category, "LSUN class", iterable_to_str(categories) 169 | ) 170 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 171 | 172 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 173 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 174 | 175 | return classes 176 | 177 | def __getitem__(self, index): 178 | """ 179 | Args: 180 | index (int): Index 181 | 182 | Returns: 183 | tuple: Tuple (image, target) where target is the index of the target category. 184 | """ 185 | target = 0 186 | sub = 0 187 | for ind in self.indices: 188 | if index < ind: 189 | break 190 | target += 1 191 | sub = ind 192 | 193 | db = self.dbs[target] 194 | index = index - sub 195 | 196 | if self.target_transform is not None: 197 | target = self.target_transform(target) 198 | 199 | img, _ = db[index] 200 | return img #, target 201 | 202 | def __len__(self): 203 | return self.length 204 | 205 | def extra_repr(self): 206 | return "Classes: {classes}".format(**self.__dict__) 207 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/other_dst/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | def pad22pow(a): 6 | assert a % 2 == 0 7 | bits = a.bit_length() 8 | ub = 2 ** bits 9 | pad = (ub - a) // 2 10 | return pad, ub 11 | 12 | 13 | def is_labelled(dataset): 14 | labelled = False 15 | if isinstance(dataset[0], tuple) and len(dataset[0]) == 2: 16 | labelled = True 17 | return labelled 18 | 19 | 20 | class AddGaussNoise(object): 21 | def __init__(self, std): 22 | self.std = std 23 | 24 | def __call__(self, tensor): 25 | return tensor + self.std * torch.rand_like(tensor).to(tensor.device) 26 | 27 | 28 | class UnlabeledDataset(Dataset): 29 | def __init__(self, dataset): 30 | self.dataset = dataset 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, item): 36 | x, y = self.dataset[item] 37 | return x 38 | 39 | 40 | class StandardizedDataset(Dataset): 41 | def __init__(self, dataset, mean, std): 42 | self.dataset = dataset 43 | self.mean = mean 44 | self.std = std 45 | self.std_inv = 1. / std 46 | self.labelled = is_labelled(dataset) 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | def __getitem__(self, item): 52 | if self.labelled: 53 | x, y = self.dataset[item] 54 | return self.std_inv * (x - self.mean), y 55 | else: 56 | x = self.dataset[item] 57 | return self.std_inv * (x - self.mean) 58 | 59 | 60 | class QuickDataset(Dataset): 61 | def __init__(self, array): 62 | self.array = array 63 | 64 | def __len__(self): 65 | return len(self.array) 66 | 67 | def __getitem__(self, item): 68 | return self.array[item] 69 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/task_schedule.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | from typing import List 3 | import time 4 | from typing import Union, Tuple 5 | import os 6 | 7 | 8 | def get_gpu_memory_map(): 9 | raw = list(os.popen('nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader')) 10 | mem = [int(x.strip()) for x in raw] 11 | return dict(zip(range(len(mem)), mem)) 12 | 13 | 14 | def get_gpu_total_memory_map(): 15 | raw = list(os.popen('nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader')) 16 | mem = [int(x.strip()) for x in raw] 17 | return dict(zip(range(len(mem)), mem)) 18 | 19 | 20 | def gpu_memory_consumption(): 21 | devices = list(map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(','))) 22 | gpu_memory_map = get_gpu_memory_map() 23 | gpu_total_memory_map = get_gpu_total_memory_map() 24 | return sum([gpu_total_memory_map[device] - gpu_memory_map[device] for device in devices]) 25 | 26 | 27 | def available_devices(threshold=10000) -> List[int]: 28 | gpu_memory_map = get_gpu_memory_map() 29 | devices = [] 30 | for idx, mem in gpu_memory_map.items(): 31 | if mem > threshold: 32 | devices.append(idx) 33 | return devices 34 | 35 | 36 | def format_devices(devices: Union[int, List[int], Tuple[int]]): 37 | if isinstance(devices, int): 38 | return "{}".format(devices) 39 | elif isinstance(devices, tuple) or isinstance(devices, list): 40 | return ','.join(map(str, devices)) 41 | 42 | 43 | class Task(object): 44 | def __init__(self, process: Process, n_devices: int = 1): 45 | self.process = process 46 | self.n_devices = n_devices 47 | self.devices = None 48 | self.just_created = True 49 | 50 | def state(self): 51 | if self.just_created: 52 | return 'just_created' 53 | elif self.process.is_alive(): 54 | return 'is_alive' 55 | else: 56 | return 'finished' 57 | 58 | def start(self, devices): 59 | self.devices = devices 60 | os.environ["CUDA_VISIBLE_DEVICES"] = format_devices(devices) 61 | self.process.start() 62 | self.just_created = False 63 | 64 | 65 | class DevicesPool(object): 66 | def __init__(self, devices: List[int]): 67 | self.devices = devices.copy() 68 | 69 | def flow_out(self, n_devices: int): 70 | if len(self.devices) < n_devices: 71 | return None 72 | ret = [] 73 | for _ in range(n_devices): 74 | ret.append(self.devices.pop()) 75 | return ret 76 | 77 | def flow_in(self, devices: List[int]): 78 | for device in devices: 79 | self.devices.append(device) 80 | 81 | 82 | ################################################################################ 83 | # Run multiple tasks run in parallel, exclusively using devices 84 | # Suitable for running tasks consuming high gpu memory 85 | ################################################################################ 86 | 87 | def wait_schedule(tasks: List[Task], devices: List[int]): 88 | # assert len(set(devices)) == len(devices) 89 | for task in tasks: 90 | assert task.n_devices <= len(devices) 91 | tasks = sorted(tasks, key=lambda x: x.n_devices, reverse=True) 92 | devices_pool = DevicesPool(devices) 93 | 94 | def linked_list_next(_idx: int, _lst: List): 95 | if _lst: 96 | return (_idx + 1) % len(_lst) 97 | else: 98 | return -1 99 | 100 | idx = 0 101 | while tasks: 102 | task = tasks[idx] 103 | state = task.state() 104 | if state == 'just_created': 105 | devices = devices_pool.flow_out(task.n_devices) 106 | if devices is not None: 107 | print("\033[1m start a task with {} devices".format(len(devices))) 108 | task.start(devices) 109 | elif state == 'finished': 110 | print("\033[1m a task with {} devices finished".format(len(task.devices))) 111 | devices_pool.flow_in(task.devices) 112 | task.process.close() 113 | tasks.pop(idx) 114 | idx -= 1 115 | idx = linked_list_next(idx, tasks) 116 | time.sleep(1) 117 | -------------------------------------------------------------------------------- /celeba_lsun_codes/interface/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from PIL import Image 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | 9 | def global_device() -> torch.device: 10 | return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 11 | 12 | 13 | def set_logger(fname): 14 | os.makedirs(os.path.split(fname)[0], exist_ok=True) 15 | logger = logging.getLogger() 16 | logger.setLevel(level=logging.INFO) 17 | handler1 = logging.StreamHandler() 18 | handler2 = logging.FileHandler(fname, mode='w') 19 | formatter = logging.Formatter('%(asctime)s - %(message)s') 20 | handler1.setFormatter(formatter) 21 | handler2.setFormatter(formatter) 22 | logger.addHandler(handler1) 23 | logger.addHandler(handler2) 24 | 25 | 26 | def set_seed(seed: int): 27 | if seed is not None: 28 | torch.manual_seed(seed) 29 | np.random.seed(seed) 30 | 31 | 32 | def cnt_png(path): 33 | png_files = filter(lambda x: x.endswith(".png"), os.listdir(path)) 34 | return len(list(png_files)) 35 | 36 | 37 | def amortize(n_samples, batch_size): 38 | k = n_samples // batch_size 39 | r = n_samples % batch_size 40 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 41 | 42 | 43 | def sample2dir(path, n_samples, batch_size, sample_fn, unpreprocess_fn=None, persist=True): 44 | os.makedirs(path, exist_ok=True) 45 | idx = n_png = cnt_png(path) if persist else 0 46 | for _batch_size in amortize(n_samples - n_png, batch_size): 47 | samples = sample_fn(_batch_size) 48 | samples = unpreprocess_fn(samples) 49 | for sample in samples: 50 | Image.fromarray(sample).save(os.path.join(path, "{}.png".format(idx))) 51 | idx += 1 52 | 53 | 54 | def score_on_dataset(dataset: Dataset, score_fn, batch_size): 55 | r""" 56 | Args: 57 | dataset: an instance of Dataset 58 | score_fn: a batch of data -> a batch of scalars 59 | batch_size: the batch size 60 | """ 61 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 62 | total_score = None 63 | tuple_output = None 64 | dataloader = DataLoader(dataset, batch_size=batch_size) 65 | for idx, v in enumerate(dataloader): 66 | v = v.to(device) 67 | score = score_fn(v) 68 | if idx == 0: 69 | tuple_output = isinstance(score, tuple) 70 | total_score = (0.,) * len(score) if tuple_output else 0. 71 | if tuple_output: 72 | total_score = tuple([a + b.sum().detach().item() for a, b in zip(total_score, score)]) 73 | else: 74 | total_score += score.sum().detach().item() 75 | if tuple_output: 76 | mean_score = tuple([a / len(dataset) for a in total_score]) 77 | else: 78 | mean_score = total_score / len(dataset) 79 | return mean_score 80 | -------------------------------------------------------------------------------- /celeba_lsun_codes/pytorch_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_diffusion.diffusion import Diffusion 2 | from pytorch_diffusion.model import Model 3 | -------------------------------------------------------------------------------- /celeba_lsun_codes/pytorch_diffusion/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | def download(url, local_path, chunk_size = 1024): 37 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 38 | with requests.get(url, stream = True) as r: 39 | total_size = int(r.headers.get("content-length", 0)) 40 | with tqdm(total = total_size, unit = "B", unit_scale = True) as pbar: 41 | with open(local_path, "wb") as f: 42 | for data in r.iter_content(chunk_size = chunk_size): 43 | if data: 44 | f.write(data) 45 | pbar.update(chunk_size) 46 | 47 | def md5_hash(path): 48 | with open(path, "rb") as f: 49 | content = f.read() 50 | return hashlib.md5(content).hexdigest() 51 | 52 | def get_ckpt_path(name, root=None, check=False): 53 | if name == 'ema_celeba': 54 | return 'workspace/ckpts/ema_celeba.pth' 55 | assert name in URL_MAP 56 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 57 | root = root if root is not None else os.path.join(cachedir, "diffusion_models_converted") 58 | path = os.path.join(root, CKPT_MAP[name]) 59 | if not os.path.exists(path) or (check and not md5_hash(path)==MD5_MAP[name]): 60 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 61 | download(URL_MAP[name], path) 62 | md5 = md5_hash(path) 63 | assert md5==MD5_MAP[name], md5 64 | return path 65 | -------------------------------------------------------------------------------- /celeba_lsun_codes/pytorch_diffusion/demo.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import time 3 | from pytorch_diffusion.diffusion import Diffusion 4 | 5 | 6 | class tqdm(object): 7 | """ 8 | tqdm-like progress bar for streamlit, adapted from 9 | https://github.com/streamlit/streamlit/issues/160#issuecomment-534385137 10 | """ 11 | def __init__(self, iterable, total=None, pbar=None): 12 | if pbar is None: 13 | pbar = st.empty() 14 | self.prog_bar = pbar 15 | self.prog_bar.progress(0) 16 | self.iterable = iterable 17 | self.length = total if total is not None else len(iterable) 18 | self.i = 0 19 | 20 | def __iter__(self): 21 | for obj in self.iterable: 22 | yield obj 23 | self.i += 1 24 | current_prog = self.i / self.length 25 | self.prog_bar.progress(current_prog) 26 | 27 | @st.cache(allow_output_mutation=True) 28 | def get_state(name, ema): 29 | if ema: 30 | name = "ema_"+name 31 | diffusion = Diffusion.from_pretrained(name) 32 | state = {"x": diffusion.denoise(1, n_steps=0), 33 | "curr_step": diffusion.num_timesteps, 34 | "diffusion": diffusion} 35 | return state 36 | 37 | def main(): 38 | st.title("Diffusion Model Demo") 39 | 40 | name = st.sidebar.radio("Model", ("cifar10", "lsun_bedroom", "lsun_cat", "lsun_church")) 41 | ema = st.sidebar.checkbox("ema", value=True) 42 | state = get_state(name, ema=ema) 43 | 44 | diffusion = state["diffusion"] 45 | st.text("Running {} model on {}".format(name, diffusion.device)) 46 | 47 | clip = st.sidebar.checkbox("clip outputs", value=True) 48 | show_x0 = st.sidebar.checkbox("show predicted x0 during denoising", value=False) 49 | 50 | n_steps = st.sidebar.number_input("Number of steps", 51 | min_value=1, 52 | max_value=diffusion.num_timesteps, 53 | value=diffusion.num_timesteps) 54 | 55 | pbar = st.sidebar.empty() 56 | pbar.progress(0) 57 | def tqdm_factory(*args, **kwargs): 58 | return tqdm(*args, **kwargs, pbar=pbar) 59 | 60 | output = st.empty() 61 | step = st.empty() 62 | 63 | def callback(x, i, x0=None): 64 | if show_x0 and x0 is not None: 65 | x = x0 66 | output.image(diffusion.torch2hwcuint8(x, clip=clip)[0]) 67 | step.text("Current step: {}".format(i)) 68 | callback(state["x"], state["curr_step"]) 69 | 70 | denoise = st.sidebar.button("Denoise") 71 | if state["curr_step"] > 0 and denoise: 72 | x = diffusion.denoise(1, 73 | n_steps=n_steps, x=state["x"], 74 | curr_step=state["curr_step"], 75 | progress_bar=tqdm_factory, 76 | callback=callback) 77 | state["x"] = x 78 | state["curr_step"] = max(0, state["curr_step"]-n_steps) 79 | 80 | diffuse = st.sidebar.button("Diffuse") 81 | if state["curr_step"] < diffusion.num_timesteps and diffuse: 82 | x = diffusion.diffuse(1, 83 | n_steps=n_steps, x=state["x"], 84 | curr_step=state["curr_step"], 85 | progress_bar=tqdm_factory, 86 | callback=callback) 87 | state["x"] = x 88 | state["curr_step"] = min(diffusion.num_timesteps, state["curr_step"]+n_steps) 89 | 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /celeba_lsun_codes/run_celeba_rep.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Process 3 | from interface.runner import run_sample, run_sample_ms_eps, run_sample_ddim, run_sample_ddim_ms_eps 4 | from interface.task_schedule import Task, wait_schedule, available_devices 5 | batch_size_per_card = 125 6 | n_devices = 8 7 | 8 | 9 | # run experiments with a slightly different implementation of ET used in DDIM 10 | def add_tasks(phase, tasks): 11 | steps_type = "linear_ddim" 12 | seed = 1234 13 | 14 | if phase == "sample_analytic_ddpm": 15 | clip_sigma_idx = 1 16 | clip_pixel = 2 17 | n_samples = 50000 18 | for sample_steps in [10, 20, 50, 100]: 19 | tag = "clip_sigma_idx_{}_clip_pixel_{}_n_samples_{}_steps_type_{}_sample_steps_{}" \ 20 | .format(clip_sigma_idx, clip_pixel, n_samples, steps_type, sample_steps) 21 | pretrained_model = "ema_celeba" 22 | root = os.path.join("samples", pretrained_model) 23 | profile = { 24 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 25 | "seed": seed, 26 | "pretrained_model": pretrained_model, 27 | "sample_steps": sample_steps, 28 | "batch_size": batch_size_per_card * n_devices, 29 | "path": os.path.join(root, tag), 30 | "n_samples": n_samples, 31 | "fid_stat": "workspace/fid_stats/fid_stats_celeba64_train_50000_ddim.npz", 32 | "ms_eps_path": "ms_eps/ema_celeba_10000.pth", 33 | "steps_type": steps_type, 34 | "clip_sigma_idx": clip_sigma_idx, 35 | "clip_pixel": clip_pixel, 36 | } 37 | p = Process(target=run_sample_ms_eps, args=(profile,)) 38 | tasks.append(Task(p, n_devices=n_devices)) 39 | 40 | elif phase == "sample_analytic_ddim": 41 | clip_sigma_idx = 1 42 | clip_pixel = 1 43 | for n_samples in [50000]: 44 | for eta in [0.]: 45 | for sample_steps in [10, 20, 50, 100]: 46 | tag = "ddim_ms_eps_clip_sigma_idx_{}_clip_pixel_{}_n_samples_{}_eta_{}_steps_type_{}_sample_steps_{}" \ 47 | .format(clip_sigma_idx, clip_pixel, n_samples, eta, steps_type, sample_steps) 48 | pretrained_model = "ema_celeba" 49 | root = os.path.join("samples", pretrained_model) 50 | profile = { 51 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 52 | "seed": seed, 53 | "pretrained_model": pretrained_model, 54 | "eta": eta, 55 | "sample_steps": sample_steps, 56 | "steps_type": steps_type, 57 | "batch_size": batch_size_per_card * n_devices, 58 | "path": os.path.join(root, tag), 59 | "n_samples": n_samples, 60 | "fid_stat": "workspace/fid_stats/fid_stats_celeba64_train_50000_ddim.npz", 61 | "ms_eps_path": "ms_eps/ema_celeba_10000.pth", 62 | "clip_sigma_idx": clip_sigma_idx, 63 | "clip_pixel": clip_pixel, 64 | } 65 | p = Process(target=run_sample_ddim_ms_eps, args=(profile,)) 66 | tasks.append(Task(p, n_devices=n_devices)) 67 | 68 | elif phase == "sample_ddpm": 69 | for n_samples in [50000]: 70 | for small_sigma in [True, False]: 71 | for sample_steps in [10, 20, 50, 100]: 72 | tag = "n_samples_{}_small_sigma_{}_steps_type_{}_sample_steps_{}".format(n_samples, small_sigma, steps_type, sample_steps) 73 | pretrained_model = "ema_celeba" 74 | root = os.path.join("samples", pretrained_model) 75 | profile = { 76 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 77 | "seed": seed, 78 | "pretrained_model": pretrained_model, 79 | "small_sigma": small_sigma, 80 | "sample_steps": sample_steps, 81 | "batch_size": batch_size_per_card * n_devices, 82 | "path": os.path.join(root, tag), 83 | "n_samples": n_samples, 84 | "steps_type": steps_type, 85 | "fid_stat": "workspace/fid_stats/fid_stats_celeba64_train_50000_ddim.npz", 86 | } 87 | p = Process(target=run_sample, args=(profile,)) 88 | tasks.append(Task(p, n_devices=n_devices)) 89 | 90 | elif phase == "sample_ddim": 91 | for n_samples in [50000]: 92 | for eta in [0.]: 93 | for sample_steps in [10, 20, 50, 100]: 94 | tag = "ddim_n_samples_{}_eta_{}_steps_type_{}_sample_steps_{}".format(n_samples, eta, steps_type, sample_steps) 95 | pretrained_model = "ema_celeba" 96 | root = os.path.join("samples", pretrained_model) 97 | profile = { 98 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 99 | "seed": seed, 100 | "pretrained_model": pretrained_model, 101 | "eta": eta, 102 | "sample_steps": sample_steps, 103 | "steps_type": steps_type, 104 | "batch_size": batch_size_per_card * n_devices, 105 | "path": os.path.join(root, tag), 106 | "n_samples": n_samples, 107 | "fid_stat": "workspace/fid_stats/fid_stats_celeba64_train_50000_ddim.npz", 108 | } 109 | p = Process(target=run_sample_ddim, args=(profile,)) 110 | tasks.append(Task(p, n_devices=n_devices)) 111 | 112 | 113 | def main(): 114 | tasks = [] 115 | add_tasks("sample_analytic_ddim", tasks) 116 | add_tasks("sample_analytic_ddpm", tasks) 117 | add_tasks("sample_ddim", tasks) 118 | add_tasks("sample_ddpm", tasks) 119 | wait_schedule(tasks, devices=[] or available_devices()) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /celeba_lsun_codes/run_lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Process 3 | from interface.runner import run_sample, run_nll, run_save_ms_eps, run_sample_ms_eps, run_nll_ms_eps, run_sample_ddim, run_sample_ddim_ms_eps 4 | from interface.task_schedule import Task, wait_schedule, available_devices 5 | batch_size_per_card = 10 6 | n_devices = 8 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | phase = "sample_analytic_ddpm" 12 | 13 | if phase == "save_ms_eps": 14 | n_samples = 1000 15 | pretrained_model = "ema_lsun_bedroom" 16 | profile = { 17 | "fname_log": os.path.join("ms_eps", "%s_%d.log" % (pretrained_model, n_samples)), 18 | "pretrained_model": pretrained_model, 19 | "train_dataset": "lsun_bedroom", 20 | "batch_size": batch_size_per_card * n_devices, 21 | "fname": os.path.join("ms_eps", "%s_%d.pth" % (pretrained_model, n_samples)), 22 | "n_samples": n_samples 23 | } 24 | p = Process(target=run_save_ms_eps, args=(profile,)) 25 | tasks = [Task(p, n_devices=n_devices)] 26 | 27 | wait_schedule(tasks, devices=[] or available_devices()) 28 | 29 | elif phase == "sample_analytic_ddpm": 30 | tasks = [] 31 | clip_sigma_idx = 1 32 | clip_pixel = 1 33 | n_samples = 50000 34 | steps_type = "linear" 35 | for sample_steps in [10, 25, 50, 100, 200]: 36 | tag = "clip_sigma_idx_{}_clip_pixel_{}_n_samples_{}_steps_type_{}_sample_steps_{}"\ 37 | .format(clip_sigma_idx, clip_pixel, n_samples, steps_type, sample_steps) 38 | pretrained_model = "ema_lsun_bedroom" 39 | root = os.path.join("samples", pretrained_model) 40 | profile = { 41 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 42 | "seed": 1234, 43 | "pretrained_model": pretrained_model, 44 | "sample_steps": sample_steps, 45 | "batch_size": batch_size_per_card * n_devices, 46 | "path": os.path.join(root, tag), 47 | "n_samples": n_samples, 48 | "fid_stat": "workspace/fid_stats/fid_stats_lsun_bedroom_train_50000_ddim.npz", 49 | "ms_eps_path": "ms_eps/ema_lsun_bedroom_1000.pth", 50 | "steps_type": steps_type, 51 | "clip_sigma_idx": clip_sigma_idx, 52 | "clip_pixel": clip_pixel, 53 | } 54 | p = Process(target=run_sample_ms_eps, args=(profile,)) 55 | tasks.append(Task(p, n_devices=n_devices)) 56 | wait_schedule(tasks, devices=[] or available_devices()) 57 | 58 | elif phase == "sample_ddpm": 59 | tasks = [] 60 | for n_samples in [50000]: 61 | for small_sigma in [True]: 62 | for sample_steps in [10, 25, 50, 100, 200]: 63 | tag = "n_samples_{}_small_sigma_{}_sample_steps_{}".format(n_samples, small_sigma, sample_steps) 64 | pretrained_model = "ema_lsun_bedroom" 65 | root = os.path.join("samples", pretrained_model) 66 | profile = { 67 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 68 | "seed": 1234, 69 | "pretrained_model": pretrained_model, 70 | "small_sigma": small_sigma, 71 | "sample_steps": sample_steps, 72 | "batch_size": batch_size_per_card * n_devices, 73 | "path": os.path.join(root, tag), 74 | "n_samples": n_samples, 75 | "fid_stat": "workspace/fid_stats/fid_stats_lsun_bedroom_train_50000_ddim.npz", 76 | } 77 | p = Process(target=run_sample, args=(profile,)) 78 | tasks.append(Task(p, n_devices=n_devices)) 79 | wait_schedule(tasks, devices=available_devices()) 80 | 81 | elif phase == "sample_ddim": 82 | tasks = [] 83 | for n_samples in [50000]: 84 | for eta in [0.]: 85 | for sample_steps in [10, 25, 50, 100, 200]: 86 | tag = "ddim_n_samples_{}_eta_{}_sample_steps_{}".format(n_samples, eta, sample_steps) 87 | pretrained_model = "ema_lsun_bedroom" 88 | root = os.path.join("samples", pretrained_model) 89 | profile = { 90 | "fname_log": os.path.join(root, os.path.join("%s.log" % tag)), 91 | "seed": None, 92 | "pretrained_model": pretrained_model, 93 | "eta": eta, 94 | "sample_steps": sample_steps, 95 | "batch_size": batch_size_per_card * n_devices, 96 | "path": os.path.join(root, tag), 97 | "n_samples": n_samples, 98 | "fid_stat": "workspace/fid_stats/fid_stats_lsun_bedroom_train_50000_ddim.npz", 99 | } 100 | p = Process(target=run_sample_ddim, args=(profile,)) 101 | tasks.append(Task(p, n_devices=n_devices)) 102 | 103 | wait_schedule(tasks, devices=[] or available_devices()) 104 | -------------------------------------------------------------------------------- /celeba_lsun_codes/scripts/pytorch_diffusion_demo: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | SRC=$(python -c "import os,inspect,pytorch_diffusion; print(os.path.dirname(inspect.getfile(pytorch_diffusion)))") 3 | streamlit run "${SRC}/demo.py" 4 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/__init__.py: -------------------------------------------------------------------------------- 1 | r""" Core provides algorithms and models 2 | """ 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddpm import * 2 | 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/criterions/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import core.utils.managers as managers 3 | from core.utils import global_device, diagnose 4 | 5 | 6 | class Criterion(object): 7 | def __init__(self, 8 | models: managers.ModelsManager, 9 | optimizers: managers.OptimizersManager, 10 | lr_schedulers: managers.LRSchedulersManager 11 | ): 12 | r""" Criterion does 13 | 1. calculating objectives 14 | 2. calculating gradients 15 | 3. updating parameters 16 | 17 | Args: 18 | models: an object of ModelsManager 19 | optimizers: an object of OptimizersManager 20 | lr_schedulers: an object of LRSchedulersManager 21 | """ 22 | self.statistics = {} 23 | self.models = models 24 | self.optimizers = optimizers 25 | self.lr_schedulers = lr_schedulers 26 | self.device = global_device() 27 | 28 | def objective(self, v, **kwargs): 29 | raise NotImplementedError 30 | 31 | def update(self, data_loader): 32 | raise NotImplementedError 33 | 34 | def default_val_fn(self, v): 35 | r""" Advise a validation function 36 | """ 37 | return self.objective(v) 38 | 39 | def criterion_name(self): 40 | return self.__class__.__name__.lower() 41 | 42 | def record_grad_norm(self): 43 | for key, model in self.models.__dict__.items(): 44 | self.statistics["grad_norm2_%s" % key] = diagnose.grad_norm(model, 2.) 45 | self.statistics["grad_norminf_%s" % key] = diagnose.grad_norm(model, float('inf')) 46 | 47 | 48 | class MultilevelCriterion(Criterion): 49 | def __init__(self, 50 | models: managers.ModelsManager, 51 | optimizers: managers.OptimizersManager, 52 | lr_schedulers: managers.LRSchedulersManager, 53 | levels: list, 54 | level_n_steps: dict, 55 | level_model_keys: dict 56 | ): 57 | r""" Sometimes the optimization might have multiple levels, e.g., bilevel optimization 58 | MultilevelCriterion does 59 | for level in levels: 60 | 1. calculating objectives 61 | 2. calculating gradients 62 | 3. updating parameters 63 | 64 | Args: 65 | models: an object of ModelsManager 66 | optimizers: an object of OptimizersManager 67 | The optimizers inside are indexed by levels 68 | lr_schedulers: an object of LRSchedulersManager 69 | The schedulers inside are indexed by levels 70 | levels: the name of each level 71 | Example: levels = ["lower", "higher"] 72 | level_n_steps: the steps of each level 73 | Example: level_n_steps = {"lower": 5, "higher": 1} 74 | level_model_keys: the models to update in each level 75 | Example: level_model_keys = {"lower": ["discriminator"], "higher": ["generator"]} 76 | 77 | """ 78 | super().__init__(models, optimizers, lr_schedulers) 79 | self.levels = levels 80 | self.level_n_steps = level_n_steps 81 | self.level_model_keys = level_model_keys 82 | 83 | def update(self, data_loader): 84 | r""" A demo of the multiple-level optimization 85 | """ 86 | v = next(data_loader).to(self.device) 87 | for level in self.levels: 88 | self.models.toggle_grad(*self.level_model_keys[level]) 89 | for i in range(self.level_n_steps[level]): 90 | objective = self.objective(v, level=level).mean() 91 | self.statistics[level] = objective.item() 92 | self.optimizers.get(level).zero_grad() 93 | objective.backward() 94 | self.optimizers.get(level).step() 95 | if level in self.lr_schedulers: 96 | self.lr_schedulers.get(level).step() 97 | 98 | 99 | class NaiveCriterion(Criterion): 100 | def update(self, data_loader): 101 | v = next(data_loader).to(self.device) 102 | loss = self.objective(v).mean() 103 | self.statistics[self.criterion_name()] = loss.item() 104 | self.optimizers.all.zero_grad() 105 | loss.backward() 106 | self.optimizers.all.step() 107 | if "all" in self.lr_schedulers: 108 | self.lr_schedulers.all.step() 109 | # self.record_grad_norm() 110 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/criterions/ddpm.py: -------------------------------------------------------------------------------- 1 | r""" for n>=1, betas[n] is the variance of q(x_n|x_{n-1}) 2 | for n=0, betas[0]=0 3 | """ 4 | 5 | __all__ = ["DDPMDSM"] 6 | 7 | 8 | import torch 9 | import numpy as np 10 | from .base import NaiveCriterion 11 | import core.utils.managers as managers 12 | import core.func as func 13 | import torch.nn as nn 14 | import logging 15 | 16 | 17 | def _rescale_timesteps(n, N, flag): 18 | if flag: 19 | return n * 1000.0 / float(N) 20 | return n 21 | 22 | 23 | def _bipartition(ts): 24 | if ts.dim() == 4: # bs * 2c * w * w 25 | assert ts.size(1) % 2 == 0 26 | c = ts.size(1) // 2 27 | return ts.split(c, dim=1) 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | def _make_coeff(betas): 33 | assert betas[0] == 0 # betas[0] = 0 for convenience 34 | alphas = 1. - betas 35 | cum_alphas = alphas.cumprod() 36 | cum_betas = 1. - cum_alphas 37 | return alphas, cum_alphas, cum_betas 38 | 39 | 40 | def _sample(x_0, cum_alphas, cum_betas): 41 | N = len(cum_alphas) - 1 42 | n = np.random.choice(list(range(1, N + 1)), (len(x_0),)) 43 | eps = torch.randn_like(x_0) 44 | x_n = func.stp(cum_alphas[n] ** 0.5, x_0) + func.stp(cum_betas[n] ** 0.5, eps) 45 | return N, n, eps, x_n 46 | 47 | 48 | def _ddpm_dsm(x_0, eps_model, cum_alphas, cum_betas, rescale_timesteps): 49 | N, n, eps, x_n = _sample(x_0, cum_alphas, cum_betas) 50 | eps_pred = eps_model(x_n, _rescale_timesteps(torch.from_numpy(n).float().to(x_0.device), N, rescale_timesteps)) 51 | return func.sos(eps - eps_pred) 52 | 53 | 54 | def _ddpm_dsm_zero(x_0, d_model, cum_alphas, cum_betas, rescale_timesteps): 55 | N, n, eps, x_n = _sample(x_0, cum_alphas, cum_betas) 56 | d_pred = d_model(x_n, _rescale_timesteps(torch.from_numpy(n).float().to(x_0.device), N, rescale_timesteps)) 57 | return func.sos(x_0 - d_pred) 58 | 59 | 60 | def _ddpm_ddm(x_0, tau_model, cum_alphas, cum_betas, rescale_timesteps): 61 | N, n, eps, x_n = _sample(x_0, cum_alphas, cum_betas) 62 | tau_pred = tau_model(x_n, _rescale_timesteps(torch.from_numpy(n).float().to(x_0.device), N, rescale_timesteps)) 63 | return func.sos(eps.pow(2) - tau_pred) 64 | 65 | 66 | def _ddpm_ddm_zero(x_0, kappa_model, cum_alphas, cum_betas, rescale_timesteps): 67 | N, n, eps, x_n = _sample(x_0, cum_alphas, cum_betas) 68 | kappa_pred = kappa_model(x_n, _rescale_timesteps(torch.from_numpy(n).float().to(x_0.device), N, rescale_timesteps)) 69 | return func.sos(x_0.pow(2) - kappa_pred) 70 | 71 | 72 | def _ddpm_dsdm(x_0, eps_tau_model, cum_alphas, cum_betas, rescale_timesteps): 73 | N, n, eps, x_n = _sample(x_0, cum_alphas, cum_betas) 74 | eps_tau_pred = eps_tau_model(x_n, _rescale_timesteps(torch.from_numpy(n).float().to(x_0.device), N, rescale_timesteps)) 75 | eps_pred, tau_pred = _bipartition(eps_tau_pred) 76 | return func.sos(eps - eps_pred), func.sos(eps.pow(2) - tau_pred) 77 | 78 | 79 | class DDPMDSM(NaiveCriterion): 80 | def __init__(self, 81 | betas, 82 | rescale_timesteps, # todo: remove this argument 83 | models: managers.ModelsManager, 84 | optimizers: managers.OptimizersManager, 85 | lr_schedulers: managers.LRSchedulersManager, 86 | ): 87 | r""" Estimating the mean of optimal Gaussian reverse in DDPM = Denoising score matching (DSM) 88 | """ 89 | assert isinstance(betas, np.ndarray) and betas[0] == 0 90 | super().__init__(models, optimizers, lr_schedulers) 91 | self.eps_model = nn.DataParallel(models.eps_model) # predict noise 92 | self.betas = betas 93 | self.alphas, self.cum_alphas, self.cum_betas = _make_coeff(self.betas) 94 | self.rescale_timesteps = rescale_timesteps 95 | logging.info("DDPMDSM with rescale_timesteps={}".format(self.rescale_timesteps)) 96 | 97 | def objective(self, v, **kwargs): 98 | return _ddpm_dsm(v, self.eps_model, self.cum_alphas, self.cum_betas, self.rescale_timesteps) 99 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample import * 2 | from .score import * 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/evaluate/sample.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["grid_sample", "sample2dir"] 3 | 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid, save_image 7 | from core.utils import amortize 8 | 9 | 10 | def grid_sample(fname, nrow, ncol, sample_fn, unpreprocess_fn=None): 11 | r""" Sample images in a grid 12 | Args: 13 | fname: the file name 14 | nrow: the number of rows of the grid 15 | ncol: the number of columns of the grid 16 | sample_fn: the sampling function, n_samples -> samples 17 | unpreprocess_fn: the function to unpreprocess data 18 | """ 19 | root, name = os.path.split(fname) 20 | os.makedirs(root, exist_ok=True) 21 | os.makedirs(os.path.join(root, "tensor"), exist_ok=True) 22 | n_samples = nrow * ncol 23 | samples = sample_fn(n_samples) 24 | if unpreprocess_fn is not None: 25 | samples = unpreprocess_fn(samples) 26 | grid = make_grid(samples, nrow) 27 | save_image(grid, fname) 28 | torch.save(samples, os.path.join(root, "tensor", "%s.pth" % name)) # save the tensor data 29 | 30 | 31 | def cnt_png(path): 32 | png_files = filter(lambda x: x.endswith(".png"), os.listdir(path)) 33 | return len(list(png_files)) 34 | 35 | 36 | def sample2dir(path, n_samples, batch_size, sample_fn, unpreprocess_fn=None, persist=True): 37 | os.makedirs(path, exist_ok=True) 38 | idx = n_png = cnt_png(path) if persist else 0 39 | for _batch_size in amortize(n_samples - n_png, batch_size): 40 | samples = sample_fn(_batch_size) 41 | samples = unpreprocess_fn(samples) 42 | for sample in samples: 43 | save_image(sample, os.path.join(path, "{}.png".format(idx))) 44 | idx += 1 45 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/evaluate/score.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["score_on_dataset"] 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def score_on_dataset(dataset: Dataset, score_fn, batch_size): 9 | r""" 10 | Args: 11 | dataset: an instance of Dataset 12 | score_fn: a batch of data -> a batch of scalars 13 | batch_size: the batch size 14 | """ 15 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 16 | total_score = None 17 | tuple_output = None 18 | dataloader = DataLoader(dataset, batch_size=batch_size) 19 | for idx, v in enumerate(dataloader): 20 | v = v.to(device) 21 | score = score_fn(v) 22 | if idx == 0: 23 | tuple_output = isinstance(score, tuple) 24 | total_score = (0.,) * len(score) if tuple_output else 0. 25 | if tuple_output: 26 | total_score = tuple([a + b.sum().detach().item() for a, b in zip(total_score, score)]) 27 | else: 28 | total_score += score.sum().detach().item() 29 | if tuple_output: 30 | mean_score = tuple([a / len(dataset) for a in total_score]) 31 | else: 32 | mean_score = total_score / len(dataset) 33 | return mean_score 34 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/func/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .differential import * 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/func/differential.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["RequiresGradContext", "differential"] 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.autograd as autograd 8 | from typing import Union, List 9 | 10 | 11 | def judge_requires_grad(obj: Union[torch.Tensor, nn.Module]): 12 | if isinstance(obj, torch.Tensor): 13 | return obj.requires_grad 14 | elif isinstance(obj, nn.Module): 15 | return next(obj.parameters()).requires_grad 16 | else: 17 | raise TypeError 18 | 19 | 20 | class RequiresGradContext(object): 21 | def __init__(self, *objs: Union[torch.Tensor, nn.Module], requires_grad: Union[List[bool], bool]): 22 | self.objs = objs 23 | self.backups = [judge_requires_grad(obj) for obj in objs] 24 | if isinstance(requires_grad, bool): 25 | self.requires_grads = [requires_grad] * len(objs) 26 | elif isinstance(requires_grad, list): 27 | self.requires_grads = requires_grad 28 | else: 29 | raise TypeError 30 | assert len(self.objs) == len(self.requires_grads) 31 | 32 | def __enter__(self): 33 | for obj, requires_grad in zip(self.objs, self.requires_grads): 34 | obj.requires_grad_(requires_grad) 35 | 36 | def __exit__(self, exc_type, exc_val, exc_tb): 37 | for obj, backup in zip(self.objs, self.backups): 38 | obj.requires_grad_(backup) 39 | 40 | 41 | def differential(fn, v, retain_graph=None, create_graph=False): 42 | r""" d fn / dv 43 | Args: 44 | fn: a batch of tensor -> a batch of scalar 45 | v: a batch of tensor 46 | retain_graph: see autograd.grad, default to create_graph 47 | create_graph: see autograd.grad 48 | """ 49 | if retain_graph is None: 50 | retain_graph = create_graph 51 | with RequiresGradContext(v, requires_grad=True): 52 | return autograd.grad(fn(v).sum(), v, retain_graph=retain_graph, create_graph=create_graph)[0] 53 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/func/functions.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["stp", "sos", "mos", "inner_product", "duplicate", "unsqueeze_like", "logsumexp", "log_discretized_normal", 3 | "binary_cross_entropy_with_logits", "log_bernoulli", "kl_between_normal"] 4 | 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch 9 | 10 | 11 | def stp(s: np.ndarray, ts: torch.Tensor): # scalar tensor product 12 | s = torch.from_numpy(s).type_as(ts) 13 | extra_dims = (1,) * (ts.dim() - 1) 14 | return s.view(-1, *extra_dims) * ts 15 | 16 | 17 | def sos(a, start_dim=1): # sum of square 18 | return a.pow(2).flatten(start_dim=start_dim).sum(dim=-1) 19 | 20 | 21 | def mos(a, start_dim=1): # mean of square 22 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 23 | 24 | 25 | def inner_product(a, b, start_dim=1): 26 | return (a * b).flatten(start_dim=start_dim).sum(dim=-1) 27 | 28 | 29 | def duplicate(tensor, *size): 30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape) 31 | 32 | 33 | def unsqueeze_like(tensor, template, start="left"): 34 | if start == "left": 35 | tensor_dim = tensor.dim() 36 | template_dim = template.dim() 37 | assert tensor.shape == template.shape[:tensor_dim] 38 | return tensor.view(*tensor.shape, *([1] * (template_dim - tensor_dim))) 39 | elif start == "right": 40 | tensor_dim = tensor.dim() 41 | template_dim = template.dim() 42 | assert tensor.shape == template.shape[-tensor_dim:] 43 | return tensor.view(*([1] * (template_dim - tensor_dim)), *tensor.shape) 44 | else: 45 | raise ValueError 46 | 47 | 48 | def logsumexp(tensor, dim, keepdim=False): 49 | # the logsumexp of pytorch is not stable! 50 | tensor_max, _ = tensor.max(dim=dim, keepdim=True) 51 | ret = (tensor - tensor_max).exp().sum(dim=dim, keepdim=True).log() + tensor_max 52 | if not keepdim: 53 | ret.squeeze_(dim=dim) 54 | return ret 55 | 56 | 57 | def approx_standard_normal_cdf(x): 58 | """ 59 | A fast approximation of the cumulative distribution function of the 60 | standard normal. 61 | """ 62 | return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) 63 | 64 | 65 | def log_discretized_normal(x, mu, var): # element-wise 66 | centered_x = x - mu 67 | std = var ** 0.5 68 | left = (centered_x - 1. / 255) / std 69 | right = (centered_x + 1. / 255) / std 70 | 71 | cdf_right = approx_standard_normal_cdf(right) 72 | cdf_left = approx_standard_normal_cdf(left) 73 | cdf_delta = cdf_right - cdf_left 74 | 75 | return torch.where( 76 | x < -0.999, 77 | cdf_right.clamp(min=1e-12).log(), 78 | torch.where(x > 0.999, (1. - cdf_left).clamp(min=1e-12).log(), cdf_delta.clamp(min=1e-12).log()), 79 | ) 80 | 81 | 82 | def binary_cross_entropy_with_logits(logits, inputs): 83 | r""" -inputs * log (sigmoid(logits)) - (1 - inputs) * log (1 - sigmoid(logits)) element wise 84 | with automatically expand dimensions 85 | """ 86 | if inputs.dim() < logits.dim(): 87 | inputs = inputs.expand_as(logits) 88 | else: 89 | logits = logits.expand_as(inputs) 90 | return F.binary_cross_entropy_with_logits(logits, inputs, reduction="none") 91 | 92 | 93 | def log_bernoulli(inputs, logits, n_data_dim): 94 | return -binary_cross_entropy_with_logits(logits, inputs).flatten(-n_data_dim).sum(dim=-1) 95 | 96 | 97 | def kl_between_normal(mu_0, var_0, mu_1, var_1): # element-wise 98 | tensor = None 99 | for obj in (mu_0, var_0, mu_1, var_1): 100 | if isinstance(obj, torch.Tensor): 101 | tensor = obj 102 | break 103 | assert tensor is not None 104 | 105 | var_0, var_1 = [ 106 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 107 | for x in (var_0, var_1) 108 | ] 109 | 110 | return 0.5 * (var_0 / var_1 + (mu_0 - mu_1).pow(2) / var_1 + var_1.log() - var_0.log() - 1.) 111 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/cifar_imagenet_codes/core/inference/__init__.py -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/inference/sampler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/cifar_imagenet_codes/core/inference/sampler/__init__.py -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/inference/sampler/reverse_ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | import math 5 | from core.inference.utils import _choice_steps, _x_0_pred, _report_statistics 6 | 7 | 8 | @ torch.no_grad() 9 | def reverse_ddim_naive(x_init, betas, rescale_timesteps, eta=0., steps_type='linear', eps_model=None, sample_steps=None, shift1=False): 10 | assert eps_model is not None 11 | assert isinstance(betas, np.ndarray) and betas[0] == 0 12 | N = len(betas) - 1 13 | sample_steps = sample_steps or N 14 | ns = _choice_steps(N, sample_steps, typ=steps_type) 15 | alphas = 1. - betas 16 | cum_alphas = alphas.cumprod() 17 | cum_betas = 1. - cum_alphas 18 | 19 | logging.info("reverse_ddim_naive with eps_model, rescale_timesteps={}, eta={}, " 20 | "sample_steps={}, steps_type={}, shift1={}" 21 | .format(rescale_timesteps, eta, sample_steps, steps_type, shift1)) 22 | 23 | x = x_init 24 | for s, r in list(zip([0] + ns, ns))[::-1]: 25 | statistics = {} 26 | skip_alpha = alphas[s + 1: r + 1].prod() 27 | skip_beta = 1. - skip_alpha 28 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 29 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 30 | lamb2 = eta ** 2 * sigma2_small 31 | statistics['skip_beta'] = skip_beta 32 | statistics['cum_beta_s'] = cum_beta_s 33 | statistics['cum_beta_r'] = cum_beta_r 34 | statistics['cum_alpha_s'] = cum_alpha_s 35 | 36 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model, shift1=shift1) 37 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 38 | coeff1 = cum_alpha_s ** 0.5 39 | coeff2 = (cum_beta_s - lamb2) ** 0.5 40 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * eps_pred 41 | if s != 0: 42 | sigma2 = lamb2 43 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 44 | statistics['sigma2'] = sigma2 45 | else: 46 | x = x_mean 47 | _report_statistics(s, r, statistics) 48 | return x 49 | 50 | 51 | @ torch.no_grad() 52 | def reverse_ddim_ms_eps(x_init, betas, rescale_timesteps, eta=0., eps_model=None, 53 | ms_eps=None, sample_steps=None, clip_sigma_idx=0, clip_pixel=2, shift1=False): 54 | assert eps_model is not None and ms_eps is not None 55 | assert isinstance(betas, np.ndarray) and betas[0] == 0 56 | N = len(betas) - 1 57 | sample_steps = sample_steps or N 58 | ns = _choice_steps(N, sample_steps, typ='linear', ms_eps=ms_eps, betas=betas) 59 | alphas = 1. - betas 60 | cum_alphas = alphas.cumprod() 61 | cum_betas = 1. - cum_alphas 62 | 63 | logging.info("reverse_ddim_ms_eps with eps_model, rescale_timesteps={}, eta={}, sample_steps={}, clip_sigma_idx={}, clip_pixel={}, shift1={}" 64 | .format(rescale_timesteps, eta, sample_steps, clip_sigma_idx, clip_pixel, shift1)) 65 | 66 | x = x_init 67 | for s, r in list(zip([0] + ns, ns))[::-1]: 68 | statistics = {} 69 | skip_alpha = alphas[s + 1: r + 1].prod() 70 | skip_beta = 1. - skip_alpha 71 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 72 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 73 | lamb2 = eta ** 2 * sigma2_small 74 | statistics['skip_beta'] = skip_beta 75 | statistics['cum_beta_s'] = cum_beta_s 76 | statistics['cum_beta_r'] = cum_beta_r 77 | statistics['cum_alpha_s'] = cum_alpha_s 78 | 79 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model, shift1=shift1) 80 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 81 | coeff1 = cum_alpha_s ** 0.5 82 | coeff2 = (cum_beta_s - lamb2) ** 0.5 83 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * eps_pred 84 | if s != 0: 85 | cov_x_0_pred = cum_beta_r / cum_alpha_r * (1. - ms_eps[r]) 86 | cov_x_0_pred_clamp = np.clip(cov_x_0_pred, 0., 1.) 87 | coeff_cov_x_0 = (cum_alpha_s ** 0.5 - ((cum_beta_s - lamb2) * cum_alpha_r / cum_beta_r) ** 0.5) ** 2 88 | offset = coeff_cov_x_0 * cov_x_0_pred_clamp 89 | sigma2 = lamb2 + offset 90 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 91 | statistics['sigma2_unclip'] = sigma2.item() 92 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 93 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 94 | statistics['sigma2_threshold'] = sigma2_threshold 95 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 96 | statistics['sigma2'] = sigma2 97 | else: 98 | x = x_mean 99 | _report_statistics(s, r, statistics) 100 | return x 101 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/inference/sampler/reverse_ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | import math 5 | from core.inference.utils import _choice_steps, _x_0_pred, _report_statistics 6 | 7 | 8 | @ torch.no_grad() 9 | def reverse_ddpm_naive(x_init, betas, small_sigma, clip_denoise, rescale_timesteps, steps_type='linear', clip_sigma_idx=0, clip_pixel=2, 10 | eps_model=None, d_model=None, sample_steps=None, shift1=False): 11 | assert (eps_model is None and d_model is not None) or (eps_model is not None and d_model is None) 12 | assert isinstance(betas, np.ndarray) and betas[0] == 0 13 | N = len(betas) - 1 14 | sample_steps = sample_steps or N 15 | ns = _choice_steps(N, sample_steps, steps_type) 16 | alphas = 1. - betas 17 | cum_alphas = alphas.cumprod() 18 | cum_betas = 1. - cum_alphas 19 | 20 | logging.info("reverse_ddpm_naive with {}, small_sigma={}, clip_denoise={}, rescale_timesteps={}, " 21 | "sample_steps={}, steps_type={}, clip_sigma_idx={}, clip_pixel={}, shift1={}" 22 | .format("eps_model" if eps_model is not None else "d_model", 23 | small_sigma, clip_denoise, rescale_timesteps, sample_steps, steps_type, clip_sigma_idx, clip_pixel, shift1)) 24 | 25 | x = x_init 26 | for s, r in list(zip([0] + ns, ns))[::-1]: 27 | statistics = {} 28 | skip_alpha = alphas[s + 1: r + 1].prod() 29 | skip_beta = 1. - skip_alpha 30 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 31 | statistics['skip_beta'] = skip_beta 32 | statistics['cum_beta_s'] = cum_beta_s 33 | statistics['cum_beta_r'] = cum_beta_r 34 | statistics['cum_alpha_s'] = cum_alpha_s 35 | 36 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model, d_model=d_model, shift1=shift1) 37 | if clip_denoise: 38 | x_0_pred = x_0_pred.clamp(-1., 1.) 39 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / cum_beta_r 40 | coeff2 = skip_alpha ** 0.5 * cum_beta_s / cum_beta_r 41 | x_mean = coeff1 * x_0_pred + coeff2 * x 42 | if s != 0: 43 | sigma2 = skip_beta 44 | if small_sigma: 45 | sigma2 *= cum_beta_s / cum_beta_r 46 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 47 | statistics['sigma2_unclip'] = sigma2.item() 48 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 49 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 50 | statistics['sigma2_threshold'] = sigma2_threshold 51 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 52 | statistics['sigma2'] = sigma2 53 | else: 54 | x = x_mean 55 | _report_statistics(s, r, statistics) 56 | return x 57 | 58 | 59 | @ torch.no_grad() 60 | def reverse_ddpm_ms_eps(x_init, betas, rescale_timesteps, steps_type='linear', clip_sigma_idx=0, clip_pixel=2, eps_model=None, 61 | ms_eps=None, sample_steps=None, shift1=False): 62 | assert eps_model is not None and ms_eps is not None 63 | assert isinstance(betas, np.ndarray) and betas[0] == 0 64 | N = len(betas) - 1 65 | sample_steps = sample_steps or N 66 | ns = _choice_steps(N, sample_steps, steps_type, ms_eps=ms_eps, betas=betas) 67 | alphas = 1. - betas 68 | cum_alphas = alphas.cumprod() 69 | cum_betas = 1. - cum_alphas 70 | 71 | logging.info("reverse_ddpm_ms_eps with eps_model, rescale_timesteps={}, sample_steps={}, steps_type={}, clip_sigma_idx={}, clip_pixel={}, shift1={}" 72 | .format(rescale_timesteps, sample_steps, steps_type, clip_sigma_idx, clip_pixel, shift1)) 73 | 74 | x = x_init 75 | for s, r in list(zip([0] + ns, ns))[::-1]: 76 | statistics = {} 77 | skip_alpha = alphas[s + 1: r + 1].prod() 78 | skip_beta = 1. - skip_alpha 79 | cum_alpha_s, cum_alpha_r, cum_beta_s, cum_beta_r = cum_alphas[s], cum_alphas[r], cum_betas[s], cum_betas[r] 80 | statistics['skip_beta'] = skip_beta 81 | statistics['cum_beta_s'] = cum_beta_s 82 | statistics['cum_beta_r'] = cum_beta_r 83 | statistics['cum_alpha_s'] = cum_alpha_s 84 | 85 | x_0_pred, eps_pred = _x_0_pred(x, r, cum_alphas, rescale_timesteps, eps_model=eps_model, shift1=shift1) 86 | x_0_pred_clamp = x_0_pred.clamp(-1., 1.) 87 | coeff1 = skip_beta * cum_alpha_s ** 0.5 / cum_beta_r 88 | coeff2 = skip_alpha ** 0.5 * cum_beta_s / cum_beta_r 89 | x_mean = coeff1 * x_0_pred_clamp + coeff2 * x 90 | if s != 0: 91 | sigma2_small = skip_beta * cum_beta_s / cum_beta_r 92 | cov_x_0_pred = cum_beta_r / cum_alpha_r * (1. - ms_eps[r]) 93 | cov_x_0_pred_clamp = np.clip(cov_x_0_pred, 0., 1.) 94 | coeff_cov_x_0 = cum_alpha_s * skip_beta ** 2 / cum_beta_r ** 2 95 | offset = coeff_cov_x_0 * cov_x_0_pred_clamp 96 | sigma2 = sigma2_small + offset 97 | statistics['sigma2_small'] = sigma2_small 98 | statistics['cov_x_0'] = cov_x_0_pred.item() 99 | statistics['cov_x_0_clamp'] = cov_x_0_pred_clamp.item() 100 | statistics['coeff_cov_x_0'] = coeff_cov_x_0 101 | statistics['offset'] = offset.item() 102 | statistics['sigma2_small/offset'] = statistics['sigma2_small'] / statistics['offset'] 103 | if s < ns[clip_sigma_idx]: # clip_sigma_idx = 0 <=> not clip 104 | statistics['sigma2_unclip'] = sigma2.item() 105 | sigma2_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2 106 | sigma2 = np.clip(sigma2, 0., sigma2_threshold) 107 | statistics['sigma2_threshold'] = sigma2_threshold 108 | x = x_mean + sigma2 ** 0.5 * torch.randn_like(x) 109 | statistics['sigma2'] = sigma2 110 | else: 111 | x = x_mean 112 | _report_statistics(s, r, statistics) 113 | return x 114 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/inference/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | import math 5 | from core.criterions.ddpm import _rescale_timesteps, _bipartition 6 | 7 | 8 | def _report_statistics(s, r, statistics): 9 | statistics_str = {k: "{:.5e}".format(v) for k, v in statistics.items()} 10 | logging.info("[(s, r): ({}, {})] [{}]".format(s, r, statistics_str)) 11 | 12 | 13 | def _x_0_pred(x, n, cum_alphas, rescale_timesteps, eps_model=None, d_model=None, shift1=False): # estimate of E[x_0|x_n] w.r.t. q 14 | N = len(cum_alphas) - 1 15 | cum_alpha_n = cum_alphas[n] 16 | cum_beta_n = 1. - cum_alpha_n 17 | input_n = n - 1 if shift1 else n # for compatibility of pretrained models 18 | if eps_model is not None: 19 | eps_pred = eps_model(x, _rescale_timesteps(torch.tensor([input_n] * x.size(0)).type_as(x), N, rescale_timesteps)) 20 | x_0_pred = cum_alpha_n ** -0.5 * x - (1. / cum_alpha_n - 1.) ** 0.5 * eps_pred 21 | else: 22 | x_0_pred = d_model(x, _rescale_timesteps(torch.tensor([input_n] * x.size(0)).type_as(x), N, rescale_timesteps)) 23 | eps_pred = - (cum_alpha_n / cum_beta_n) ** 0.5 * x_0_pred + (1. / cum_beta_n ** 0.5) * x 24 | return x_0_pred, eps_pred 25 | 26 | 27 | def _choice_steps_linear(N, sample_steps): 28 | assert sample_steps > 1 29 | frac_stride = (N - 1) / (sample_steps - 1) 30 | cur_idx = 1.0 31 | steps = [] 32 | for _ in range(sample_steps): 33 | steps.append(round(cur_idx)) 34 | cur_idx += frac_stride 35 | return steps 36 | 37 | 38 | def _choice_steps_quad_ddim(N, sample_steps): 39 | seq = np.linspace(0, np.sqrt(N * 0.8), sample_steps) ** 2 40 | seq = [int(s) + 1 for s in list(seq)] 41 | return seq 42 | 43 | 44 | def _round_and_remove_dup(seq): 45 | seq = [round(item) for item in seq] 46 | val_old = -float('inf') 47 | for idx, val in enumerate(seq): 48 | if val <= val_old: 49 | seq[idx] = int(val_old) + 1 50 | val_old = seq[idx] 51 | return seq 52 | 53 | 54 | def _choice_steps_rfn(N, sample_steps, rfn): # rfn: reverse of a function fn, s.t., fn(0)=1 and fn(1)=0 55 | assert sample_steps > 1 56 | ys = [k / (sample_steps - 1) for k in range(sample_steps)][::-1] 57 | xs = [rfn(y) for y in ys] 58 | assert xs[0] == 0 and xs[-1] == 1 59 | steps = [(N - 1) * x + 1 for x in xs] 60 | assert steps[0] == 1 and steps[-1] == N 61 | steps = _round_and_remove_dup(steps) 62 | assert steps[0] == 1 and steps[-1] == N 63 | assert all(steps[i] < steps[i+1] for i in range(len(steps) - 1)) 64 | return steps 65 | 66 | 67 | def _split(ms_eps, N, K): 68 | idx_g1 = N + 1 69 | for n in range(1, N + 1): # Theoretically, ms_eps <= 1. Remove points of poor estimation 70 | if ms_eps[n] > 1: 71 | idx_g1 = n 72 | break 73 | num_bad = 2 * (N - idx_g1 + 1) 74 | bad_ratio = num_bad / N 75 | 76 | N1 = N - num_bad 77 | K1 = math.ceil((1. - 0.8 * bad_ratio) * K) 78 | K2 = K - K1 79 | if K1 > N1: 80 | K1 = N1 81 | K2 = K - K1 82 | if K2 > num_bad: 83 | K2 = num_bad 84 | K1 = K - K2 85 | if num_bad > 0 and K2 == 0: 86 | K2 = 1 87 | K1 = K - K2 88 | assert num_bad <= N 89 | assert K1 <= N1 and K2 <= N - N1 90 | return K1, N1, K2, num_bad 91 | 92 | 93 | def _ms_score(ms_eps, betas): 94 | alphas = 1. - betas 95 | cum_alphas = alphas.cumprod() 96 | cum_betas = 1. - cum_alphas 97 | ms_score = np.zeros_like(ms_eps) 98 | ms_score[1:] = ms_eps[1:] / cum_betas[1:] 99 | return ms_score 100 | 101 | 102 | def _solve_fn_dp(fn, N, K): # F[st, ed] with 1 <= st < ed <= N, other elements is inf 103 | if N == K: 104 | return list(range(1, N + 1)) 105 | 106 | F = fn[: N + 1, : N + 1] 107 | 108 | C = np.full((K + 1, N + 1), float('inf')) # C[k, n] with 2 <= k <= K, k <= n <= N 109 | D = np.full((K + 1, N + 1), -1) # D[k, n] with 2 <= k <= K, k <= n <= N 110 | 111 | C[2, 2: N] = F[1, 2: N] 112 | D[2, 2: N] = 1 113 | 114 | for k in range(3, K + 1): 115 | # {C[k-1, s] + F[s, r]}_{0 <= s, r <= N} = {C[k-1, s] + F[s, r]}_{k-1 <= s < r <= N} 116 | tmp = C[k - 1, :].reshape(N + 1, 1) + F 117 | C[k, k: N + 1] = np.min(tmp, axis=0)[k: N + 1] 118 | D[k, k: N + 1] = np.argmin(tmp, axis=0)[k: N + 1] 119 | 120 | res = [N] 121 | n, k = N, K 122 | while k > 2: 123 | n = D[k, n] 124 | res.append(n) 125 | k -= 1 126 | res.append(1) 127 | return res[::-1] 128 | 129 | 130 | def _solve_fn_dp_general(fn, a, b, K): # F[st, ed] with a <= st < ed <= b, other elements is inf 131 | N = b - a + 1 132 | F = np.full((N + 1, N + 1), float('inf')) # F[st, ed] with 1 <= st < ed <= N 133 | F[1: N + 1, 1: N + 1] = fn[a: b + 1, a: b + 1] 134 | res = _solve_fn_dp(F, N, K) 135 | return [idx + a - 1 for idx in res] 136 | 137 | 138 | def _get_fn_m(ms_score, alphas, N): 139 | F = np.full((N + 1, N + 1), float('inf')) # F[st, ed] with 1 <= st < ed <= N 140 | for s in range(1, N + 1): 141 | skip_alphas = alphas[s + 1: N + 1].cumprod() 142 | skip_betas = 1. - skip_alphas 143 | before_log = 1. - skip_betas * ms_score[s + 1: N + 1] 144 | F[s, s + 1: N + 1] = np.log(before_log) 145 | return F 146 | 147 | 148 | def _dp_seg(ms_eps, betas, N, K): 149 | K1, N1, K2, num_bad = _split(ms_eps, N, K) 150 | 151 | alphas = 1. - betas 152 | ms_score = _ms_score(ms_eps, betas) 153 | F = _get_fn_m(ms_score, alphas, N1) 154 | 155 | steps1 = _solve_fn_dp(F, N1, K1) 156 | if K2 > 0: 157 | frac = (N - N1) / K2 158 | steps2 = [round(N - frac * k) for k in range(K2)][::-1] 159 | assert steps1[-1] < steps2[0] 160 | assert len(steps1) + len(steps2) == K 161 | assert steps1[0] == 1 and steps1[-1] == N1 162 | assert steps2[-1] == N 163 | else: 164 | steps2 = [] 165 | steps = steps1 + steps2 166 | assert steps[0] == 1 and steps[-1] == N 167 | assert all(steps[i] < steps[i + 1] for i in range(len(steps) - 1)) 168 | return steps 169 | 170 | 171 | def _choice_steps(N, sample_steps, typ, ms_eps=None, betas=None): 172 | if typ == 'linear': 173 | steps = _choice_steps_linear(N, sample_steps) 174 | elif typ.startswith('power'): 175 | power = int(typ.split('power')[1]) 176 | steps = _choice_steps_rfn(N, sample_steps, rfn=lambda y: 1 - y ** (1. / power)) 177 | elif typ == 'quad_ddim': 178 | steps = _choice_steps_quad_ddim(N, sample_steps) 179 | elif typ == 'dp_seg': 180 | steps = _dp_seg(ms_eps, betas, N, sample_steps) 181 | else: 182 | raise NotImplementedError 183 | 184 | assert len(steps) == sample_steps and steps[0] == 1 185 | if typ not in ["linear_ddim", "quad_ddim"]: 186 | assert steps[-1] == N 187 | 188 | return steps 189 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import * 2 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/modules/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/modules/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def zero_module(module): 56 | """ 57 | Zero out the parameters of a module and return it. 58 | """ 59 | for p in module.parameters(): 60 | p.detach().zero_() 61 | return module 62 | 63 | 64 | def scale_module(module, scale): 65 | """ 66 | Scale the parameters of a module and return it. 67 | """ 68 | for p in module.parameters(): 69 | p.detach().mul_(scale) 70 | return module 71 | 72 | 73 | def mean_flat(tensor): 74 | """ 75 | Take the mean over all non-batch dimensions. 76 | """ 77 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 78 | 79 | 80 | def normalization(channels): 81 | """ 82 | Make a standard normalization layer. 83 | 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNorm32(32, channels) 88 | 89 | 90 | def timestep_embedding(timesteps, dim, max_period=10000): 91 | """ 92 | Create sinusoidal timestep embeddings. 93 | 94 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 95 | These may be fractional. 96 | :param dim: the dimension of the output. 97 | :param max_period: controls the minimum frequency of the embeddings. 98 | :return: an [N x dim] Tensor of positional embeddings. 99 | """ 100 | half = dim // 2 101 | freqs = th.exp( 102 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 103 | ).to(device=timesteps.device) 104 | args = timesteps[:, None].float() * freqs[None] 105 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 106 | if dim % 2: 107 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 108 | return embedding 109 | 110 | 111 | def checkpoint(func, inputs, params, flag): 112 | """ 113 | Evaluate a function without caching intermediate activations, allowing for 114 | reduced memory at the expense of extra compute in the backward pass. 115 | 116 | :param func: the function to evaluate. 117 | :param inputs: the argument sequence to pass to `func`. 118 | :param params: a sequence of parameters `func` depends on but does not 119 | explicitly take as arguments. 120 | :param flag: if False, disable gradient checkpointing. 121 | """ 122 | if flag: 123 | args = tuple(inputs) + tuple(params) 124 | return CheckpointFunction.apply(func, len(inputs), *args) 125 | else: 126 | return func(*inputs) 127 | 128 | 129 | class CheckpointFunction(th.autograd.Function): 130 | @staticmethod 131 | def forward(ctx, run_function, length, *args): 132 | ctx.run_function = run_function 133 | ctx.input_tensors = list(args[:length]) 134 | ctx.input_params = list(args[length:]) 135 | with th.no_grad(): 136 | output_tensors = ctx.run_function(*ctx.input_tensors) 137 | return output_tensors 138 | 139 | @staticmethod 140 | def backward(ctx, *output_grads): 141 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 142 | with th.enable_grad(): 143 | # Fixes a bug where the first op in run_function modifies the 144 | # Tensor storage in place, which is not allowed for detach()'d 145 | # Tensors. 146 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 147 | output_tensors = ctx.run_function(*shallow_copies) 148 | input_grads = th.autograd.grad( 149 | output_tensors, 150 | ctx.input_tensors + ctx.input_params, 151 | output_grads, 152 | allow_unused=True, 153 | ) 154 | del ctx.input_tensors 155 | del ctx.input_params 156 | del output_tensors 157 | return (None, None) + input_grads 158 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import diagnose 2 | from . import managers 3 | from .clip_grad import * 4 | from .device_utils import * 5 | 6 | 7 | def amortize(n_samples, batch_size): 8 | k = n_samples // batch_size 9 | r = n_samples % batch_size 10 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 11 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["clip_grad_norm_", "clip_grad_element_wise_"] 3 | 4 | 5 | import torch 6 | from typing import List, Union 7 | import math 8 | 9 | 10 | def clip_grad_norm_(grads: Union[torch.Tensor, List[torch.Tensor]], max_norm: float, norm_type: float = 2.): 11 | if isinstance(grads, torch.Tensor): 12 | grads = [grads] 13 | max_norm = float(max_norm) 14 | norm_type = float(norm_type) 15 | if norm_type == math.inf: 16 | total_norm = max(grad.data.abs().max() for grad in grads) 17 | else: 18 | total_norm = 0 19 | for grad in grads: 20 | param_norm = grad.data.norm(norm_type) 21 | total_norm += param_norm.item() ** norm_type 22 | total_norm = total_norm ** (1. / norm_type) 23 | clip_coef = max_norm / (total_norm + 1e-6) 24 | if clip_coef < 1: 25 | for grad in grads: 26 | grad.data.mul_(clip_coef) 27 | return total_norm 28 | 29 | 30 | def clip_grad_element_wise_(grads: Union[torch.Tensor, List[torch.Tensor]], max_norm: float): 31 | if isinstance(grads, torch.Tensor): 32 | grads = [grads] 33 | for grad in grads: 34 | grad.data.clamp_(-max_norm, max_norm) 35 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/device_utils.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["device_of", "global_device"] 3 | 4 | 5 | import torch.nn as nn 6 | import torch 7 | from typing import Union 8 | from .managers import ModelsManager 9 | 10 | 11 | def device_of(inputs: Union[nn.Module, torch.Tensor, ModelsManager]) -> torch.device: 12 | if isinstance(inputs, nn.Module): 13 | return next(inputs.parameters()).device 14 | elif isinstance(inputs, torch.Tensor): 15 | return inputs.device 16 | elif isinstance(inputs, ModelsManager): 17 | return device_of(next(iter(ModelsManager.__dict__.values()))) 18 | else: 19 | raise TypeError 20 | 21 | 22 | def global_device() -> torch.device: 23 | return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 24 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/diagnose.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Union, Iterator 3 | import torch 4 | from .device_utils import device_of, global_device 5 | 6 | 7 | def grad_norm_inf(inputs: Union[nn.Module, Iterator[torch.Tensor]]) -> float: 8 | if isinstance(inputs, nn.Module): 9 | inputs = inputs.parameters() 10 | s = float("-inf") 11 | for p in inputs: 12 | if p.grad is not None: 13 | s = max(s, p.grad.data.abs().max().item()) 14 | return s 15 | 16 | 17 | def grad_norm(inputs: Union[nn.Module, Iterator[torch.Tensor]], norm_type: float = 2.) -> float: 18 | if norm_type == float('inf'): 19 | return grad_norm_inf(inputs) 20 | 21 | if isinstance(inputs, nn.Module): 22 | inputs = inputs.parameters() 23 | s = torch.tensor(0., device=next(inputs).device) 24 | for p in inputs: 25 | if p.grad is not None: 26 | s += p.grad.data.pow(norm_type).sum() 27 | return s.pow(1. / norm_type).item() 28 | 29 | 30 | def probe_output_shape(model: nn.Module, input_shape): 31 | inputs = torch.ones(1, *input_shape, device=device_of(model)) 32 | return model(inputs).shape[1:] 33 | 34 | 35 | def make_xyz(fn, left, right, bottom, top, steps, exp): 36 | assert left < right 37 | assert bottom < top 38 | xs = torch.linspace(left, right, steps=steps) 39 | ys = torch.linspace(bottom, top, steps=steps) 40 | xs, ys = torch.meshgrid([xs, ys]) 41 | xs, ys = xs.flatten().unsqueeze(dim=-1), ys.flatten().unsqueeze(dim=-1) 42 | inputs = torch.cat([xs, ys], dim=-1).to(global_device()) 43 | zs = fn(inputs) 44 | if exp: 45 | zs = (zs - zs.max()).exp() 46 | xs, ys, = xs.view(steps, steps).detach().cpu().numpy(), ys.view(steps, steps).detach().cpu().numpy() 47 | zs = zs.view(steps, steps).detach().cpu().numpy() 48 | return xs, ys, zs 49 | 50 | 51 | def tensor_info(ts, label): 52 | print(label, '{:.4f}'.format(ts.max().item()), '{:.4f}'.format(ts.min().item()), '{:.4f}'.format(ts.mean().item())) 53 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def ema(model_dest: nn.Module, model_src: nn.Module, rate): 5 | param_dict_src = dict(model_src.named_parameters()) 6 | for p_name, p_dest in model_dest.named_parameters(): 7 | p_src = param_dict_src[p_name] 8 | assert p_src is not p_dest 9 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) 10 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/core/utils/managers.py: -------------------------------------------------------------------------------- 1 | r""" Sometimes, we need manage multiple pytorch objects in a script, e.g., multiple models, multiple optimizers 2 | Manager provide a interface to manage them together 3 | """ 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from .ema import ema 7 | import logging 8 | 9 | 10 | class Manager(object): 11 | def __init__(self, **kwargs): 12 | r""" Manage a dict of objects 13 | """ 14 | for key, obj in kwargs.items(): 15 | self.__setattr__(key, obj) 16 | 17 | def __contains__(self, key): 18 | return key in self.__dict__ 19 | 20 | def get(self, key: str): 21 | assert isinstance(key, str) 22 | return self.__getattribute__(key) 23 | 24 | def load_state(self, key: str, state): 25 | r""" 26 | Args: 27 | key: the key of the object 28 | state: the state of the object 29 | """ 30 | assert isinstance(key, str) 31 | logging.info("load {}".format(key)) 32 | self.__dict__[key].load_state_dict(state) 33 | 34 | def get_state(self, key: str): 35 | assert isinstance(key, str) 36 | return self.__dict__[key].state_dict() 37 | 38 | def load_states(self, states: dict, *keys): 39 | r""" 40 | Args: 41 | states: a dict of states of objects 42 | keys: the keys of objects 43 | If empty, load states for all objects 44 | """ 45 | assert all(map(lambda x: isinstance(x, str), keys)) 46 | if len(keys) == 0: 47 | keys = list(self.__dict__.keys()) 48 | for key in keys: 49 | if key in states: 50 | self.load_state(key, states[key]) 51 | 52 | def get_states(self, *keys): 53 | r""" 54 | Args: 55 | keys: the keys of objects 56 | If empty, return the states of all objects 57 | """ 58 | assert all(map(lambda x: isinstance(x, str), keys)) 59 | if len(keys) == 0: 60 | keys = list(self.__dict__.keys()) 61 | states = {} 62 | for key in keys: 63 | states[key] = self.get_state(key) 64 | return states 65 | 66 | 67 | class ModelsManager(Manager): 68 | def __init__(self, **kwargs): 69 | r""" Manage a dict of models (nn.Modules) 70 | """ 71 | for key, model in kwargs.items(): 72 | assert isinstance(model, nn.Module) 73 | super(ModelsManager, self).__init__(**kwargs) 74 | 75 | def parameters(self, *keys): 76 | r""" Return the parameters of models corresponding to keys 77 | If keys are empty, return the parameters of all models 78 | 79 | Args: 80 | keys: the keys of models 81 | If empty, return the parameters of all models 82 | """ 83 | assert all(map(lambda x: isinstance(x, str), keys)) 84 | if len(keys) == 0: 85 | keys = list(self.__dict__.keys()) 86 | params = [] 87 | for key in keys: 88 | params += self.__dict__[key].parameters() 89 | return params 90 | 91 | def toggle_grad(self, *keys): 92 | r""" Open the gradient of models corresponding to keys 93 | Others' gradients will be closed 94 | 95 | Args: 96 | keys: the keys of models 97 | """ 98 | assert all(map(lambda x: isinstance(x, str), keys)) 99 | for key, model in self.__dict__.items(): 100 | model.requires_grad_(key in keys) 101 | 102 | def to(self, device): 103 | for key, model in self.__dict__.items(): 104 | model.to(device) 105 | 106 | def train(self): 107 | for key, model in self.__dict__.items(): 108 | model.train() 109 | 110 | def eval(self): 111 | for key, model in self.__dict__.items(): 112 | model.eval() 113 | 114 | def ema(self, src, *keys, rate): 115 | r""" Exponential moving average 116 | theta <- beta * theta + (1 - beta) * theta_src 117 | 118 | Args: 119 | src: the source model 120 | keys: the keys of models 121 | If empty, update parameters of all models 122 | rate: theta <- rate * theta + (1 - rate) * theta_src 123 | """ 124 | assert isinstance(src, ModelsManager) 125 | assert all(map(lambda x: isinstance(x, str), keys)) 126 | if len(keys) == 0: 127 | keys = list(self.__dict__.keys()) 128 | for key in keys: 129 | ema(self.__dict__[key], src.__dict__[key], rate) 130 | 131 | 132 | class OptimizersManager(Manager): 133 | def __init__(self, **kwargs): 134 | r""" Manage a dict of optimizers (optim.Optimizer) 135 | """ 136 | assert all(map(lambda obj: isinstance(obj, optim.Optimizer), kwargs.values())) 137 | super(OptimizersManager, self).__init__(**kwargs) 138 | 139 | 140 | class LRSchedulersManager(Manager): 141 | def __init__(self, **kwargs): 142 | r""" Manage a dict of lr_schedulers (optim.lr_scheduler._LRScheduler) 143 | """ 144 | assert all(map(lambda obj: isinstance(obj, optim.lr_scheduler._LRScheduler), kwargs.values())) 145 | super(LRSchedulersManager, self).__init__(**kwargs) 146 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/dp_vb.py: -------------------------------------------------------------------------------- 1 | # Get NLL for the baseline (OT, DDPM, $\sigma_n^2 = \beta_n$) 2 | 3 | import torch 4 | import os 5 | import math 6 | import numpy as np 7 | 8 | 9 | def make_inf(F): # make F[s, t] = inf for s >= t 10 | return np.triu(F, 1) + np.tril(np.full(F.shape, float('inf'))) 11 | 12 | 13 | def vectorized_dp(F, N): # F[s, t] with 0 <= s < t <= N 14 | F = make_inf(F[: N + 1, : N + 1]) 15 | 16 | C = np.full((N + 1, N + 1), float('inf')) 17 | D = np.full((N + 1, N + 1), -1) 18 | 19 | C[0, 0] = 0 20 | for k in range(1, N + 1): 21 | bpds = C[k - 1, :].reshape(N + 1, 1) + F 22 | C[k] = np.min(bpds, axis=0) 23 | D[k] = np.argmin(bpds, axis=0) 24 | 25 | return D 26 | 27 | 28 | def fetch_path(D, N, K): # find a path of length K (K+1 nodes) 29 | optpath = [] 30 | t = N 31 | for k in reversed(range(K + 1)): 32 | optpath.append(t) 33 | t = D[k, t] 34 | return optpath[::-1] 35 | 36 | 37 | @ torch.no_grad() 38 | def nelbo_dp_ddpm(D_train, test_nll_terms, N, K, trajectory): 39 | if trajectory == "dp": 40 | ns = fetch_path(D_train, N, K) 41 | else: 42 | raise NotImplementedError 43 | 44 | nelbo = 0. 45 | rev_terms = [] 46 | 47 | term = test_nll_terms['last_term'] 48 | nelbo += term 49 | rev_terms.append(term) 50 | 51 | for s, r in list(zip(ns, ns[1:]))[::-1]: 52 | term = test_nll_terms['F'][s, r] 53 | nelbo += term 54 | rev_terms.append(term) 55 | 56 | return nelbo, rev_terms[::-1] 57 | 58 | 59 | def main(): 60 | dataset = "celeba" 61 | trajectory = "dp" 62 | 63 | if dataset == "cifar10_ls": 64 | root = 'workspace/runner/cifar10/ddpm_dsm/betas_2021-08-19-23-55-32/beta_schedule_linear_num_diffusion_1000/train/nll_terms' 65 | train_nll_terms_name = '400000_small_sigma_False_n_samples_None_partition_train.nll_terms.pth' 66 | test_nll_terms_name = '400000_small_sigma_False_n_samples_None_partition_test.nll_terms.pth' 67 | c = 3 * 32 * 32 * math.log(2.) 68 | elif dataset == "cifar10_cs": 69 | root = 'workspace/runner/cifar10/ddpm_dsm/betas_2021-08-19-23-55-32/beta_schedule_cosine_num_diffusion_1000/train/nll_terms' 70 | train_nll_terms_name = '160000_small_sigma_False_n_samples_None_partition_train.nll_terms.pth' 71 | test_nll_terms_name = '160000_small_sigma_False_n_samples_None_partition_test.nll_terms.pth' 72 | c = 3 * 32 * 32 * math.log(2.) 73 | elif dataset == "imagenet": 74 | root = "workspace/runner/imagenet64/improved_diffusion/L_hybrid_2021-08-29-20-26-00/cosine4000/train/nll_terms" 75 | train_nll_terms_name = 'imagenet64_uncond_100M_1500K_small_sigma_False_n_samples_16384_partition_train.nll_terms.pth' 76 | test_nll_terms_name = 'imagenet64_uncond_100M_1500K_small_sigma_False_n_samples_None_partition_test.nll_terms.pth' 77 | c = 3 * 64 * 64 * math.log(2.) 78 | elif dataset == 'celeba': 79 | root = '../celeba_lsun_codes/nelbo_terms/ema_celeba' 80 | train_nll_terms_name = 'partition_train_n_samples_16384.pth' 81 | test_nll_terms_name = 'partition_test_n_samples_None.pth' 82 | c = 3 * 64 * 64 * math.log(2.) 83 | else: 84 | raise ValueError 85 | 86 | train_nll_terms = torch.load(os.path.join(root, train_nll_terms_name)) 87 | F_train = train_nll_terms['F'] 88 | N = len(F_train) - 1 89 | if trajectory == "dp": 90 | D_train = vectorized_dp(F_train, N) 91 | else: 92 | D_train = None 93 | 94 | for sample_steps in sorted({10, 25, 50, 100, 200, 400, 1000, N}): 95 | test_nll_terms = torch.load(os.path.join(root, test_nll_terms_name)) 96 | 97 | nelbo, terms = nelbo_dp_ddpm(D_train, test_nll_terms, N, sample_steps, trajectory) 98 | nelbo_bpd = nelbo / c 99 | terms_bpd = [a / c for a in terms] 100 | 101 | print('sample_steps', sample_steps, 'bpd/continuous_part/discrete_part', 102 | '{:.2f}/{:.2f}/{:.2f}'.format(nelbo_bpd, sum(terms_bpd[1:]), terms_bpd[0])) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/__init__.py: -------------------------------------------------------------------------------- 1 | r""" An interface to use algorithms and models in core 2 | """ 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import Cifar10 2 | from .imagenet64 import Imagenet64 3 | from .dataset_factory import DatasetFactory 4 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from torchvision import datasets 3 | import torchvision.transforms as transforms 4 | from .dataset_factory import DatasetFactory 5 | from .utils import * 6 | 7 | 8 | class Cifar10(DatasetFactory): 9 | r""" Cifar10 dataset 10 | 11 | Information of the raw dataset: 12 | train: 40,000 13 | val: 10,000 14 | test: 10,000 15 | shape: 3 * 32 * 32 16 | """ 17 | 18 | def __init__(self, data_path, gauss_noise=False, noise_std=0.01): 19 | super(Cifar10, self).__init__() 20 | self.data_path = data_path 21 | self.gauss_noise = gauss_noise 22 | self.noise_std = noise_std 23 | 24 | _transform = [transforms.ToTensor()] 25 | if self.gauss_noise: 26 | _transform.append(AddGaussNoise(self.noise_std)) 27 | im_transform = transforms.Compose(_transform) 28 | self.train_val = datasets.CIFAR10(self.data_path, train=True, transform=im_transform, download=True) 29 | self.train = Subset(self.train_val, list(range(40000))) 30 | self.val = Subset(self.train_val, list(range(40000, 50000))) 31 | self.test = datasets.CIFAR10(self.data_path, train=False, transform=im_transform, download=True) 32 | 33 | def affine_transform(self, dataset): 34 | return StandardizedDataset(dataset, mean=0.5, std=0.5) # scale to [-1, 1] 35 | 36 | def preprocess(self, v): 37 | return 2. * (v - 0.5) 38 | 39 | def unpreprocess(self, v): 40 | v = 0.5 * (v + 1.) 41 | v.clamp_(0., 1.) 42 | return v 43 | 44 | @property 45 | def data_shape(self): 46 | return 3, 32, 32 47 | 48 | @property 49 | def fid_stat(self): 50 | return 'workspace/fid_stats/fid_stats_cifar10_train_pytorch.npz' 51 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | from .utils import is_labelled, UnlabeledDataset 2 | from torch.utils.data import ConcatDataset 3 | import numpy as np 4 | 5 | 6 | class DatasetFactory(object): 7 | r""" Output dataset after two transformations to the raw data: 8 | 1. distribution transform (e.g. binarized, adding noise), often irreversible, a part of which is implemented 9 | in distribution_transform 10 | 2. an affine transform (preprocess), which is bijective 11 | """ 12 | 13 | def __init__(self): 14 | self.train = None 15 | self.val = None 16 | self.test = None 17 | 18 | def allow_labelled(self): 19 | return is_labelled(self.train) 20 | 21 | def get_data(self, dataset, labelled): 22 | assert not (not is_labelled(dataset) and labelled) 23 | if is_labelled(dataset) and not labelled: 24 | dataset = UnlabeledDataset(dataset) 25 | return self.affine_transform(self.distribution_transform(dataset)) 26 | 27 | def get_train_data(self, labelled=False): 28 | return self.get_data(self.train, labelled=labelled) 29 | 30 | def get_val_data(self, labelled=False): 31 | return self.get_data(self.val, labelled=labelled) 32 | 33 | def get_train_val_data(self, labelled=False): 34 | train_val = ConcatDataset([self.train, self.val]) 35 | return self.get_data(train_val, labelled=labelled) 36 | 37 | def get_test_data(self, labelled=False): 38 | return self.get_data(self.test, labelled=labelled) 39 | 40 | def distribution_transform(self, dataset): 41 | return dataset 42 | 43 | def affine_transform(self, dataset): 44 | return dataset 45 | 46 | def preprocess(self, v): 47 | r""" The mathematical form of the affine transform 48 | """ 49 | return v 50 | 51 | def unpreprocess(self, v): 52 | r""" The mathematical form of the affine transform's inverse 53 | """ 54 | return v 55 | 56 | @property 57 | def data_shape(self): 58 | raise NotImplementedError 59 | 60 | @property 61 | def data_dim(self): 62 | return int(np.prod(self.data_shape)) 63 | 64 | @property 65 | def fid_stat(self): 66 | return None 67 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/datasets/imagenet64.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import torchvision.transforms as transforms 4 | from .dataset_factory import DatasetFactory 5 | from .utils import * 6 | 7 | 8 | class ImageDataset(Dataset): 9 | def __init__(self, path): 10 | super().__init__() 11 | names = os.listdir(path) 12 | self.local_images = [os.path.join(path, name) for name in names] 13 | self._transform = transforms.ToTensor() 14 | 15 | def __len__(self): 16 | return len(self.local_images) 17 | 18 | def __getitem__(self, idx): 19 | X = Image.open(self.local_images[idx]) 20 | X = self._transform(X) 21 | return X 22 | 23 | 24 | class Imagenet64(DatasetFactory): 25 | r""" Imagenet64 dataset 26 | 27 | Information of the raw dataset: 28 | train: 1,281,149 29 | test: 49,999 30 | shape: 3 * 64 * 64 31 | """ 32 | 33 | def __init__(self, path): 34 | super().__init__() 35 | self.train = ImageDataset(os.path.join(path, 'train_64x64')) 36 | self.test = ImageDataset(os.path.join(path, 'valid_64x64')) 37 | 38 | def affine_transform(self, dataset): 39 | return StandardizedDataset(dataset, mean=0.5, std=0.5) # scale to [-1, 1] 40 | 41 | def preprocess(self, v): 42 | return 2. * (v - 0.5) 43 | 44 | def unpreprocess(self, v): 45 | v = 0.5 * (v + 1.) 46 | v.clamp_(0., 1.) 47 | return v 48 | 49 | @property 50 | def data_shape(self): 51 | return 3, 64, 64 52 | 53 | @property 54 | def fid_stat(self): 55 | return 'workspace/fid_stats/fid_stats_imagenet64_train.npz' 56 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | def pad22pow(a): 6 | assert a % 2 == 0 7 | bits = a.bit_length() 8 | ub = 2 ** bits 9 | pad = (ub - a) // 2 10 | return pad, ub 11 | 12 | 13 | def is_labelled(dataset): 14 | labelled = False 15 | if isinstance(dataset[0], tuple) and len(dataset[0]) == 2: 16 | labelled = True 17 | return labelled 18 | 19 | 20 | class AddGaussNoise(object): 21 | def __init__(self, std): 22 | self.std = std 23 | 24 | def __call__(self, tensor): 25 | return tensor + self.std * torch.rand_like(tensor).to(tensor.device) 26 | 27 | 28 | class UnlabeledDataset(Dataset): 29 | def __init__(self, dataset): 30 | self.dataset = dataset 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, item): 36 | x, y = self.dataset[item] 37 | return x 38 | 39 | 40 | class StandardizedDataset(Dataset): 41 | def __init__(self, dataset, mean, std): 42 | self.dataset = dataset 43 | self.mean = mean 44 | self.std = std 45 | self.std_inv = 1. / std 46 | self.labelled = is_labelled(dataset) 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | def __getitem__(self, item): 52 | if self.labelled: 53 | x, y = self.dataset[item] 54 | return self.std_inv * (x - self.mean), y 55 | else: 56 | x = self.dataset[item] 57 | return self.std_inv * (x - self.mean) 58 | 59 | 60 | class QuickDataset(Dataset): 61 | def __init__(self, array): 62 | self.array = array 63 | 64 | def __len__(self): 65 | return len(self.array) 66 | 67 | def __getitem__(self, item): 68 | return self.array[item] 69 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddpm_evaluator import DDPMNaiveEvaluator, ImprovedDDPMEvaluator 2 | from .base import Evaluator 3 | 4 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/evaluators/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Evaluator(object): 4 | def __init__(self, options: dict): 5 | r""" Evaluate models 6 | """ 7 | self.options = options 8 | 9 | def evaluate_train(self, it): 10 | r""" Evaluate during training 11 | Args: 12 | it: the iteration of training 13 | """ 14 | for fn, val in self.options.items(): 15 | period = val["period"] 16 | kwargs = val.get("kwargs", {}) 17 | if it % period == 0: 18 | eval("self.%s" % fn)(it=it, **kwargs) 19 | 20 | def evaluate(self, it=None): 21 | r""" 22 | Args: 23 | it: the iteration when the evaluated models is saved 24 | """ 25 | if it is None: 26 | it = 0 27 | for fn, val in self.options.items(): 28 | kwargs = val.get("kwargs", {}) 29 | eval("self.%s" % fn)(it=it, **kwargs) 30 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/evaluators/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def linear_interpolate(a, b, steps): 5 | a_shape = a.shape 6 | a = a.detach().cpu().view(-1) 7 | b = b.detach().cpu().view(-1) 8 | res = [] 9 | for aa, bb in zip(a, b): 10 | res.append(torch.linspace(aa, bb, steps=steps).unsqueeze(dim=1)) 11 | res = torch.cat(res, dim=1) 12 | res = res.view(len(res), *a_shape) 13 | return res 14 | 15 | 16 | def rect_interpolate(a, b, c, steps): 17 | a = a.detach().cpu() 18 | b = b.detach().cpu() 19 | c = c.detach().cpu() 20 | ab = linear_interpolate(a, b, steps) - a 21 | ac = linear_interpolate(a, c, steps) - a 22 | res = [] 23 | for st in ac: 24 | res.append(ab + st) 25 | res = torch.cat(res, dim=0) + a 26 | return res 27 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner import * 2 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/runner/fit.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from interface.utils.ckpt import CKPT 3 | import os 4 | import logging 5 | import math 6 | from core.utils.managers import ModelsManager 7 | 8 | 9 | def infinite_loader(dataset, batch_size): 10 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 11 | while True: 12 | for data in loader: 13 | yield data 14 | 15 | 16 | def check_anomaly(statistics: dict, it: int): 17 | for k, v in statistics.items(): 18 | if math.isnan(float(v)): 19 | statistics_str = {k: "{:.5e}".format(float(statistics[k])) for k in statistics} 20 | logging.info('Exit at it {}, {}'.format(it, statistics_str)) 21 | exit(0) 22 | 23 | 24 | def naive_fit(criterion, train_dataset, batch_size, n_its, n_ckpts, ckpt_root, interact, 25 | evaluator=None, val_dataset=None, val_fn=None, ckpt=None, 26 | ema_models: ModelsManager = None, ema_rate=None): 27 | r""" Loops of Learning 28 | Args: 29 | criterion: a Criterion instance 30 | train_dataset: the training dataset 31 | batch_size: the batch size of training 32 | n_its: the number of iterations 33 | n_ckpts: the number of ckpts to save 34 | ckpt_root: the directory root of ckpts 35 | interact: an Interact instance 36 | evaluator: an Evaluator instance 37 | val_dataset: the validation dataset 38 | val_fn: the function used for validation 39 | ckpt: a CKPT instance 40 | ema_models: the exponential moving average models 41 | ema_rate: theta <- rate * theta + (1 - rate) * theta_src 42 | """ 43 | os.makedirs(ckpt_root, exist_ok=True) 44 | criterion.models.to(criterion.device) 45 | if ema_models is not None: 46 | ema_models.to(criterion.device) 47 | 48 | it = 0 49 | best_val_loss = float('inf') # the smaller the better 50 | if ckpt is not None: 51 | it = ckpt.it 52 | best_val_loss = ckpt.best_val_loss 53 | ckpt.to_criterion(criterion) 54 | if ema_models is not None: 55 | ckpt.to_ema_models(ema_models) 56 | del ckpt 57 | 58 | logging.info("Start fitting, it=%d" % it) 59 | 60 | train_dataset_loader = infinite_loader(train_dataset, batch_size=batch_size) 61 | period = n_its // n_ckpts # the period of saving ckpts 62 | 63 | while it < n_its: 64 | criterion.models.train() 65 | criterion.update(train_dataset_loader) 66 | if ema_models is not None: # exponential moving average 67 | ema_models.ema(criterion.models, rate=ema_rate) 68 | it += 1 69 | 70 | interact.report_train(criterion.statistics, it) 71 | check_anomaly(criterion.statistics, it) 72 | if evaluator is not None: 73 | criterion.models.eval() 74 | evaluator.evaluate_train(it) 75 | 76 | if it % period == 0 or it == n_its: 77 | criterion.models.eval() 78 | if val_dataset is not None and val_fn is not None: # validation 79 | loss = val_fn(dataset=val_dataset) 80 | interact.report_val(loss, it) 81 | if loss < best_val_loss: 82 | CKPT(models_states=criterion.models.get_states()). \ 83 | save(os.path.join(ckpt_root, "best.pth")) # update the best model 84 | best_val_loss = loss 85 | CKPT(it=it, best_val_loss=best_val_loss, ema_models_states=None if ema_models is None else ema_models.get_states()).\ 86 | from_criterion(criterion).save(os.path.join(ckpt_root, "%d.ckpt.pth" % it)) # save ckpt 87 | 88 | logging.info("Finish fitting, it=%d" % it) 89 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/runner/runner.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["run_train_profile", "run_evaluate_profile", "run_timing_profile"] 3 | 4 | 5 | from interface.runner.fit import naive_fit 6 | from interface.runner.timing import timing 7 | from interface.utils import set_seed, set_deterministic, backup_codes, backup_profile 8 | from interface.utils import ckpt, profile_utils, dict_utils 9 | from interface.utils.dict_utils import get_val 10 | from core.utils import global_device 11 | from core.utils.managers import ModelsManager 12 | 13 | 14 | def merge_models(dest: ModelsManager, src: ModelsManager): 15 | if src is None: 16 | return dest 17 | _dict = {key: dest.__dict__[key] for key in dest.__dict__.keys()} 18 | for key, val in src.__dict__.items(): 19 | _dict[key] = val 20 | return ModelsManager(**_dict) 21 | 22 | 23 | def run_train_profile(profile: dict): 24 | r""" 25 | Args: 26 | profile: a parsed or unparsed profile 27 | """ 28 | profile = dict_utils.parse_self_ref_dict(profile) 29 | set_seed(get_val(profile, "seed", default=None)) 30 | set_deterministic(get_val(profile, "deterministic", default=False)) 31 | backup_codes(profile["backup_root"]) 32 | backup_profile(profile, profile["backup_root"]) 33 | 34 | interact = profile_utils.create_interact(profile["interact"]) 35 | interact.report_machine() 36 | 37 | models = profile_utils.create_models(profile["models"]) 38 | ema_models, ema_rate = profile_utils.create_ema(profile, models) 39 | optimizers = profile_utils.create_optimizers(profile["optimizers"], models) 40 | lr_schedulers = profile_utils.create_lr_schedulers(profile.get("lr_schedulers", {}), optimizers) 41 | criterion = profile_utils.create_criterion(profile["criterion"], models, optimizers, lr_schedulers) 42 | 43 | dataset = profile_utils.create_dataset(profile["dataset"]) 44 | if profile["dataset"].get("use_val", True): 45 | train_dataset = dataset.get_train_data() 46 | val_dataset = dataset.get_val_data() 47 | else: 48 | train_dataset = dataset.get_train_val_data() 49 | val_dataset = None 50 | 51 | evaluator = None 52 | if "evaluator" in profile: 53 | evaluator = profile_utils.create_evaluator(profile["evaluator"], merge_models(models, ema_models), dataset, interact) 54 | 55 | naive_fit(criterion=criterion, 56 | train_dataset=train_dataset, 57 | batch_size=get_val(profile, "training", "batch_size"), 58 | n_its=get_val(profile, "training", "n_its"), 59 | n_ckpts=get_val(profile, "training", "n_ckpts", default=10), 60 | ckpt_root=profile["ckpt_root"], 61 | interact=interact, 62 | evaluator=evaluator, 63 | val_dataset=val_dataset, 64 | val_fn=profile_utils.create_val_fn(profile, criterion), 65 | ckpt=ckpt.get_ckpt_by_it(profile["ckpt_root"]), 66 | ema_models=ema_models, 67 | ema_rate=ema_rate 68 | ) 69 | 70 | 71 | def run_evaluate_profile(profile: dict): 72 | r""" 73 | Args: 74 | profile: a parsed or unparsed profile 75 | """ 76 | profile = dict_utils.parse_self_ref_dict(profile) 77 | set_seed(get_val(profile, "seed", default=None)) 78 | set_deterministic(get_val(profile, "deterministic", default=False)) 79 | backup_codes(profile["backup_root"]) 80 | backup_profile(profile, profile["backup_root"]) 81 | 82 | interact = profile_utils.create_interact(profile["interact"]) 83 | interact.report_machine() 84 | 85 | models = profile_utils.create_models(profile["models"]) 86 | ckpt_path = profile['ckpt_path'] if isinstance(profile['ckpt_path'], list) else [profile['ckpt_path']] 87 | for path in ckpt_path: 88 | if profile.get("ema", False): 89 | ckpt.CKPT().load(path).to_ema_models(models) 90 | else: 91 | ckpt.CKPT().load(path).to_models(models) 92 | dataset = profile_utils.create_dataset(profile["dataset"]) 93 | 94 | evaluator = profile_utils.create_evaluator(profile["evaluator"], models, dataset, interact) 95 | models.to(global_device()) 96 | models.eval() 97 | evaluator.evaluate() 98 | 99 | 100 | def run_timing_profile(profile: dict): 101 | r""" 102 | Args: 103 | profile: a parsed or unparsed profile 104 | """ 105 | profile = dict_utils.parse_self_ref_dict(profile) 106 | set_seed(get_val(profile, "seed", default=None)) 107 | set_deterministic(get_val(profile, "deterministic", default=False)) 108 | backup_codes(profile["backup_root"]) 109 | backup_profile(profile, profile["backup_root"]) 110 | 111 | interact = profile_utils.create_interact(profile["interact"]) 112 | interact.report_machine() 113 | 114 | models = profile_utils.create_models(profile["models"]) 115 | optimizers = profile_utils.create_optimizers(profile["optimizers"], models) 116 | lr_schedulers = profile_utils.create_lr_schedulers(profile.get("lr_schedulers", {}), optimizers) 117 | criterion = profile_utils.create_criterion(profile["criterion"], models, optimizers, lr_schedulers) 118 | 119 | dataset = profile_utils.create_dataset(profile["dataset"]) 120 | if profile["dataset"].get("use_val", True): 121 | train_dataset = dataset.get_train_data() 122 | else: 123 | train_dataset = dataset.get_train_val_data() 124 | 125 | timing(criterion=criterion, 126 | train_dataset=train_dataset, 127 | batch_size=get_val(profile, "training", "batch_size"), 128 | n_its=get_val(profile, "training", "n_its") 129 | ) 130 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/runner/timing.py: -------------------------------------------------------------------------------- 1 | from .fit import infinite_loader 2 | import logging 3 | import time 4 | from interface.utils.task_schedule import gpu_memory_consumption 5 | 6 | 7 | def timing(criterion, train_dataset, batch_size, n_its): 8 | r""" Loops of Learning 9 | Args: 10 | criterion: a Criterion instance 11 | train_dataset: the training dataset 12 | batch_size: the batch size of training 13 | n_its: the number of iterations 14 | """ 15 | criterion.models.to(criterion.device) 16 | 17 | it = 0 18 | train_dataset_loader = infinite_loader(train_dataset, batch_size=batch_size) 19 | 20 | st = time.time() 21 | while it < n_its: 22 | criterion.models.train() 23 | criterion.update(train_dataset_loader) 24 | it += 1 25 | ed = time.time() 26 | 27 | logging.info("%d iterations take %.2f s" % (n_its, ed - st)) 28 | logging.info("Taking %d MB" % gpu_memory_consumption()) 29 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dict_utils 2 | from . import profile_utils 3 | from . import ckpt 4 | from . import interact 5 | from .reproducibility import * 6 | from . import misc 7 | from . import exp_templates 8 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import core.utils.managers as managers 4 | import logging 5 | 6 | 7 | class CKPT(object): 8 | def __init__(self, it=None, best_val_loss=None, models_states=None, optimizers_states=None, lr_schedulers_states=None, 9 | ema_models_states=None): 10 | r""" Record the states of training 11 | Args: 12 | it: iteration 13 | best_val_loss: the best validation loss 14 | models_states: a dict of state_dicts of models 15 | optimizers_states: a dict of state_dicts of optimizers 16 | lr_schedulers_states: a dict of state_dicts of lr_schedulers 17 | ema_models_states: a dict of state_dicts of ema_models 18 | """ 19 | self.it = it 20 | self.best_val_loss = best_val_loss 21 | self.models_states = models_states 22 | self.optimizers_states = optimizers_states 23 | self.lr_schedulers_states = lr_schedulers_states 24 | self.ema_models_states = ema_models_states 25 | 26 | def save(self, fname): 27 | logging.info("save ckpt to {}".format(fname)) 28 | torch.save(self.__dict__, fname) 29 | 30 | def load(self, fname): 31 | logging.info("load ckpt from {}".format(fname)) 32 | ckpt = torch.load(fname) 33 | for k, val in ckpt.items(): 34 | self.__dict__[k] = val 35 | return self 36 | 37 | def from_criterion(self, criterion): 38 | self.models_states = criterion.models.get_states() 39 | self.optimizers_states = criterion.optimizers.get_states() 40 | self.lr_schedulers_states = criterion.lr_schedulers.get_states() 41 | return self 42 | 43 | def to_criterion(self, criterion): 44 | criterion.models.load_states(self.models_states) 45 | criterion.optimizers.load_states(self.optimizers_states) 46 | criterion.lr_schedulers.load_states(self.lr_schedulers_states) 47 | 48 | def to_models(self, models: managers.ModelsManager): 49 | logging.info("load models_states") 50 | models.load_states(self.models_states) 51 | 52 | def to_ema_models(self, ema_models: managers.ModelsManager): 53 | logging.info("load ema_models_states") 54 | ema_models.load_states(self.ema_models_states) 55 | 56 | 57 | def list_ckpts(ckpt_root): 58 | fnames = list(filter(lambda x: x.endswith(".ckpt.pth"), os.listdir(ckpt_root))) 59 | fnames = sorted(fnames, key=lambda x: int(x.split(".")[0])) 60 | return fnames 61 | 62 | 63 | def get_ckpt_by_it(ckpt_root, it=None): 64 | r""" Get the ckpt at a iteration 'it' from ckpt_root 65 | If 'it' is None, try to get the latest ckpt 66 | If 'it' is None and there is no ckpt, return None 67 | 68 | Args: 69 | ckpt_root: the root of ckpts 70 | it: the iteration 71 | """ 72 | if not os.path.exists(ckpt_root): 73 | return None 74 | if it is None: 75 | fnames = list_ckpts(ckpt_root) 76 | if fnames: 77 | return CKPT().load(os.path.join(ckpt_root, fnames[-1])) 78 | else: 79 | return None 80 | else: 81 | return CKPT().load(os.path.join(ckpt_root, "%d.ckpt.pth" % it)) 82 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/dict_utils.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['merge_dict_', 'merge_dict', 'get_val', 'parse_self_ref_dict', 'single_chain_dict'] 3 | 4 | import copy 5 | import re 6 | import os 7 | import datetime 8 | 9 | 10 | ################################################################################ 11 | # Recursively merge dict 12 | ################################################################################ 13 | 14 | def merge_dict_(dest: dict, src: dict): 15 | r""" Merge src to dest, inplace 16 | """ 17 | for k, val in src.items(): 18 | if isinstance(val, dict) and k in dest.keys() and isinstance(dest[k], dict): 19 | merge_dict_(dest[k], src[k]) 20 | else: 21 | dest[k] = val 22 | 23 | 24 | def merge_dict(dest: dict, src: dict): 25 | r""" Merge src to dest 26 | """ 27 | dest = copy.deepcopy(dest) 28 | merge_dict_(dest, src) 29 | return dest 30 | 31 | 32 | ################################################################################ 33 | # Recursively get value in a dict 34 | ################################################################################ 35 | 36 | def get_val(dct: dict, *fields, **kwargs): 37 | r""" 38 | Args: 39 | dct: a dict 40 | fields: the keys 41 | """ 42 | try: 43 | cur = dct 44 | for field in fields: 45 | cur = cur[field] 46 | return cur 47 | except Exception as e: 48 | if "default" in kwargs.keys(): 49 | return kwargs["default"] 50 | else: 51 | raise e 52 | 53 | 54 | ################################################################################ 55 | # Parse a self-reference dict 56 | ################################################################################ 57 | 58 | def _is_reference(s: str): 59 | return isinstance(s, str) and '$' in s 60 | 61 | 62 | def _parse_term(term: str, dct_name: str): 63 | r""" '$(a.b.c.d)' -> "dct_name['a']['b']['c']['d']" 64 | Args: 65 | term: a term, e.g., '$(a)', '$(a.b.c.d)' 66 | dct_name: the name of the dict 67 | """ 68 | assert isinstance(term, str) and isinstance(dct_name, str) 69 | keys = term[2:-1].split('.') 70 | res = "" 71 | for key in keys: 72 | res += "['%s']" % key 73 | return dct_name + res 74 | 75 | 76 | def _parse_reference(ref: str, dct_name: str): 77 | r""" '$(a.b) // 10' -> "dct_name['a']['b'] // 10" 78 | Args: 79 | ref: a reference, e.g., '$(a.b) // 10', '$(a) + $(b.c)' 80 | dct_name: the name of the dict 81 | """ 82 | assert isinstance(ref, str) and isinstance(dct_name, str) 83 | matches = list(re.finditer("\$\([^$()]*\)", ref)) 84 | for match in matches[::-1]: 85 | span = match.span() 86 | term = match.group() 87 | parsed_term = _parse_term(term, dct_name) 88 | ref = ref[:span[0]] + parsed_term + ref[span[1]:] 89 | return ref 90 | 91 | 92 | def _parse_self_ref_dict(local_dct: dict, global_dct: dict): 93 | for k, val in local_dct.items(): 94 | if _is_reference(val): 95 | local_dct[k] = eval(_parse_reference(val, 'global_dct')) 96 | elif isinstance(val, dict): 97 | _parse_self_ref_dict(val, global_dct) 98 | 99 | 100 | def parse_self_ref_dict(dct: dict): 101 | r""" Parse the auto-reference in a dict 102 | Only allow one-level reference 103 | 104 | Args: 105 | dct: a dict which might have self-reference 106 | """ 107 | dct = copy.deepcopy(dct) 108 | _parse_self_ref_dict(dct, dct) 109 | return dct 110 | 111 | 112 | ################################################################################ 113 | # Create dict 114 | ################################################################################ 115 | 116 | def single_chain_dict(key: str, val): 117 | sub_keys = key.split('.') 118 | for sub_key in sub_keys[::-1]: 119 | dct = {sub_key: val} 120 | val = dct 121 | return dct 122 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/interact.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from torch.utils.tensorboard import SummaryWriter 4 | import socket 5 | 6 | 7 | def set_logger(fname): 8 | logger = logging.getLogger() 9 | logger.setLevel(level=logging.INFO) 10 | handler1 = logging.StreamHandler() 11 | handler2 = logging.FileHandler(fname, mode='w') 12 | formatter = logging.Formatter('%(asctime)s - %(message)s') 13 | handler1.setFormatter(formatter) 14 | handler2.setFormatter(formatter) 15 | logger.addHandler(handler1) 16 | logger.addHandler(handler2) 17 | 18 | 19 | class Interact(object): 20 | def __init__(self, fname_log, summary_root=None, period=None, reported_keys=None): 21 | r""" 22 | Args: 23 | period: the period to report statistics 24 | """ 25 | self.fname_log = fname_log 26 | self.summary_root = summary_root 27 | self.period = period 28 | self.reported_keys = reported_keys 29 | os.makedirs(os.path.dirname(self.fname_log), exist_ok=True) 30 | set_logger(self.fname_log) 31 | 32 | self.writer = None 33 | if self.summary_root is not None: 34 | os.makedirs(self.summary_root, exist_ok=True) 35 | self.writer = SummaryWriter(self.summary_root) 36 | 37 | def report_train(self, statistics, it): 38 | if it % self.period == 0: 39 | if self.writer is not None: 40 | for k, v in statistics.items(): 41 | self.writer.add_scalar(k, v, global_step=it) 42 | reported_keys = statistics.keys() if self.reported_keys is None else self.reported_keys 43 | statistics_str = {k: "{:.5e}".format(float(statistics[k])) for k in reported_keys} 44 | logging.info("[train] [it: {}] [{}]".format(it, statistics_str)) 45 | 46 | def report_val(self, scalar, it): 47 | self.report_scalar(scalar, it, "val") 48 | 49 | def report_scalar(self, scalar, it, tag): 50 | if self.writer is not None: 51 | self.writer.add_scalar(tag, scalar, global_step=it) 52 | logging.info("[{}] [it: {}] [{:.5e}]".format(tag, it, scalar)) 53 | 54 | def report_machine(self): 55 | logging.info("running @ {}".format(socket.gethostname())) 56 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/misc.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["get_root_by_time", "sample_from_dataset"] 3 | 4 | 5 | import os 6 | import datetime 7 | import re 8 | import random 9 | 10 | 11 | def valid_prefix_time(prefix_time: str, prefix: str): 12 | res = re.search(r"%s_\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}" % prefix, prefix_time) 13 | return res is not None and res.span() == (0, len(prefix_time)) 14 | 15 | 16 | def get_root_by_time(path, prefix, time=None, strategy=None): 17 | r""" 18 | Args: 19 | path: the root is path/prefix_time 20 | prefix: the root is path/prefix_time 21 | time: a time tag with the format %Y-%m-%d-%H-%M-%S 22 | strategy: how to infer the time when time is None (only works when time is None ) 23 | """ 24 | assert strategy is None or strategy in ["latest", "now"] 25 | if time is None: 26 | if strategy == "latest": 27 | _all = filter(lambda s: valid_prefix_time(s, prefix), os.listdir(path)) 28 | latest = sorted(_all)[-1] 29 | prefix = os.path.join(path, latest) 30 | elif strategy == "now": 31 | time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 32 | prefix = os.path.join(path, prefix + "_" + time) 33 | else: 34 | raise ValueError 35 | else: 36 | prefix = os.path.join(path, prefix + "_" + time) 37 | return prefix 38 | 39 | 40 | def sample_from_dataset(dataset): 41 | idx = random.sample(range(len(dataset)), 1)[0] 42 | return dataset[idx] 43 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import os 5 | import torch 6 | matplotlib.use('Agg') 7 | 8 | 9 | class PlotContext(object): 10 | def __init__(self): 11 | fig, ax = plt.subplots(figsize=(10, 10)) 12 | fig.subplots_adjust(left=0.1, right=0.99, bottom=0.06, top=1.) 13 | ax.axis('equal') 14 | ax.margins(0) 15 | ax.tick_params(axis="both", labelsize=40) 16 | ax.locator_params(axis="both", nbins=3) 17 | self.cmap = plt.get_cmap('GnBu') 18 | self.fig = fig 19 | self.ax = ax 20 | 21 | def __enter__(self): 22 | return self.fig, self.ax, self.cmap 23 | 24 | def __exit__(self, exc_type, exc_val, exc_tb): 25 | plt.close(self.fig) 26 | 27 | 28 | def plot_density(xs, ys, density, fname): 29 | root, name = os.path.split(fname) 30 | os.makedirs(root, exist_ok=True) 31 | os.makedirs(os.path.join(root, "tensor"), exist_ok=True) 32 | torch.save((xs, ys, density), os.path.join(root, "tensor", "%s.pth" % name)) 33 | with PlotContext() as (fig, ax, cmap): 34 | ax.pcolormesh(xs, ys, density, cmap=cmap) 35 | fig.savefig(fname) 36 | 37 | 38 | def plot_scatter(samples, fname): 39 | root, name = os.path.split(fname) 40 | os.makedirs(root, exist_ok=True) 41 | os.makedirs(os.path.join(root, "tensor"), exist_ok=True) 42 | torch.save(samples, os.path.join(root, "tensor", "%s.pth" % name)) 43 | with PlotContext() as (fig, ax, cmap): 44 | ax.scatter(samples[:, 0], samples[:, 1], cmap=cmap) 45 | fig.savefig(fname) 46 | 47 | 48 | def plot_kde(samples, fname): 49 | root, name = os.path.split(fname) 50 | os.makedirs(root, exist_ok=True) 51 | os.makedirs(os.path.join(root, "tensor"), exist_ok=True) 52 | torch.save(samples, os.path.join(root, "tensor", "%s.pth" % name)) 53 | with PlotContext() as (fig, ax, cmap): 54 | sns.kdeplot(samples[:, 0], samples[:, 1], shade=True, cmap=cmap, ax=ax) 55 | fig.savefig(fname) 56 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/profile_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import traceback 3 | import core.utils.managers as managers 4 | import torch.optim as optim 5 | from .interact import Interact 6 | from .dict_utils import get_val 7 | from core.evaluate import score_on_dataset 8 | import functools 9 | import torch 10 | 11 | 12 | def _is_instance_profile(profile): 13 | r""" Judge whether the profile defines an instance of a class 14 | """ 15 | return isinstance(profile, dict) and "class" in profile 16 | 17 | 18 | def _create_instance_recursively(profile): 19 | assert _is_instance_profile(profile) 20 | kwargs = profile.get("kwargs", {}) 21 | for k, val in kwargs.items(): 22 | if _is_instance_profile(val): 23 | kwargs[k] = _create_instance_recursively(val) 24 | try: 25 | return profile["class"](**kwargs) 26 | except TypeError: 27 | traceback.print_exc() 28 | print(profile["class"]) 29 | exit(1) 30 | 31 | 32 | ################################################################################ 33 | # Create models from a profile 34 | ################################################################################ 35 | 36 | def create_model(profile: dict): 37 | r""" Create an instance of the model described in the profile 38 | Args: 39 | profile: a parsed profile describing the model 40 | """ 41 | return _create_instance_recursively(copy.deepcopy(profile)) 42 | 43 | 44 | def create_models(profile: dict): 45 | r""" Create models (an instance of ModelsManager) described in the profile 46 | Args: 47 | profile: a parsed profile describing models 48 | """ 49 | profile = copy.deepcopy(profile) 50 | models = {} 51 | for k, val in profile.items(): 52 | models[k] = create_model(val) 53 | if "init_ckpt_path" in val: 54 | path = val["init_ckpt_path"] 55 | models[k].load_state_dict(torch.load(path)["models_states"][k]) 56 | return managers.ModelsManager(**models) 57 | 58 | 59 | ################################################################################ 60 | # Create optimizers from a profile 61 | ################################################################################ 62 | 63 | def create_optimizer(profile: dict, models: managers.ModelsManager): 64 | r""" Create an instance of the optimizer described in the profile 65 | Args: 66 | profile: a parsed profile describing the optimizer 67 | Example: { "class": optim.Adam, 68 | "model_keys": ["lvm", "q"], 69 | "kwargs": { "lr": 0.0001 } } 70 | If 'model_keys' is missing, the corresponding optimizer will include all parameters 71 | models: an object of ModelsManager 72 | """ 73 | assert _is_instance_profile(profile) 74 | params = models.parameters(*profile.get("model_keys", [])) 75 | try: 76 | return profile["class"](params, **profile["kwargs"]) 77 | except TypeError: 78 | traceback.print_exc() 79 | print(profile["class"]) 80 | exit(1) 81 | 82 | 83 | def create_optimizers(profile: dict, models: managers.ModelsManager): 84 | r""" Create optimizers (an instance of OptimizersManager) described in the profile 85 | Args: 86 | profile: a parsed profile describing optimizers 87 | models: an object of ModelsManager 88 | """ 89 | profile = copy.deepcopy(profile) 90 | optimizers = {} 91 | for k, val in profile.items(): 92 | optimizers[k] = create_optimizer(val, models) 93 | return managers.OptimizersManager(**optimizers) 94 | 95 | 96 | ################################################################################ 97 | # Create lr_schedulers from a profile 98 | ################################################################################ 99 | 100 | def create_lr_scheduler(profile: dict, optimizer: optim.Optimizer): 101 | r""" Create an instance of the optimizer described in the profile 102 | Args: 103 | profile: a parsed profile describing the optimizer 104 | optimizer: the optimizer to apply 105 | """ 106 | assert _is_instance_profile(profile) 107 | try: 108 | return profile["class"](optimizer, **profile["kwargs"]) 109 | except TypeError: 110 | traceback.print_exc() 111 | print(profile["class"]) 112 | exit(1) 113 | 114 | 115 | def create_lr_schedulers(profile: dict, optimizers: managers.OptimizersManager): 116 | r""" Create optimizers (an instance of OptimizersManager) described in the profile 117 | Args: 118 | profile: a parsed profile describing optimizers 119 | optimizers: an object of OptimizersManager 120 | """ 121 | profile = copy.deepcopy(profile) 122 | lr_schedulers = {} 123 | for k, val in profile.items(): 124 | lr_schedulers[k] = create_lr_scheduler(val, optimizers.get(k)) 125 | return managers.LRSchedulersManager(**lr_schedulers) 126 | 127 | 128 | ################################################################################ 129 | # Create criterion from a profile 130 | ################################################################################ 131 | 132 | def create_criterion(profile: dict, 133 | models: managers.ModelsManager, 134 | optimizers: managers.OptimizersManager, 135 | lr_schedulers: managers.LRSchedulersManager): 136 | r""" Create an instance of the criterion described in the profile 137 | Args: 138 | profile: a parsed profile describing the criterion 139 | models: an object of ModelsManager 140 | optimizers: an object of OptimizersManager 141 | lr_schedulers: an object of LRSchedulersManager 142 | """ 143 | assert _is_instance_profile(profile) 144 | try: 145 | return profile["class"](**profile.get("kwargs", {}), 146 | models=models, 147 | optimizers=optimizers, 148 | lr_schedulers=lr_schedulers) 149 | except TypeError: 150 | traceback.print_exc() 151 | print(profile["class"]) 152 | exit(1) 153 | 154 | 155 | ################################################################################ 156 | # Create dataset from a profile 157 | ################################################################################ 158 | 159 | def create_dataset(profile): 160 | assert _is_instance_profile(profile) 161 | try: 162 | return profile["class"](**profile["kwargs"]) 163 | except TypeError: 164 | traceback.print_exc() 165 | print(profile["class"]) 166 | exit(1) 167 | 168 | 169 | ################################################################################ 170 | # Create evaluator from a profile 171 | ################################################################################ 172 | 173 | def create_evaluator(profile, models, dataset, interact): 174 | assert _is_instance_profile(profile) 175 | try: 176 | return profile["class"](**profile["kwargs"], models=models, dataset=dataset, interact=interact) 177 | except TypeError: 178 | traceback.print_exc() 179 | print(profile["class"]) 180 | exit(1) 181 | 182 | 183 | ################################################################################ 184 | # Create interact from a profile 185 | ################################################################################ 186 | 187 | def create_interact(profile: dict) -> Interact: 188 | return Interact(**profile) 189 | 190 | 191 | ################################################################################ 192 | # Create the validation function 193 | ################################################################################ 194 | 195 | def create_val_fn(profile, criterion): 196 | if profile.get("disable_val_fn", False): # no val_fn 197 | return None 198 | elif "val_fn" not in profile: # default val_fn 199 | return functools.partial(score_on_dataset, score_fn=criterion.default_val_fn, 200 | batch_size=get_val(profile, "training", "batch_size")) 201 | else: 202 | profile_val_fn = profile["val_fn"] 203 | batch_size = get_val(profile_val_fn, "batch_size", default=get_val(profile, "training", "batch_size")) 204 | kwargs = profile_val_fn.get("kwargs", {}) 205 | apply_to = profile_val_fn.get("apply_to", "tensor") 206 | if apply_to == "tensor": 207 | def score_fn(v): 208 | return profile_val_fn["fn"](models=criterion.models, v=v, **kwargs) 209 | return functools.partial(score_on_dataset, score_fn=score_fn, batch_size=batch_size) 210 | elif apply_to == "dataset": 211 | return functools.partial(profile_val_fn["fn"], models=criterion.models, batch_size=batch_size) 212 | else: 213 | raise ValueError 214 | 215 | 216 | ################################################################################ 217 | # Create ema 218 | ################################################################################ 219 | 220 | def create_ema(profile, models): 221 | if profile.get("disable_ema", False): 222 | return None, None 223 | else: 224 | ema_keys = get_val(profile, "ema", "keys", default=list(models.__dict__.keys())) 225 | ema_rate = get_val(profile, "ema", "rate", default=0.9999) 226 | ema_models = create_models({key: profile["models"][key] for key in ema_keys}) 227 | ema_models.ema(models, rate=0) 228 | return ema_models, ema_rate 229 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/reproducibility.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["set_seed", "set_deterministic", "backup_codes", "backup_profile"] 3 | 4 | 5 | import torch 6 | import numpy as np 7 | import os 8 | import datetime 9 | import shutil 10 | import pprint 11 | 12 | 13 | def set_seed(seed: int): 14 | if seed is not None: 15 | torch.manual_seed(seed) 16 | np.random.seed(seed) 17 | 18 | 19 | def set_deterministic(flag: bool): 20 | if flag: 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | def backup_codes(path): 26 | current_path = os.path.dirname(os.path.realpath(__file__)) 27 | root_path = os.path.realpath(os.path.join(current_path, os.pardir, os.pardir)) 28 | 29 | path = os.path.join(path, "codes_{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))) 30 | os.makedirs(path, exist_ok=True) 31 | 32 | names = ["core", "interface", "profiles", "develop", "experiments", "tools"] 33 | for name in names: 34 | if os.path.exists(os.path.join(root_path, name)): 35 | shutil.copytree(os.path.join(root_path, name), os.path.join(path, name)) 36 | 37 | pyfiles = filter(lambda x: x.endswith(".py"), os.listdir(root_path)) 38 | for pyfile in pyfiles: 39 | shutil.copy(os.path.join(root_path, pyfile), os.path.join(path, pyfile)) 40 | 41 | 42 | def backup_profile(profile: dict, path): 43 | os.makedirs(path, exist_ok=True) 44 | path = os.path.join(path, "profile_{}.txt".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))) 45 | s = pprint.pformat(profile) 46 | with open(path, 'w') as f: 47 | f.write(s) 48 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/interface/utils/task_schedule.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | from typing import List 3 | import time 4 | from typing import Union, Tuple 5 | import os 6 | 7 | 8 | def get_gpu_memory_map(): 9 | raw = list(os.popen('nvidia-smi --query-gpu=memory.free --format=csv,nounits,noheader')) 10 | mem = [int(x.strip()) for x in raw] 11 | return dict(zip(range(len(mem)), mem)) 12 | 13 | 14 | def get_gpu_total_memory_map(): 15 | raw = list(os.popen('nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader')) 16 | mem = [int(x.strip()) for x in raw] 17 | return dict(zip(range(len(mem)), mem)) 18 | 19 | 20 | def gpu_memory_consumption(): 21 | devices = list(map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(','))) 22 | gpu_memory_map = get_gpu_memory_map() 23 | gpu_total_memory_map = get_gpu_total_memory_map() 24 | return sum([gpu_total_memory_map[device] - gpu_memory_map[device] for device in devices]) 25 | 26 | 27 | def available_devices(threshold=10000) -> List[int]: 28 | gpu_memory_map = get_gpu_memory_map() 29 | devices = [] 30 | for idx, mem in gpu_memory_map.items(): 31 | if mem > threshold: 32 | devices.append(idx) 33 | return devices 34 | 35 | 36 | def format_devices(devices: Union[int, List[int], Tuple[int]]): 37 | if isinstance(devices, int): 38 | return "{}".format(devices) 39 | elif isinstance(devices, tuple) or isinstance(devices, list): 40 | return ','.join(map(str, devices)) 41 | 42 | 43 | class Task(object): 44 | def __init__(self, process: Process, n_devices: int = 1): 45 | self.process = process 46 | self.n_devices = n_devices 47 | self.devices = None 48 | self.just_created = True 49 | 50 | def state(self): 51 | if self.just_created: 52 | return 'just_created' 53 | elif self.process.is_alive(): 54 | return 'is_alive' 55 | else: 56 | return 'finished' 57 | 58 | def start(self, devices): 59 | self.devices = devices 60 | os.environ["CUDA_VISIBLE_DEVICES"] = format_devices(devices) 61 | self.process.start() 62 | self.just_created = False 63 | 64 | 65 | class DevicesPool(object): 66 | def __init__(self, devices: List[int]): 67 | self.devices = devices.copy() 68 | 69 | def flow_out(self, n_devices: int): 70 | if len(self.devices) < n_devices: 71 | return None 72 | ret = [] 73 | for _ in range(n_devices): 74 | ret.append(self.devices.pop()) 75 | return ret 76 | 77 | def flow_in(self, devices: List[int]): 78 | for device in devices: 79 | self.devices.append(device) 80 | 81 | 82 | ################################################################################ 83 | # Run multiple tasks run in parallel, exclusively using devices 84 | # Suitable for running tasks consuming high gpu memory 85 | ################################################################################ 86 | 87 | def wait_schedule(tasks: List[Task], devices: List[int]): 88 | # assert len(set(devices)) == len(devices) 89 | for task in tasks: 90 | assert task.n_devices <= len(devices) 91 | tasks = sorted(tasks, key=lambda x: x.n_devices, reverse=True) 92 | devices_pool = DevicesPool(devices) 93 | 94 | def linked_list_next(_idx: int, _lst: List): 95 | if _lst: 96 | return (_idx + 1) % len(_lst) 97 | else: 98 | return -1 99 | 100 | idx = 0 101 | while tasks: 102 | task = tasks[idx] 103 | state = task.state() 104 | if state == 'just_created': 105 | devices = devices_pool.flow_out(task.n_devices) 106 | if devices is not None: 107 | print("\033[1m start a task with {} devices".format(len(devices))) 108 | task.start(devices) 109 | elif state == 'finished': 110 | print("\033[1m a task with {} devices finished".format(len(task.devices))) 111 | devices_pool.flow_in(task.devices) 112 | task.process.close() 113 | tasks.pop(idx) 114 | idx -= 1 115 | idx = linked_list_next(idx, tasks) 116 | time.sleep(1) 117 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/cifar_imagenet_codes/profiles/__init__.py -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/common.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | ################################################################################ 5 | # The commonly used interact 6 | ################################################################################ 7 | 8 | def interact_datetime_train(period: int): 9 | return { 10 | "fname_log": "os.path.join($(workspace_root), 'train/logs/" + 11 | "{}.log'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')))", 12 | "summary_root": "os.path.join($(workspace_root), 'train/summary/')", 13 | "period": period, 14 | } 15 | 16 | 17 | def interact_datetime_evaluate(): 18 | return { 19 | "fname_log": "os.path.join($(workspace_root), 'evaluate/logs/" + 20 | "{}.log'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')))", 21 | } 22 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baofff/Analytic-DPM/2d7a28c0bbd984a6d47744ab4f6440f3e79757db/cifar_imagenet_codes/profiles/ddpm/__init__.py -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/beta_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 6 | """ 7 | Get a pre-defined beta schedule for the given name. 8 | 9 | The beta schedule library consists of beta schedules which remain similar 10 | in the limit of num_diffusion_timesteps. 11 | Beta schedules may be added, but should not be removed or changed once 12 | they are committed to maintain backwards compatibility. 13 | """ 14 | if schedule_name == "linear": 15 | # Linear schedule from Ho et al, extended to work for any number of 16 | # diffusion steps. 17 | scale = 1000 / num_diffusion_timesteps 18 | beta_start = scale * 0.0001 19 | beta_end = scale * 0.02 20 | return np.linspace( 21 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 22 | ) 23 | elif schedule_name == "cosine": 24 | return betas_for_alpha_bar( 25 | num_diffusion_timesteps, 26 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 27 | ) 28 | else: 29 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 30 | 31 | 32 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 33 | """ 34 | Create a beta schedule that discretizes the given alpha_t_bar function, 35 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 36 | 37 | :param num_diffusion_timesteps: the number of betas to produce. 38 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 39 | produces the cumulative product of (1-beta) up to that 40 | part of the diffusion process. 41 | :param max_beta: the maximum beta to use; use values lower than 1 to 42 | prevent singularities. 43 | """ 44 | betas = [] 45 | for i in range(num_diffusion_timesteps): 46 | t1 = i / num_diffusion_timesteps 47 | t2 = (i + 1) / num_diffusion_timesteps 48 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 49 | return np.array(betas) 50 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | from . import naive_evaluate 3 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/cifar10/base.py: -------------------------------------------------------------------------------- 1 | import core.modules as modules 2 | import interface.datasets as datasets 3 | import interface.evaluators as evaluators 4 | import numpy as np 5 | import profiles.ddpm.beta_schedules as beta_schedules 6 | from interface.utils import dict_utils 7 | 8 | 9 | default_betas = np.append(0., beta_schedules.get_named_beta_schedule("cosine", 4000)) 10 | 11 | 12 | unet_model = { 13 | "class": modules.UNetModel, 14 | "kwargs": { 15 | "in_channels": 3, 16 | "model_channels": 128, 17 | "out_channels": 3, 18 | "num_res_blocks": 3, 19 | "attention_resolutions": (32 // 16, 32 // 8), 20 | "dropout": 0.3, 21 | "channel_mult": (1, 2, 2, 2), 22 | "conv_resample": True, 23 | "dims": 2, 24 | "num_classes": None, 25 | "use_checkpoint": False, 26 | "num_heads": 4, 27 | "num_heads_upsample": -1, 28 | "use_scale_shift_norm": True, 29 | } 30 | } 31 | 32 | 33 | unet_model_double = dict_utils.merge_dict(unet_model, { 34 | "kwargs": { 35 | "out_channels": 6 36 | } 37 | }) 38 | 39 | 40 | dataset = { 41 | "use_val": False, 42 | "class": datasets.Cifar10, 43 | "kwargs": { 44 | "data_path": "workspace/datasets/cifar10/", 45 | } 46 | } 47 | 48 | 49 | ddpm_naive_evaluator_train = { 50 | "class": evaluators.DDPMNaiveEvaluator, 51 | "kwargs": { 52 | "options": { 53 | "grid_sample": { 54 | "period": 5000, 55 | "kwargs": { # fast sampling 56 | "betas": "$(betas)", 57 | "small_sigma": True, # small sample steps require small sigma 58 | "clip_denoise": True, # must be true since it will improve the sample quality 59 | "rescale_timesteps": "$(rescale_timesteps)", 60 | "sample_steps": 50, 61 | "path": "os.path.join($(workspace_root), 'train/evaluator/grid_sample')" 62 | } 63 | }, 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/cifar10/train.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ["train", "train_ddpm_dsm"] 3 | 4 | 5 | import core.criterions as criterions 6 | from interface.utils import dict_utils 7 | import profiles.common as common 8 | from .base import unet_model, dataset, default_betas, ddpm_naive_evaluator_train 9 | import torch.optim as optim 10 | 11 | 12 | train = { 13 | "seed": 1234, 14 | "betas": default_betas, 15 | "rescale_timesteps": True, 16 | "ckpt_root": "os.path.join($(workspace_root), 'train/ckpts/')", 17 | "backup_root": "os.path.join($(workspace_root), 'train/reproducibility/')", 18 | "training": { 19 | "n_ckpts": 50, 20 | "n_its": 500000, 21 | "batch_size": 128, 22 | }, 23 | "ema": { 24 | "rate": 0.9999 25 | }, 26 | "dataset": dataset, 27 | "optimizers": { 28 | "all": { 29 | "class": optim.AdamW, 30 | "kwargs": { 31 | "lr": 0.0001, 32 | "weight_decay": 0. 33 | } 34 | } 35 | }, 36 | "interact": common.interact_datetime_train(period=10), 37 | } 38 | 39 | 40 | train_ddpm_dsm = dict_utils.merge_dict(train, { 41 | "models": { 42 | "eps_model": unet_model, 43 | }, 44 | "criterion": { 45 | "class": criterions.DDPMDSM, 46 | "kwargs": { 47 | "betas": "$(betas)", 48 | "rescale_timesteps": "$(rescale_timesteps)", 49 | } 50 | }, 51 | "evaluator": ddpm_naive_evaluator_train 52 | }) 53 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/imagenet64/__init__.py: -------------------------------------------------------------------------------- 1 | from . import evaluate 2 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/profiles/ddpm/imagenet64/base.py: -------------------------------------------------------------------------------- 1 | import core.modules as modules 2 | import numpy as np 3 | import profiles.ddpm.beta_schedules as beta_schedules 4 | import interface.datasets as datasets 5 | 6 | 7 | default_betas = np.append(0., beta_schedules.get_named_beta_schedule("cosine", 4000)) 8 | 9 | 10 | unet_model = { 11 | "class": modules.UNetModel, 12 | "kwargs": { 13 | "in_channels": 3, 14 | "model_channels": 128, 15 | "out_channels": 6, 16 | "num_res_blocks": 3, 17 | "attention_resolutions": (64 // 16, 64 // 8), 18 | "dropout": 0.0, 19 | "channel_mult": (1, 2, 3, 4), 20 | "conv_resample": True, 21 | "dims": 2, 22 | "num_classes": None, 23 | "use_checkpoint": False, 24 | "num_heads": 4, 25 | "num_heads_upsample": -1, 26 | "use_scale_shift_norm": True, 27 | } 28 | } 29 | 30 | 31 | dataset = { 32 | "use_val": False, 33 | "class": datasets.Imagenet64, 34 | "kwargs": { 35 | "path": "workspace/datasets/imagenet64" 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /cifar_imagenet_codes/tools/eval.py: -------------------------------------------------------------------------------- 1 | from interface.utils.misc import get_root_by_time 2 | import os 3 | from .fid_score import calculate_fid_given_paths 4 | 5 | 6 | def fid_ckpts(stat, path, prefix, names, ckpts, dirname, device=None, batch_size=200, time=None): 7 | path_prefix_time = get_root_by_time(path, prefix, time, "latest") 8 | for name in names: 9 | root = os.path.join(path_prefix_time, name, 'evaluate/evaluator/sample2dir') 10 | with open(os.path.join(root, "%s_fid.txt" % dirname), 'w') as f: 11 | for ckpt in ckpts: 12 | samples_dir = os.path.join(root, ckpt, dirname) 13 | fid = calculate_fid_given_paths((stat, samples_dir), device=device, batch_size=batch_size) 14 | f.write("{}: {}\n".format(ckpt, fid)) 15 | 16 | 17 | def fid_one_ckpt(stat, path, prefix, names, ckpts, dirname, device=None, batch_size=200, time=None): 18 | path_prefix_time = get_root_by_time(path, prefix, time, "latest") 19 | for name, ckpt in zip(names, ckpts): 20 | root = os.path.join(path_prefix_time, name, 'evaluate/evaluator/sample2dir') 21 | with open(os.path.join(root, "%s_fid.txt" % dirname), 'w') as f: 22 | samples_dir = os.path.join(root, ckpt, dirname) 23 | fid = calculate_fid_given_paths((stat, samples_dir), device=device, batch_size=batch_size) 24 | f.write("{}: {}\n".format(ckpt, fid)) 25 | 26 | 27 | def fid_ckpts_cifar10(path, prefix, names, ckpts, dirname, device=None, batch_size=200, time=None): 28 | fid_ckpts('workspace/fid_stats/fid_stats_cifar10_train_pytorch.npz', path, prefix, names, ckpts, dirname, device, batch_size, time) 29 | 30 | 31 | def fid_one_ckpt_cifar10(path, prefix, names, ckpts, dirname, device=None, batch_size=200, time=None): 32 | fid_one_ckpt('workspace/fid_stats/fid_stats_cifar10_train_pytorch.npz', path, prefix, names, ckpts, dirname, device, batch_size, time) 33 | --------------------------------------------------------------------------------