├── assets ├── bedroom.jpeg ├── church.jpeg ├── ffhq_256.jpg ├── ffhq_1024.jpeg ├── schematic.jpg ├── celebahq_256.jpg └── ffhq_samples.jpg ├── op ├── __init__.py ├── fused_bias_act.cpp ├── upfirdn2d.cpp ├── fused_act.py ├── fused_bias_act_kernel.cu ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── requirements.txt ├── .gitignore ├── models ├── __init__.py ├── ema.py ├── utils.py ├── ddpm.py ├── normalization.py ├── up_or_down_sampling.py └── layerspp.py ├── utils.py ├── debug.py ├── configs ├── ve │ ├── cifar10_ddpm.py │ ├── ncsnv2 │ │ ├── cifar10.py │ │ ├── celeba.py │ │ └── bedroom.py │ ├── ncsn │ │ ├── celeba_124.py │ │ ├── celeba_1245.py │ │ ├── cifar10_124.py │ │ ├── celeba.py │ │ ├── cifar10.py │ │ ├── cifar10_5.py │ │ ├── celeba_5.py │ │ └── cifar10_1245.py │ ├── cifar10_ncsnpp_continuous.py │ ├── cifar10_ncsnpp.py │ ├── celeba_ncsnpp.py │ ├── cifar10_ncsnpp_deep_continuous.py │ ├── bedroom_ncsnpp_continuous.py │ ├── church_ncsnpp_continuous.py │ ├── ffhq_256_ncsnpp_continuous.py │ ├── celebahq_256_ncsnpp_continuous.py │ ├── celebahq_ncsnpp_continuous.py │ └── ffhq_ncsnpp_continuous.py ├── vp │ ├── ddpm │ │ ├── cifar10_continuous.py │ │ ├── cifar10.py │ │ ├── cifar10_unconditional.py │ │ ├── bedroom.py │ │ ├── church.py │ │ └── celebahq.py │ ├── cifar10_ncsnpp.py │ ├── cifar10_ncsnpp_continuous.py │ ├── cifar10_ddpmpp.py │ ├── cifar10_ddpmpp_continuous.py │ ├── cifar10_ncsnpp_deep_continuous.py │ └── cifar10_ddpmpp_deep_continuous.py ├── subvp │ ├── cifar10_ddpm_continuous.py │ ├── cifar10_ncsnpp_continuous.py │ ├── cifar10_ddpmpp_continuous.py │ ├── cifar10_ncsnpp_deep_continuous.py │ └── cifar10_ddpmpp_deep_continuous.py ├── default_lsun_configs.py ├── default_celeba_configs.py └── default_cifar10_configs.py ├── main.py ├── likelihood.py ├── evaluation.py ├── datasets.py ├── sde_lib.py ├── controllable_generation.py ├── losses.py └── LICENSE /assets/bedroom.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/bedroom.jpeg -------------------------------------------------------------------------------- /assets/church.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/church.jpeg -------------------------------------------------------------------------------- /assets/ffhq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/ffhq_256.jpg -------------------------------------------------------------------------------- /assets/ffhq_1024.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/ffhq_1024.jpeg -------------------------------------------------------------------------------- /assets/schematic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/schematic.jpg -------------------------------------------------------------------------------- /assets/celebahq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/celebahq_256.jpg -------------------------------------------------------------------------------- /assets/ffhq_samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timokoesters/testing/original/assets/ffhq_samples.jpg -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ml-collections==0.1.0 2 | tensorflow-gan==2.0.0 3 | tensorflow_io 4 | tensorflow_datasets==3.1.0 5 | tensorflow==2.4.0 6 | tensorflow-addons==0.12.0 7 | tensorboard==2.4.0 8 | absl-py==0.10.0 9 | torch>=1.7.0 10 | torchvision 11 | ninja -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import tensorflow as tf 4 | import os 5 | import logging 6 | 7 | def eprint(*args, **kwargs): 8 | print(*args, file=sys.stderr, **kwargs) 9 | 10 | def restore_checkpoint(ckpt_dir, state, device): 11 | if not tf.io.gfile.exists(ckpt_dir): 12 | tf.io.gfile.makedirs(os.path.dirname(ckpt_dir)) 13 | logging.warning(f"No checkpoint found at {ckpt_dir}. " 14 | f"Returned the same state as input") 15 | return state 16 | else: 17 | loaded_state = torch.load(ckpt_dir, map_location=device) 18 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 19 | state['model'].load_state_dict(loaded_state['model'], strict=False) 20 | state['ema'].load_state_dict(loaded_state['ema']) 21 | state['step'] = loaded_state['step'] 22 | return state 23 | 24 | 25 | def save_checkpoint(ckpt_dir, state): 26 | saved_state = { 27 | 'optimizer': state['optimizer'].state_dict(), 28 | 'model': state['model'].state_dict(), 29 | 'ema': state['ema'].state_dict(), 30 | 'step': state['step'] 31 | } 32 | torch.save(saved_state, ckpt_dir) -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import matplotlib.pyplot as plt 3 | import io 4 | import csv 5 | import numpy as np 6 | import pandas as pd 7 | import seaborn as sns 8 | import matplotlib 9 | import importlib 10 | import os 11 | import functools 12 | import itertools 13 | import torch 14 | 15 | import torch.nn as nn 16 | import numpy as np 17 | import tensorflow as tf 18 | import tensorflow_datasets as tfds 19 | import tensorflow_gan as tfgan 20 | import tqdm 21 | import io 22 | import inspect 23 | sns.set(font_scale=2) 24 | sns.set(style="whitegrid") 25 | 26 | import models 27 | from models import utils as mutils 28 | from models import ncsnv2 29 | from models import ncsnpp 30 | from models import ddpm as ddpm_model 31 | from models import layerspp 32 | from models import layers 33 | from models import normalization 34 | 35 | #from configs.ncsnpp import cifar10_continuous_ve as configs 36 | from configs.ddpm import cifar10_continuous_vp as configs 37 | config = configs.get_config() 38 | 39 | checkpoint = torch.load('exp/ddpm_continuous_vp.pth') 40 | 41 | #score_model = ncsnpp.NCSNpp(config) 42 | score_model = ddpm_model.DDPM(config) 43 | score_model.load_state_dict(checkpoint) 44 | score_model = score_model.eval() 45 | x = torch.ones(8, 3, 32, 32) 46 | y = torch.tensor([1] * 8) 47 | breakpoint() 48 | with torch.no_grad(): 49 | score = score_model(x, y) -------------------------------------------------------------------------------- /configs/ve/cifar10_ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Train the original DDPM model with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ddpm' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 2 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.conv_size = 3 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /configs/vp/ddpm/cifar10_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ddpm_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with sub-VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'subvpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /configs/vp/ddpm/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on cifar-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /configs/vp/ddpm/cifar10_unconditional.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = False 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /configs/ve/ncsnv2/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsnv2_64' 38 | model.scale_by_sigma = True 39 | model.num_scales = 232 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-4 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/ve/ncsn/celeba_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/ve/ncsn/celeba_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1245 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/ve/ncsn/cifar10_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 232 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/ve/ncsn/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.loss = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /configs/ve/ncsn/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /configs/ve/ncsn/cifar10_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.snr = 0.316 34 | sampling.n_steps_each = 100 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /configs/ve/ncsn/celeba_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv1 model with technique 5 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1. 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /configs/vp/ddpm/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'bedroom' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /configs/vp/ddpm/church.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on church_outdoor.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'church_outdoor' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /configs/ve/ncsnv2/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000033 30 | n_steps_each = 5 31 | ckpt_id = 210000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.128 41 | # model 42 | model = config.model 43 | model.name = 'ncsnv2_64' 44 | model.scale_by_sigma = True 45 | model.num_scales = 500 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-4 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /configs/vp/ddpm/celebahq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.dataset = 'CelebAHQ' 40 | data.centered = True 41 | data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' 42 | data.image_size = 256 43 | 44 | # model 45 | model = config.model 46 | model.name = 'ddpm' 47 | model.scale_by_sigma = False 48 | model.num_scales = 1000 49 | model.ema_rate = 0.9999 50 | model.normalization = 'GroupNorm' 51 | model.nonlinearity = 'swish' 52 | model.nf = 128 53 | model.ch_mult = (1, 1, 2, 2, 4, 4) 54 | model.num_res_blocks = 2 55 | model.attn_resolutions = (16,) 56 | model.resamp_with_conv = True 57 | model.conditional = True 58 | 59 | # optim 60 | optim = config.optim 61 | optim.lr = 2e-5 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /configs/ve/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vesde' 26 | training.continuous = True 27 | 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'reverse_diffusion' 32 | sampling.corrector = 'langevin' 33 | 34 | # model 35 | model = config.model 36 | model.name = 'ncsnpp' 37 | model.scale_by_sigma = True 38 | model.ema_rate = 0.999 39 | model.normalization = 'GroupNorm' 40 | model.nonlinearity = 'swish' 41 | model.nf = 128 42 | model.ch_mult = (1, 2, 2, 2) 43 | model.num_res_blocks = 4 44 | model.attn_resolutions = (16,) 45 | model.resamp_with_conv = True 46 | model.conditional = True 47 | model.fir = True 48 | model.fir_kernel = [1, 3, 3, 1] 49 | model.skip_rescale = True 50 | model.resblock_type = 'biggan' 51 | model.progressive = 'none' 52 | model.progressive_input = 'residual' 53 | model.progressive_combine = 'sum' 54 | model.attention_type = 'ddpm' 55 | model.init_scale = 0. 56 | model.fourier_scale = 16 57 | model.conv_size = 3 58 | 59 | return config 60 | -------------------------------------------------------------------------------- /configs/ve/ncsn/cifar10_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4,5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000062 30 | n_steps_each = 5 31 | ckpt_id = 300000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.176 41 | # model 42 | model = config.model 43 | model.name = 'ncsn' 44 | model.scale_by_sigma = False 45 | model.num_scales = 232 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-3 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /configs/ve/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 4 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.fir = True 49 | model.fir_kernel = [1, 3, 3, 1] 50 | model.skip_rescale = True 51 | model.resblock_type = 'biggan' 52 | model.progressive = 'none' 53 | model.progressive_input = 'residual' 54 | model.progressive_combine = 'sum' 55 | model.attention_type = 'ddpm' 56 | model.init_scale = 0.0 57 | model.embedding_type = 'positional' 58 | model.conv_size = 3 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /configs/ve/ncsnv2/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on bedroom.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.batch_size = 128 27 | training.sde = 'vesde' 28 | training.continuouse = False 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'none' 33 | sampling.corrector = 'ald' 34 | sampling.n_steps_each = 3 35 | sampling.snr = 0.095 36 | # data 37 | data = config.data 38 | data.category = 'bedroom' 39 | data.image_size = 128 40 | # model 41 | model = config.model 42 | model.name = 'ncsnv2_128' 43 | model.scale_by_sigma = True 44 | model.sigma_max = 190 45 | model.num_scales = 1086 46 | model.ema_rate = 0.9999 47 | model.sigma_min = 0.01 48 | model.normalization = 'InstanceNorm++' 49 | model.nonlinearity = 'elu' 50 | model.nf = 128 51 | model.interpolation = 'bilinear' 52 | # optim 53 | optim = config.optim 54 | optim.weight_decay = 0 55 | optim.optimizer = 'Adam' 56 | optim.lr = 1e-4 57 | optim.beta1 = 0.9 58 | optim.amsgrad = False 59 | optim.eps = 1e-8 60 | optim.warmup = 0 61 | optim.grad_clip = -1 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /configs/ve/celeba_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CelebA with SMLD.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.sigma_begin = 90 40 | model.ema_rate = 0.999 41 | model.normalization = 'GroupNorm' 42 | model.nonlinearity = 'swish' 43 | model.nf = 128 44 | model.ch_mult = (1, 2, 2, 2) 45 | model.num_res_blocks = 4 46 | model.attn_resolutions = (16,) 47 | model.resamp_with_conv = True 48 | model.conditional = True 49 | model.fir = True 50 | model.fir_kernel = [1, 3, 3, 1] 51 | model.skip_rescale = True 52 | model.resblock_type = 'biggan' 53 | model.progressive = 'none' 54 | model.progressive_input = 'residual' 55 | model.progressive_combine = 'sum' 56 | model.attention_type = 'ddpm' 57 | model.init_scale = 0.0 58 | model.conv_size = 3 59 | model.embedding_type = 'positional' 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /configs/ve/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'langevin' 35 | 36 | # model 37 | model = config.model 38 | model.name = 'ncsnpp' 39 | model.fourier_scale = 16 40 | model.scale_by_sigma = True 41 | model.ema_rate = 0.999 42 | model.normalization = 'GroupNorm' 43 | model.nonlinearity = 'swish' 44 | model.nf = 128 45 | model.ch_mult = (1, 2, 2, 2) 46 | model.num_res_blocks = 8 47 | model.attn_resolutions = (16,) 48 | model.resamp_with_conv = True 49 | model.conditional = True 50 | model.fir = True 51 | model.fir_kernel = [1, 3, 3, 1] 52 | model.skip_rescale = True 53 | model.resblock_type = 'biggan' 54 | model.progressive = 'none' 55 | model.progressive_input = 'residual' 56 | model.progressive_combine = 'sum' 57 | model.attention_type = 'ddpm' 58 | model.init_scale = 0.0 59 | model.conv_size = 3 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /configs/ve/bedroom_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on bedroom with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'bedroom' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = True 43 | model.ema_rate = 0.999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 48 | model.num_res_blocks = 2 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'output_skip' 57 | model.progressive_input = 'input_skip' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.init_scale = 0. 61 | model.fourier_scale = 16 62 | model.conv_size = 3 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with DDPM.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'residual' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0.0 62 | model.embedding_type = 'positional' 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/ve/church_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on Church with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'church_outdoor' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.sigma_max = 380 43 | model.scale_by_sigma = True 44 | model.ema_rate = 0.999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 49 | model.num_res_blocks = 2 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'output_skip' 58 | model.progressive_input = 'input_skip' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with sub-VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'subvpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ddpmpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'ancestral_sampling' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ncsnpp' 44 | model.fourier_scale = 16 45 | model.scale_by_sigma = False 46 | model.ema_rate = 0.9999 47 | model.normalization = 'GroupNorm' 48 | model.nonlinearity = 'swish' 49 | model.nf = 128 50 | model.ch_mult = (1, 2, 2, 2) 51 | model.num_res_blocks = 8 52 | model.attn_resolutions = (16,) 53 | model.resamp_with_conv = True 54 | model.conditional = True 55 | model.fir = True 56 | model.fir_kernel = [1, 3, 3, 1] 57 | model.skip_rescale = True 58 | model.resblock_type = 'biggan' 59 | model.progressive = 'none' 60 | model.progressive_input = 'residual' 61 | model.progressive_combine = 'sum' 62 | model.attention_type = 'ddpm' 63 | model.embedding_type = 'positional' 64 | model.init_scale = 0.0 65 | model.conv_size = 3 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ncsnpp' 44 | model.fourier_scale = 16 45 | model.scale_by_sigma = False 46 | model.ema_rate = 0.9999 47 | model.normalization = 'GroupNorm' 48 | model.nonlinearity = 'swish' 49 | model.nf = 128 50 | model.ch_mult = (1, 2, 2, 2) 51 | model.num_res_blocks = 8 52 | model.attn_resolutions = (16,) 53 | model.resamp_with_conv = True 54 | model.conditional = True 55 | model.fir = True 56 | model.fir_kernel = [1, 3, 3, 1] 57 | model.skip_rescale = True 58 | model.resblock_type = 'biggan' 59 | model.progressive = 'none' 60 | model.progressive_input = 'residual' 61 | model.progressive_combine = 'sum' 62 | model.attention_type = 'ddpm' 63 | model.embedding_type = 'positional' 64 | model.init_scale = 0.0 65 | model.conv_size = 3 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ncsnpp' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 8 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | model.fir = False 55 | model.fir_kernel = [1, 3, 3, 1] 56 | model.skip_rescale = True 57 | model.resblock_type = 'biggan' 58 | model.progressive = 'none' 59 | model.progressive_input = 'none' 60 | model.progressive_combine = 'sum' 61 | model.attention_type = 'ddpm' 62 | model.init_scale = 0. 63 | model.embedding_type = 'positional' 64 | model.fourier_scale = 16 65 | model.conv_size = 3 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ncsnpp' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 8 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | model.fir = False 55 | model.fir_kernel = [1, 3, 3, 1] 56 | model.skip_rescale = True 57 | model.resblock_type = 'biggan' 58 | model.progressive = 'none' 59 | model.progressive_input = 'none' 60 | model.progressive_combine = 'sum' 61 | model.attention_type = 'ddpm' 62 | model.init_scale = 0. 63 | model.embedding_type = 'positional' 64 | model.fourier_scale = 16 65 | model.conv_size = 3 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /configs/ve/ffhq_256_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on Church with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.dataset = 'FFHQ' 38 | data.image_size = 256 39 | data.tfrecords_path = '/home/yangsong/ncsc/ffhq/ffhq-r08.tfrecords' 40 | 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ncsnpp' 45 | model.sigma_max = 348 46 | model.scale_by_sigma = True 47 | model.ema_rate = 0.999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | model.fir = True 57 | model.fir_kernel = [1, 3, 3, 1] 58 | model.skip_rescale = True 59 | model.resblock_type = 'biggan' 60 | model.progressive = 'output_skip' 61 | model.progressive_input = 'input_skip' 62 | model.progressive_combine = 'sum' 63 | model.attention_type = 'ddpm' 64 | model.init_scale = 0. 65 | model.fourier_scale = 16 66 | model.conv_size = 3 67 | 68 | return config 69 | -------------------------------------------------------------------------------- /configs/ve/celebahq_256_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on Church with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.dataset = 'CelebAHQ' 38 | data.image_size = 256 39 | data.tfrecords_path = '/home/yangsong/ncsc/celebahq/r08.tfrecords' 40 | 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ncsnpp' 45 | model.sigma_max = 348 46 | model.scale_by_sigma = True 47 | model.ema_rate = 0.999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | model.fir = True 57 | model.fir_kernel = [1, 3, 3, 1] 58 | model.skip_rescale = True 59 | model.resblock_type = 'biggan' 60 | model.progressive = 'output_skip' 61 | model.progressive_input = 'input_skip' 62 | model.progressive_combine = 'sum' 63 | model.attention_type = 'ddpm' 64 | model.init_scale = 0. 65 | model.fourier_scale = 16 66 | model.conv_size = 3 67 | 68 | return config 69 | -------------------------------------------------------------------------------- /configs/default_lsun_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 64 10 | training.n_iters = 2400001 11 | training.snapshot_freq = 50000 12 | training.log_freq = 50 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 5000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = False 19 | training.continuous = True 20 | training.reduce_mean = False 21 | 22 | # sampling 23 | config.sampling = sampling = ml_collections.ConfigDict() 24 | sampling.n_steps_each = 1 25 | sampling.noise_removal = True 26 | sampling.probability_flow = False 27 | sampling.snr = 0.075 28 | 29 | # evaluation 30 | config.eval = evaluate = ml_collections.ConfigDict() 31 | evaluate.begin_ckpt = 50 32 | evaluate.end_ckpt = 96 33 | evaluate.batch_size = 512 34 | evaluate.enable_sampling = True 35 | evaluate.num_samples = 50000 36 | evaluate.enable_loss = True 37 | evaluate.enable_bpd = False 38 | evaluate.bpd_dataset = 'test' 39 | 40 | # data 41 | config.data = data = ml_collections.ConfigDict() 42 | data.dataset = 'LSUN' 43 | data.image_size = 256 44 | data.random_flip = True 45 | data.uniform_dequantization = False 46 | data.centered = False 47 | data.num_channels = 3 48 | 49 | # model 50 | config.model = model = ml_collections.ConfigDict() 51 | model.sigma_max = 378 52 | model.sigma_min = 0.01 53 | model.num_scales = 2000 54 | model.beta_min = 0.1 55 | model.beta_max = 20. 56 | model.dropout = 0. 57 | model.embedding_type = 'fourier' 58 | 59 | # optimization 60 | config.optim = optim = ml_collections.ConfigDict() 61 | optim.weight_decay = 0 62 | optim.optimizer = 'Adam' 63 | optim.lr = 2e-4 64 | optim.beta1 = 0.9 65 | optim.eps = 1e-8 66 | optim.warmup = 5000 67 | optim.grad_clip = 1. 68 | 69 | config.seed = 42 70 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 71 | 72 | return config -------------------------------------------------------------------------------- /configs/default_celeba_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 128 10 | training.n_iters = 1300001 11 | training.snapshot_freq = 50000 12 | training.log_freq = 50 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 10000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = False 19 | training.continuous = True 20 | training.reduce_mean = False 21 | 22 | # sampling 23 | config.sampling = sampling = ml_collections.ConfigDict() 24 | sampling.n_steps_each = 1 25 | sampling.noise_removal = True 26 | sampling.probability_flow = False 27 | sampling.snr = 0.17 28 | 29 | # evaluation 30 | config.eval = evaluate = ml_collections.ConfigDict() 31 | evaluate.begin_ckpt = 1 32 | evaluate.end_ckpt = 26 33 | evaluate.batch_size = 1024 34 | evaluate.enable_sampling = True 35 | evaluate.num_samples = 50000 36 | evaluate.enable_loss = True 37 | evaluate.enable_bpd = False 38 | evaluate.bpd_dataset = 'test' 39 | 40 | # data 41 | config.data = data = ml_collections.ConfigDict() 42 | data.dataset = 'CELEBA' 43 | data.image_size = 64 44 | data.random_flip = True 45 | data.uniform_dequantization = False 46 | data.centered = False 47 | data.num_channels = 3 48 | 49 | # model 50 | config.model = model = ml_collections.ConfigDict() 51 | model.sigma_max = 90. 52 | model.sigma_min = 0.01 53 | model.num_scales = 1000 54 | model.beta_min = 0.1 55 | model.beta_max = 20. 56 | model.dropout = 0.1 57 | model.embedding_type = 'fourier' 58 | 59 | # optimization 60 | config.optim = optim = ml_collections.ConfigDict() 61 | optim.weight_decay = 0 62 | optim.optimizer = 'Adam' 63 | optim.lr = 2e-4 64 | optim.beta1 = 0.9 65 | optim.eps = 1e-8 66 | optim.warmup = 5000 67 | optim.grad_clip = 1. 68 | 69 | config.seed = 42 70 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 71 | 72 | return config -------------------------------------------------------------------------------- /configs/default_cifar10_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 1 10 | training.n_iters = 1300001 11 | training.snapshot_freq = 1 12 | training.log_freq = 1 13 | training.eval_freq = 1 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 538 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = False 19 | training.continuous = True 20 | training.reduce_mean = False 21 | 22 | # sampling 23 | config.sampling = sampling = ml_collections.ConfigDict() 24 | sampling.n_steps_each = 1 25 | sampling.noise_removal = True 26 | sampling.probability_flow = False 27 | sampling.snr = 0.16 28 | 29 | # evaluation 30 | config.eval = evaluate = ml_collections.ConfigDict() 31 | evaluate.begin_ckpt = 9 32 | evaluate.end_ckpt = 26 33 | evaluate.batch_size = 1024 34 | evaluate.enable_sampling = False 35 | evaluate.num_samples = 50000 36 | evaluate.enable_loss = True 37 | evaluate.enable_bpd = False 38 | evaluate.bpd_dataset = 'test' 39 | 40 | # data 41 | config.data = data = ml_collections.ConfigDict() 42 | data.dataset = 'CIFAR10' 43 | data.image_size = 32 44 | data.random_flip = True 45 | data.centered = False 46 | data.uniform_dequantization = False 47 | data.num_channels = 3 48 | 49 | # model 50 | config.model = model = ml_collections.ConfigDict() 51 | model.sigma_min = 0.01 52 | model.sigma_max = 50 53 | model.num_scales = 1000 54 | model.beta_min = 0.1 55 | model.beta_max = 20. 56 | model.dropout = 0.1 57 | model.embedding_type = 'fourier' 58 | 59 | # optimization 60 | config.optim = optim = ml_collections.ConfigDict() 61 | optim.weight_decay = 0 62 | optim.optimizer = 'Adam' 63 | optim.lr = 2e-4 64 | optim.beta1 = 0.9 65 | optim.eps = 1e-8 66 | optim.warmup = 5000 67 | optim.grad_clip = 1. 68 | 69 | config.seed = 42 70 | config.device = 'cuda:0' #torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 71 | 72 | return config -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Training and evaluation""" 17 | 18 | import run_lib 19 | from absl import app 20 | from absl import flags 21 | from ml_collections.config_flags import config_flags 22 | import logging 23 | import os 24 | import tensorflow as tf 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | config_flags.DEFINE_config_file( 29 | "config", None, "Training configuration.", lock_config=True) 30 | flags.DEFINE_string("workdir", None, "Work directory.") 31 | flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") 32 | flags.DEFINE_string("eval_folder", "eval", 33 | "The folder name for storing evaluation results") 34 | flags.mark_flags_as_required(["workdir", "config", "mode"]) 35 | 36 | 37 | def main(argv): 38 | if FLAGS.mode == "train": 39 | # Create the working directory 40 | tf.io.gfile.makedirs(FLAGS.workdir) 41 | # Set logger so that it outputs to both console and file 42 | # Make logging work for both disk and Google Cloud Storage 43 | gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w') 44 | handler = logging.StreamHandler(gfile_stream) 45 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 46 | handler.setFormatter(formatter) 47 | logger = logging.getLogger() 48 | logger.addHandler(handler) 49 | logger.setLevel('INFO') 50 | # Run the training pipeline 51 | run_lib.train(FLAGS.config, FLAGS.workdir) 52 | elif FLAGS.mode == "eval": 53 | # Run the evaluation pipeline 54 | run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder) 55 | else: 56 | raise ValueError(f"Mode {FLAGS.mode} not recognized.") 57 | 58 | 59 | if __name__ == "__main__": 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /configs/ve/celebahq_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CelebAHQ with VE SDE.""" 18 | 19 | import ml_collections 20 | import torch 21 | 22 | 23 | def get_config(): 24 | config = ml_collections.ConfigDict() 25 | # training 26 | config.training = training = ml_collections.ConfigDict() 27 | training.batch_size = 8 28 | training.n_iters = 2400001 29 | training.snapshot_freq = 50000 30 | training.log_freq = 50 31 | training.eval_freq = 100 32 | training.snapshot_freq_for_preemption = 5000 33 | training.snapshot_sampling = True 34 | training.sde = 'vesde' 35 | training.continuous = True 36 | training.likelihood_weighting = False 37 | training.reduce_mean = False 38 | 39 | # sampling 40 | config.sampling = sampling = ml_collections.ConfigDict() 41 | sampling.method = 'pc' 42 | sampling.predictor = 'reverse_diffusion' 43 | sampling.corrector = 'langevin' 44 | sampling.probability_flow = False 45 | sampling.snr = 0.15 46 | sampling.n_steps_each = 1 47 | sampling.noise_removal = True 48 | 49 | # eval 50 | config.eval = evaluate = ml_collections.ConfigDict() 51 | evaluate.batch_size = 1024 52 | evaluate.num_samples = 50000 53 | evaluate.begin_ckpt = 1 54 | evaluate.end_ckpt = 96 55 | 56 | # data 57 | config.data = data = ml_collections.ConfigDict() 58 | data.dataset = 'CelebAHQ' 59 | data.image_size = 1024 60 | data.centered = False 61 | data.random_flip = True 62 | data.uniform_dequantization = False 63 | data.num_channels = 3 64 | data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' 65 | 66 | # model 67 | config.model = model = ml_collections.ConfigDict() 68 | model.name = 'ncsnpp' 69 | model.scale_by_sigma = True 70 | model.sigma_max = 1348 71 | model.num_scales = 2000 72 | model.ema_rate = 0.9999 73 | model.sigma_min = 0.01 74 | model.normalization = 'GroupNorm' 75 | model.nonlinearity = 'swish' 76 | model.nf = 16 77 | model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) 78 | model.num_res_blocks = 1 79 | model.attn_resolutions = (16,) 80 | model.dropout = 0. 81 | model.resamp_with_conv = True 82 | model.conditional = True 83 | model.fir = True 84 | model.fir_kernel = [1, 3, 3, 1] 85 | model.skip_rescale = True 86 | model.resblock_type = 'biggan' 87 | model.progressive = 'output_skip' 88 | model.progressive_input = 'input_skip' 89 | model.progressive_combine = 'sum' 90 | model.attention_type = 'ddpm' 91 | model.init_scale = 0. 92 | model.fourier_scale = 16 93 | model.conv_size = 3 94 | model.embedding_type = 'fourier' 95 | 96 | # optim 97 | config.optim = optim = ml_collections.ConfigDict() 98 | optim.weight_decay = 0 99 | optim.optimizer = 'Adam' 100 | optim.lr = 2e-4 101 | optim.beta1 = 0.9 102 | optim.amsgrad = False 103 | optim.eps = 1e-8 104 | optim.warmup = 5000 105 | optim.grad_clip = 1. 106 | 107 | config.seed = 42 108 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 109 | 110 | return config 111 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /configs/ve/ffhq_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on FFHQ with VE SDEs.""" 18 | 19 | import ml_collections 20 | import torch 21 | 22 | def get_config(): 23 | config = ml_collections.ConfigDict() 24 | # training 25 | config.training = training = ml_collections.ConfigDict() 26 | training.batch_size = 8 27 | training.n_iters = 2400001 28 | training.snapshot_freq = 50000 29 | training.log_freq = 50 30 | training.eval_freq = 100 31 | training.snapshot_freq_for_preemption = 5000 32 | training.snapshot_sampling = True 33 | training.sde = 'vesde' 34 | training.continuous = True 35 | training.likelihood_weighting = False 36 | training.reduce_mean = True 37 | 38 | # sampling 39 | config.sampling = sampling = ml_collections.ConfigDict() 40 | sampling.method = 'pc' 41 | sampling.predictor = 'reverse_diffusion' 42 | sampling.corrector = 'langevin' 43 | sampling.probability_flow = False 44 | sampling.snr = 0.15 45 | sampling.n_steps_each = 1 46 | sampling.noise_removal = True 47 | 48 | # eval 49 | config.eval = evaluate = ml_collections.ConfigDict() 50 | evaluate.batch_size = 1024 51 | evaluate.num_samples = 50000 52 | evaluate.begin_ckpt = 1 53 | evaluate.end_ckpt = 96 54 | 55 | # data 56 | config.data = data = ml_collections.ConfigDict() 57 | data.dataset = 'FFHQ' 58 | data.image_size = 1024 59 | data.centered = False 60 | data.random_flip = True 61 | data.uniform_dequantization = False 62 | data.num_channels = 3 63 | # Plug in your own path to the tfrecords file. 64 | data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' 65 | 66 | # model 67 | config.model = model = ml_collections.ConfigDict() 68 | model.name = 'ncsnpp' 69 | model.scale_by_sigma = True 70 | model.sigma_max = 1348 71 | model.num_scales = 2000 72 | model.ema_rate = 0.9999 73 | model.sigma_min = 0.01 74 | model.normalization = 'GroupNorm' 75 | model.nonlinearity = 'swish' 76 | model.nf = 16 77 | model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) 78 | model.num_res_blocks = 1 79 | model.attn_resolutions = (16,) 80 | model.dropout = 0. 81 | model.resamp_with_conv = True 82 | model.conditional = True 83 | model.fir = True 84 | model.fir_kernel = [1, 3, 3, 1] 85 | model.skip_rescale = True 86 | model.resblock_type = 'biggan' 87 | model.progressive = 'output_skip' 88 | model.progressive_input = 'input_skip' 89 | model.progressive_combine = 'sum' 90 | model.attention_type = 'ddpm' 91 | model.init_scale = 0. 92 | model.fourier_scale = 16 93 | model.conv_size = 3 94 | model.embedding_type = 'fourier' 95 | 96 | # optim 97 | config.optim = optim = ml_collections.ConfigDict() 98 | optim.weight_decay = 0 99 | optim.optimizer = 'Adam' 100 | optim.lr = 2e-4 101 | optim.beta1 = 0.9 102 | optim.amsgrad = False 103 | optim.eps = 1e-8 104 | optim.warmup = 5000 105 | optim.grad_clip = 1. 106 | 107 | config.seed = 42 108 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 109 | 110 | return config 111 | -------------------------------------------------------------------------------- /likelihood.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | # pytype: skip-file 18 | """Various sampling methods.""" 19 | 20 | import torch 21 | import numpy as np 22 | from scipy import integrate 23 | from models import utils as mutils 24 | 25 | 26 | def get_div_fn(fn): 27 | """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" 28 | 29 | def div_fn(x, t, eps): 30 | with torch.enable_grad(): 31 | x.requires_grad_(True) 32 | fn_eps = torch.sum(fn(x, t) * eps) 33 | grad_fn_eps = torch.autograd.grad(fn_eps, x)[0] 34 | x.requires_grad_(False) 35 | return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape)))) 36 | 37 | return div_fn 38 | 39 | 40 | def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher', 41 | rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5): 42 | """Create a function to compute the unbiased log-likelihood estimate of a given data point. 43 | 44 | Args: 45 | sde: A `sde_lib.SDE` object that represents the forward SDE. 46 | inverse_scaler: The inverse data normalizer. 47 | hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. 48 | rtol: A `float` number. The relative tolerance level of the black-box ODE solver. 49 | atol: A `float` number. The absolute tolerance level of the black-box ODE solver. 50 | method: A `str`. The algorithm for the black-box ODE solver. 51 | See documentation for `scipy.integrate.solve_ivp`. 52 | eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. 53 | 54 | Returns: 55 | A function that a batch of data points and returns the log-likelihoods in bits/dim, 56 | the latent code, and the number of function evaluations cost by computation. 57 | """ 58 | 59 | def drift_fn(model, x, t): 60 | """The drift function of the reverse-time SDE.""" 61 | score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True) 62 | # Probability flow ODE is a special case of Reverse SDE 63 | rsde = sde.reverse(score_fn, probability_flow=True) 64 | return rsde.sde(x, t)[0] 65 | 66 | def div_fn(model, x, t, noise): 67 | return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise) 68 | 69 | def likelihood_fn(model, data): 70 | """Compute an unbiased estimate to the log-likelihood in bits/dim. 71 | 72 | Args: 73 | model: A score model. 74 | data: A PyTorch tensor. 75 | 76 | Returns: 77 | bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim. 78 | z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the 79 | probability flow ODE. 80 | nfe: An integer. The number of function evaluations used for running the black-box ODE solver. 81 | """ 82 | with torch.no_grad(): 83 | shape = data.shape 84 | if hutchinson_type == 'Gaussian': 85 | epsilon = torch.randn_like(data) 86 | elif hutchinson_type == 'Rademacher': 87 | epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1. 88 | else: 89 | raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") 90 | 91 | def ode_func(t, x): 92 | sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) 93 | vec_t = torch.ones(sample.shape[0], device=sample.device) * t 94 | drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t)) 95 | logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon)) 96 | return np.concatenate([drift, logp_grad], axis=0) 97 | 98 | init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) 99 | solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) 100 | nfe = solution.nfev 101 | zp = solution.y[:, -1] 102 | z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32) 103 | delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32) 104 | prior_logp = sde.prior_logp(z) 105 | bpd = -(prior_logp + delta_logp) / np.log(2) 106 | N = np.prod(shape[1:]) 107 | bpd = bpd / N 108 | # A hack to convert log-likelihoods to bits/dim 109 | offset = 7. - inverse_scaler(-1.) 110 | bpd = bpd + offset 111 | return bpd, z, nfe 112 | 113 | return likelihood_fn 114 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for computing FID/Inception scores.""" 17 | 18 | import jax 19 | import numpy as np 20 | import six 21 | import tensorflow as tf 22 | import tensorflow_gan as tfgan 23 | import tensorflow_hub as tfhub 24 | 25 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 26 | INCEPTION_OUTPUT = 'logits' 27 | INCEPTION_FINAL_POOL = 'pool_3' 28 | _DEFAULT_DTYPES = { 29 | INCEPTION_OUTPUT: tf.float32, 30 | INCEPTION_FINAL_POOL: tf.float32 31 | } 32 | INCEPTION_DEFAULT_IMAGE_SIZE = 299 33 | 34 | 35 | def get_inception_model(inceptionv3=False): 36 | if inceptionv3: 37 | return tfhub.load( 38 | 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4') 39 | else: 40 | return tfhub.load(INCEPTION_TFHUB) 41 | 42 | 43 | def load_dataset_stats(config): 44 | """Load the pre-computed dataset statistics.""" 45 | if config.data.dataset == 'CIFAR10': 46 | filename = 'assets/stats/cifar10_stats.npz' 47 | elif config.data.dataset == 'CELEBA': 48 | filename = 'assets/stats/celeba_stats.npz' 49 | elif config.data.dataset == 'LSUN': 50 | filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz' 51 | else: 52 | raise ValueError(f'Dataset {config.data.dataset} stats not found.') 53 | 54 | with tf.io.gfile.GFile(filename, 'rb') as fin: 55 | stats = np.load(fin) 56 | return stats 57 | 58 | 59 | def classifier_fn_from_tfhub(output_fields, inception_model, 60 | return_tensor=False): 61 | """Returns a function that can be as a classifier function. 62 | 63 | Copied from tfgan but avoid loading the model each time calling _classifier_fn 64 | 65 | Args: 66 | output_fields: A string, list, or `None`. If present, assume the module 67 | outputs a dictionary, and select this field. 68 | inception_model: A model loaded from TFHub. 69 | return_tensor: If `True`, return a single tensor instead of a dictionary. 70 | 71 | Returns: 72 | A one-argument function that takes an image Tensor and returns outputs. 73 | """ 74 | if isinstance(output_fields, six.string_types): 75 | output_fields = [output_fields] 76 | 77 | def _classifier_fn(images): 78 | output = inception_model(images) 79 | if output_fields is not None: 80 | output = {x: output[x] for x in output_fields} 81 | if return_tensor: 82 | assert len(output) == 1 83 | output = list(output.values())[0] 84 | return tf.nest.map_structure(tf.compat.v1.layers.flatten, output) 85 | 86 | return _classifier_fn 87 | 88 | 89 | @tf.function 90 | def run_inception_jit(inputs, 91 | inception_model, 92 | num_batches=1, 93 | inceptionv3=False): 94 | """Running the inception network. Assuming input is within [0, 255].""" 95 | if not inceptionv3: 96 | inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5 97 | else: 98 | inputs = tf.cast(inputs, tf.float32) / 255. 99 | 100 | return tfgan.eval.run_classifier_fn( 101 | inputs, 102 | num_batches=num_batches, 103 | classifier_fn=classifier_fn_from_tfhub(None, inception_model), 104 | dtypes=_DEFAULT_DTYPES) 105 | 106 | 107 | @tf.function 108 | def run_inception_distributed(input_tensor, 109 | inception_model, 110 | num_batches=1, 111 | inceptionv3=False): 112 | """Distribute the inception network computation to all available TPUs. 113 | 114 | Args: 115 | input_tensor: The input images. Assumed to be within [0, 255]. 116 | inception_model: The inception network model obtained from `tfhub`. 117 | num_batches: The number of batches used for dividing the input. 118 | inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1. 119 | 120 | Returns: 121 | A dictionary with key `pool_3` and `logits`, representing the pool_3 and 122 | logits of the inception network respectively. 123 | """ 124 | num_tpus = jax.local_device_count() 125 | input_tensors = tf.split(input_tensor, num_tpus, axis=0) 126 | pool3 = [] 127 | logits = [] if not inceptionv3 else None 128 | device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}' 129 | for i, tensor in enumerate(input_tensors): 130 | with tf.device(device_format.format(i)): 131 | tensor_on_device = tf.identity(tensor) 132 | res = run_inception_jit( 133 | tensor_on_device, inception_model, num_batches=num_batches, 134 | inceptionv3=inceptionv3) 135 | 136 | if not inceptionv3: 137 | pool3.append(res['pool_3']) 138 | logits.append(res['logits']) # pytype: disable=attribute-error 139 | else: 140 | pool3.append(res) 141 | 142 | with tf.device('/CPU'): 143 | return { 144 | 'pool_3': tf.concat(pool3, axis=0), 145 | 'logits': tf.concat(logits, axis=0) if not inceptionv3 else None 146 | } 147 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | import sde_lib 21 | import numpy as np 22 | 23 | 24 | _MODELS = {} 25 | 26 | 27 | def register_model(cls=None, *, name=None): 28 | """A decorator for registering model classes.""" 29 | 30 | def _register(cls): 31 | if name is None: 32 | local_name = cls.__name__ 33 | else: 34 | local_name = name 35 | if local_name in _MODELS: 36 | raise ValueError(f'Already registered model with name: {local_name}') 37 | _MODELS[local_name] = cls 38 | return cls 39 | 40 | if cls is None: 41 | return _register 42 | else: 43 | return _register(cls) 44 | 45 | 46 | def get_model(name): 47 | return _MODELS[name] 48 | 49 | 50 | def get_sigmas(config): 51 | """Get sigmas --- the set of noise levels for SMLD from config files. 52 | Args: 53 | config: A ConfigDict object parsed from the config file 54 | Returns: 55 | sigmas: a jax numpy arrary of noise levels 56 | """ 57 | sigmas = np.exp( 58 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 59 | 60 | return sigmas 61 | 62 | 63 | def get_ddpm_params(config): 64 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 65 | num_diffusion_timesteps = 1000 66 | # parameters need to be adapted if number of time steps differs from 1000 67 | beta_start = config.model.beta_min / config.model.num_scales 68 | beta_end = config.model.beta_max / config.model.num_scales 69 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 70 | 71 | alphas = 1. - betas 72 | alphas_cumprod = np.cumprod(alphas, axis=0) 73 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 74 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 75 | 76 | return { 77 | 'betas': betas, 78 | 'alphas': alphas, 79 | 'alphas_cumprod': alphas_cumprod, 80 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 81 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 82 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 83 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 84 | 'num_diffusion_timesteps': num_diffusion_timesteps 85 | } 86 | 87 | 88 | def create_model(config): 89 | """Create the score model.""" 90 | model_name = config.model.name 91 | score_model = get_model(model_name)(config) 92 | score_model = score_model.to(config.device) 93 | if config.device != 'cpu': 94 | score_model = torch.nn.DataParallel(score_model) 95 | return score_model 96 | 97 | 98 | def get_model_fn(model, train=False): 99 | """Create a function to give the output of the score-based model. 100 | 101 | Args: 102 | model: The score model. 103 | train: `True` for training and `False` for evaluation. 104 | 105 | Returns: 106 | A model function. 107 | """ 108 | 109 | def model_fn(x, labels): 110 | """Compute the output of the score-based model. 111 | 112 | Args: 113 | x: A mini-batch of input data. 114 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 115 | for different models. 116 | 117 | Returns: 118 | A tuple of (model output, new mutable states) 119 | """ 120 | if not train: 121 | model.eval() 122 | return model(x, labels) 123 | else: 124 | model.train() 125 | return model(x, labels) 126 | 127 | return model_fn 128 | 129 | 130 | def get_score_fn(sde, model, train=False, continuous=False): 131 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 132 | 133 | Args: 134 | sde: An `sde_lib.SDE` object that represents the forward SDE. 135 | model: A score model. 136 | train: `True` for training and `False` for evaluation. 137 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 138 | 139 | Returns: 140 | A score function. 141 | """ 142 | model_fn = get_model_fn(model, train=train) 143 | 144 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 145 | def score_fn(x, t): 146 | # Scale neural network output by standard deviation and flip sign 147 | if continuous or isinstance(sde, sde_lib.subVPSDE): 148 | # For VP-trained models, t=0 corresponds to the lowest noise level 149 | # The maximum value of time embedding is assumed to 999 for 150 | # continuously-trained models. 151 | labels = t * 999 152 | score = model_fn(x, labels) 153 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 154 | else: 155 | # For VP-trained models, t=0 corresponds to the lowest noise level 156 | labels = t * (sde.N - 1) 157 | score = model_fn(x, labels) 158 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 159 | 160 | score = -score / std[:, None, None, None] 161 | return score 162 | 163 | elif isinstance(sde, sde_lib.VESDE): 164 | def score_fn(x, t): 165 | if continuous: 166 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 167 | else: 168 | # For VE-trained models, t=0 corresponds to the highest noise level 169 | labels = sde.T - t 170 | labels *= sde.N - 1 171 | labels = torch.round(labels).long() 172 | 173 | score = model_fn(x, labels) 174 | return score 175 | 176 | else: 177 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 178 | 179 | return score_fn 180 | 181 | 182 | def to_flattened_numpy(x): 183 | """Flatten a torch tensor `x` and convert it to numpy.""" 184 | return x.detach().cpu().numpy().reshape((-1,)) 185 | 186 | 187 | def from_flattened_numpy(x, shape): 188 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 189 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /models/ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import functools 25 | 26 | from . import utils, layers, normalization 27 | 28 | RefineBlock = layers.RefineBlock 29 | ResidualBlock = layers.ResidualBlock 30 | ResnetBlockDDPM = layers.ResnetBlockDDPM 31 | Upsample = layers.Upsample 32 | Downsample = layers.Downsample 33 | conv3x3 = layers.ddpm_conv3x3 34 | get_act = layers.get_act 35 | get_normalization = normalization.get_normalization 36 | default_initializer = layers.default_init 37 | 38 | 39 | @utils.register_model(name='ddpm') 40 | class DDPM(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 45 | 46 | self.nf = nf = config.model.nf 47 | ch_mult = config.model.ch_mult 48 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 49 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 50 | dropout = config.model.dropout 51 | resamp_with_conv = config.model.resamp_with_conv 52 | self.num_resolutions = num_resolutions = len(ch_mult) 53 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 54 | 55 | AttnBlock = functools.partial(layers.AttnBlock) 56 | self.conditional = conditional = config.model.conditional 57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 58 | if conditional: 59 | # Condition on noise levels. 60 | modules = [nn.Linear(nf, nf * 4)] 61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 62 | nn.init.zeros_(modules[0].bias) 63 | modules.append(nn.Linear(nf * 4, nf * 4)) 64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 65 | nn.init.zeros_(modules[1].bias) 66 | 67 | self.centered = config.data.centered 68 | channels = config.data.num_channels 69 | 70 | # Downsampling block 71 | modules.append(conv3x3(channels, nf)) 72 | hs_c = [nf] 73 | in_ch = nf 74 | for i_level in range(num_resolutions): 75 | # Residual blocks for this resolution 76 | for i_block in range(num_res_blocks): 77 | out_ch = nf * ch_mult[i_level] 78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 79 | in_ch = out_ch 80 | if all_resolutions[i_level] in attn_resolutions: 81 | modules.append(AttnBlock(channels=in_ch)) 82 | hs_c.append(in_ch) 83 | if i_level != num_resolutions - 1: 84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 85 | hs_c.append(in_ch) 86 | 87 | in_ch = hs_c[-1] 88 | modules.append(ResnetBlock(in_ch=in_ch)) 89 | modules.append(AttnBlock(channels=in_ch)) 90 | modules.append(ResnetBlock(in_ch=in_ch)) 91 | 92 | # Upsampling block 93 | for i_level in reversed(range(num_resolutions)): 94 | for i_block in range(num_res_blocks + 1): 95 | out_ch = nf * ch_mult[i_level] 96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 97 | in_ch = out_ch 98 | if all_resolutions[i_level] in attn_resolutions: 99 | modules.append(AttnBlock(channels=in_ch)) 100 | if i_level != 0: 101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 102 | 103 | assert not hs_c 104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 105 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 106 | self.all_modules = nn.ModuleList(modules) 107 | 108 | self.scale_by_sigma = config.model.scale_by_sigma 109 | 110 | def forward(self, x, labels): 111 | modules = self.all_modules 112 | m_idx = 0 113 | if self.conditional: 114 | # timestep/scale embedding 115 | timesteps = labels 116 | temb = layers.get_timestep_embedding(timesteps, self.nf) 117 | temb = modules[m_idx](temb) 118 | m_idx += 1 119 | temb = modules[m_idx](self.act(temb)) 120 | m_idx += 1 121 | else: 122 | temb = None 123 | 124 | if self.centered: 125 | # Input is in [-1, 1] 126 | h = x 127 | else: 128 | # Input is in [0, 1] 129 | h = 2 * x - 1. 130 | 131 | # Downsampling block 132 | hs = [modules[m_idx](h)] 133 | m_idx += 1 134 | for i_level in range(self.num_resolutions): 135 | # Residual blocks for this resolution 136 | for i_block in range(self.num_res_blocks): 137 | h = modules[m_idx](hs[-1], temb) 138 | m_idx += 1 139 | if h.shape[-1] in self.attn_resolutions: 140 | h = modules[m_idx](h) 141 | m_idx += 1 142 | hs.append(h) 143 | if i_level != self.num_resolutions - 1: 144 | hs.append(modules[m_idx](hs[-1])) 145 | m_idx += 1 146 | 147 | h = hs[-1] 148 | h = modules[m_idx](h, temb) 149 | m_idx += 1 150 | h = modules[m_idx](h) 151 | m_idx += 1 152 | h = modules[m_idx](h, temb) 153 | m_idx += 1 154 | 155 | # Upsampling block 156 | for i_level in reversed(range(self.num_resolutions)): 157 | for i_block in range(self.num_res_blocks + 1): 158 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 159 | m_idx += 1 160 | if h.shape[-1] in self.attn_resolutions: 161 | h = modules[m_idx](h) 162 | m_idx += 1 163 | if i_level != 0: 164 | h = modules[m_idx](h) 165 | m_idx += 1 166 | 167 | assert not hs 168 | h = self.act(modules[m_idx](h)) 169 | m_idx += 1 170 | h = modules[m_idx](h) 171 | m_idx += 1 172 | assert m_idx == len(modules) 173 | 174 | if self.scale_by_sigma: 175 | # Divide the output by sigmas. Useful for training with the NCSN loss. 176 | # The DDPM loss scales the network output by sigma in the loss function, 177 | # so no need of doing it here. 178 | used_sigmas = self.sigmas[labels, None, None, None] 179 | h = h / used_sigmas 180 | 181 | return h 182 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import jax 19 | import tensorflow as tf 20 | import tensorflow_datasets as tfds 21 | 22 | 23 | def get_data_scaler(config): 24 | """Data normalizer. Assume data are always in [0, 1].""" 25 | if config.data.centered: 26 | # Rescale to [-1, 1] 27 | return lambda x: x * 2. - 1. 28 | else: 29 | return lambda x: x 30 | 31 | 32 | def get_data_inverse_scaler(config): 33 | """Inverse data normalizer.""" 34 | if config.data.centered: 35 | # Rescale [-1, 1] to [0, 1] 36 | return lambda x: (x + 1.) / 2. 37 | else: 38 | return lambda x: x 39 | 40 | 41 | def crop_resize(image, resolution): 42 | """Crop and resize an image to the given resolution.""" 43 | crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) 44 | h, w = tf.shape(image)[0], tf.shape(image)[1] 45 | image = image[(h - crop) // 2:(h + crop) // 2, 46 | (w - crop) // 2:(w + crop) // 2] 47 | image = tf.image.resize( 48 | image, 49 | size=(resolution, resolution), 50 | antialias=True, 51 | method=tf.image.ResizeMethod.BICUBIC) 52 | return tf.cast(image, tf.uint8) 53 | 54 | 55 | def resize_small(image, resolution): 56 | """Shrink an image to the given resolution.""" 57 | h, w = image.shape[0], image.shape[1] 58 | ratio = resolution / min(h, w) 59 | h = tf.round(h * ratio, tf.int32) 60 | w = tf.round(w * ratio, tf.int32) 61 | return tf.image.resize(image, [h, w], antialias=True) 62 | 63 | 64 | def central_crop(image, size): 65 | """Crop the center of an image to the given size.""" 66 | top = (image.shape[0] - size) // 2 67 | left = (image.shape[1] - size) // 2 68 | return tf.image.crop_to_bounding_box(image, top, left, size, size) 69 | 70 | 71 | def get_dataset(config, uniform_dequantization=False, evaluation=False): 72 | """Create data loaders for training and evaluation. 73 | 74 | Args: 75 | config: A ml_collection.ConfigDict parsed from config files. 76 | uniform_dequantization: If `True`, add uniform dequantization to images. 77 | evaluation: If `True`, fix number of epochs to 1. 78 | 79 | Returns: 80 | train_ds, eval_ds, dataset_builder. 81 | """ 82 | # Compute batch size for this worker. 83 | batch_size = config.training.batch_size if not evaluation else config.eval.batch_size 84 | if batch_size % jax.device_count() != 0: 85 | raise ValueError(f'Batch sizes ({batch_size} must be divided by' 86 | f'the number of devices ({jax.device_count()})') 87 | 88 | # Reduce this when image resolution is too large and data pointer is stored 89 | shuffle_buffer_size = 10000 90 | prefetch_size = tf.data.experimental.AUTOTUNE 91 | num_epochs = None if not evaluation else 1 92 | 93 | # Create dataset builders for each dataset. 94 | if config.data.dataset == 'CIFAR10': 95 | dataset_builder = tfds.builder('cifar10') 96 | train_split_name = 'train' 97 | eval_split_name = 'test' 98 | 99 | def resize_op(img): 100 | img = tf.image.convert_image_dtype(img, tf.float32) 101 | return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) 102 | 103 | elif config.data.dataset == 'SVHN': 104 | dataset_builder = tfds.builder('svhn_cropped') 105 | train_split_name = 'train' 106 | eval_split_name = 'test' 107 | 108 | def resize_op(img): 109 | img = tf.image.convert_image_dtype(img, tf.float32) 110 | return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) 111 | 112 | elif config.data.dataset == 'CELEBA': 113 | dataset_builder = tfds.builder('celeb_a') 114 | train_split_name = 'train' 115 | eval_split_name = 'validation' 116 | 117 | def resize_op(img): 118 | img = tf.image.convert_image_dtype(img, tf.float32) 119 | img = central_crop(img, 140) 120 | img = resize_small(img, config.data.image_size) 121 | return img 122 | 123 | elif config.data.dataset == 'LSUN': 124 | dataset_builder = tfds.builder(f'lsun/{config.data.category}') 125 | train_split_name = 'train' 126 | eval_split_name = 'validation' 127 | 128 | if config.data.image_size == 128: 129 | def resize_op(img): 130 | img = tf.image.convert_image_dtype(img, tf.float32) 131 | img = resize_small(img, config.data.image_size) 132 | img = central_crop(img, config.data.image_size) 133 | return img 134 | 135 | else: 136 | def resize_op(img): 137 | img = crop_resize(img, config.data.image_size) 138 | img = tf.image.convert_image_dtype(img, tf.float32) 139 | return img 140 | 141 | elif config.data.dataset in ['FFHQ', 'CelebAHQ']: 142 | dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path) 143 | train_split_name = eval_split_name = 'train' 144 | 145 | else: 146 | raise NotImplementedError( 147 | f'Dataset {config.data.dataset} not yet supported.') 148 | 149 | # Customize preprocess functions for each dataset. 150 | if config.data.dataset in ['FFHQ', 'CelebAHQ']: 151 | def preprocess_fn(d): 152 | sample = tf.io.parse_single_example(d, features={ 153 | 'shape': tf.io.FixedLenFeature([3], tf.int64), 154 | 'data': tf.io.FixedLenFeature([], tf.string)}) 155 | data = tf.io.decode_raw(sample['data'], tf.uint8) 156 | data = tf.reshape(data, sample['shape']) 157 | data = tf.transpose(data, (1, 2, 0)) 158 | img = tf.image.convert_image_dtype(data, tf.float32) 159 | if config.data.random_flip and not evaluation: 160 | img = tf.image.random_flip_left_right(img) 161 | if uniform_dequantization: 162 | img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. 163 | return dict(image=img, label=None) 164 | 165 | else: 166 | def preprocess_fn(d): 167 | """Basic preprocessing function scales data to [0, 1) and randomly flips.""" 168 | img = resize_op(d['image']) 169 | if config.data.random_flip and not evaluation: 170 | img = tf.image.random_flip_left_right(img) 171 | if uniform_dequantization: 172 | img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. 173 | 174 | return dict(image=img, label=d.get('label', None)) 175 | 176 | def create_dataset(dataset_builder, split): 177 | dataset_options = tf.data.Options() 178 | dataset_options.experimental_optimization.map_parallelization = True 179 | dataset_options.experimental_threading.private_threadpool_size = 48 180 | dataset_options.experimental_threading.max_intra_op_parallelism = 1 181 | read_config = tfds.ReadConfig(options=dataset_options) 182 | if isinstance(dataset_builder, tfds.core.DatasetBuilder): 183 | dataset_builder.download_and_prepare() 184 | ds = dataset_builder.as_dataset( 185 | split=split, shuffle_files=True, read_config=read_config) 186 | else: 187 | ds = dataset_builder.with_options(dataset_options) 188 | ds = ds.repeat(count=num_epochs) 189 | ds = ds.shuffle(shuffle_buffer_size) 190 | ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 191 | ds = ds.batch(batch_size, drop_remainder=True) 192 | return ds.prefetch(prefetch_size) 193 | 194 | train_ds = create_dataset(dataset_builder, train_split_name) 195 | eval_ds = create_dataset(dataset_builder, eval_split_name) 196 | return train_ds, eval_ds, dataset_builder 197 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /sde_lib.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | import abc 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class SDE(abc.ABC): 8 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 9 | 10 | def __init__(self, N): 11 | """Construct an SDE. 12 | 13 | Args: 14 | N: number of discretization time steps. 15 | """ 16 | super().__init__() 17 | self.N = N 18 | 19 | @property 20 | @abc.abstractmethod 21 | def T(self): 22 | """End time of the SDE.""" 23 | pass 24 | 25 | @abc.abstractmethod 26 | def sde(self, x, t): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def marginal_prob(self, x, t): 31 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" 32 | pass 33 | 34 | @abc.abstractmethod 35 | def prior_sampling(self, shape): 36 | """Generate one sample from the prior distribution, $p_T(x)$.""" 37 | pass 38 | 39 | @abc.abstractmethod 40 | def prior_logp(self, z): 41 | """Compute log-density of the prior distribution. 42 | 43 | Useful for computing the log-likelihood via probability flow ODE. 44 | 45 | Args: 46 | z: latent code 47 | Returns: 48 | log probability density 49 | """ 50 | pass 51 | 52 | def discretize(self, x, t): 53 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 54 | 55 | Useful for reverse diffusion sampling and probabiliy flow sampling. 56 | Defaults to Euler-Maruyama discretization. 57 | 58 | Args: 59 | x: a torch tensor 60 | t: a torch float representing the time step (from 0 to `self.T`) 61 | 62 | Returns: 63 | f, G 64 | """ 65 | dt = 1 / self.N 66 | drift, diffusion = self.sde(x, t) 67 | f = drift * dt 68 | G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) 69 | return f, G 70 | 71 | def reverse(self, score_fn, probability_flow=False): 72 | """Create the reverse-time SDE/ODE. 73 | 74 | Args: 75 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 76 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 77 | """ 78 | N = self.N 79 | T = self.T 80 | sde_fn = self.sde 81 | discretize_fn = self.discretize 82 | 83 | # Build the class for reverse-time SDE. 84 | class RSDE(self.__class__): 85 | def __init__(self): 86 | self.N = N 87 | self.probability_flow = probability_flow 88 | 89 | @property 90 | def T(self): 91 | return T 92 | 93 | def sde(self, x, t): 94 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 95 | drift, diffusion = sde_fn(x, t) 96 | score = score_fn(x, t) 97 | drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 98 | # Set the diffusion function to zero for ODEs. 99 | diffusion = 0. if self.probability_flow else diffusion 100 | return drift, diffusion 101 | 102 | def discretize(self, x, t): 103 | """Create discretized iteration rules for the reverse diffusion sampler.""" 104 | f, G = discretize_fn(x, t) 105 | rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.) 106 | rev_G = torch.zeros_like(G) if self.probability_flow else G 107 | return rev_f, rev_G 108 | 109 | return RSDE() 110 | 111 | 112 | class VPSDE(SDE): 113 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 114 | """Construct a Variance Preserving SDE. 115 | 116 | Args: 117 | beta_min: value of beta(0) 118 | beta_max: value of beta(1) 119 | N: number of discretization steps 120 | """ 121 | super().__init__(N) 122 | self.beta_0 = beta_min 123 | self.beta_1 = beta_max 124 | self.N = N 125 | self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) 126 | self.alphas = 1. - self.discrete_betas 127 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 128 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 129 | self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 130 | 131 | @property 132 | def T(self): 133 | return 1 134 | 135 | def sde(self, x, t): 136 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 137 | drift = -0.5 * beta_t[:, None, None, None] * x 138 | diffusion = torch.sqrt(beta_t) 139 | return drift, diffusion 140 | 141 | def marginal_prob(self, x, t): 142 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 143 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x 144 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 145 | return mean, std 146 | 147 | def prior_sampling(self, shape): 148 | return torch.randn(*shape) 149 | 150 | def prior_logp(self, z): 151 | shape = z.shape 152 | N = np.prod(shape[1:]) 153 | logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 154 | return logps 155 | 156 | def discretize(self, x, t): 157 | """DDPM discretization.""" 158 | timestep = (t * (self.N - 1) / self.T).long() 159 | beta = self.discrete_betas.to(x.device)[timestep] 160 | alpha = self.alphas.to(x.device)[timestep] 161 | sqrt_beta = torch.sqrt(beta) 162 | f = torch.sqrt(alpha)[:, None, None, None] * x - x 163 | G = sqrt_beta 164 | return f, G 165 | 166 | 167 | class subVPSDE(SDE): 168 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 169 | """Construct the sub-VP SDE that excels at likelihoods. 170 | 171 | Args: 172 | beta_min: value of beta(0) 173 | beta_max: value of beta(1) 174 | N: number of discretization steps 175 | """ 176 | super().__init__(N) 177 | self.beta_0 = beta_min 178 | self.beta_1 = beta_max 179 | self.N = N 180 | 181 | @property 182 | def T(self): 183 | return 1 184 | 185 | def sde(self, x, t): 186 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 187 | drift = -0.5 * beta_t[:, None, None, None] * x 188 | discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) 189 | diffusion = torch.sqrt(beta_t * discount) 190 | return drift, diffusion 191 | 192 | def marginal_prob(self, x, t): 193 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 194 | mean = torch.exp(log_mean_coeff)[:, None, None, None] * x 195 | std = 1 - torch.exp(2. * log_mean_coeff) 196 | return mean, std 197 | 198 | def prior_sampling(self, shape): 199 | return torch.randn(*shape) 200 | 201 | def prior_logp(self, z): 202 | shape = z.shape 203 | N = np.prod(shape[1:]) 204 | return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 205 | 206 | 207 | class VESDE(SDE): 208 | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): 209 | """Construct a Variance Exploding SDE. 210 | 211 | Args: 212 | sigma_min: smallest sigma. 213 | sigma_max: largest sigma. 214 | N: number of discretization steps 215 | """ 216 | super().__init__(N) 217 | self.sigma_min = sigma_min 218 | self.sigma_max = sigma_max 219 | self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) 220 | self.N = N 221 | 222 | @property 223 | def T(self): 224 | return 1 225 | 226 | def sde(self, x, t): 227 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 228 | drift = torch.zeros_like(x) 229 | diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), 230 | device=t.device)) 231 | return drift, diffusion 232 | 233 | def marginal_prob(self, x, t): 234 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 235 | mean = x 236 | return mean, std 237 | 238 | def prior_sampling(self, shape): 239 | return torch.randn(*shape) * self.sigma_max 240 | 241 | def prior_logp(self, z): 242 | shape = z.shape 243 | N = np.prod(shape[1:]) 244 | return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) 245 | 246 | def discretize(self, x, t): 247 | """SMLD(NCSN) discretization.""" 248 | timestep = (t * (self.N - 1) / self.T).long() 249 | sigma = self.discrete_sigmas.to(t.device)[timestep] 250 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), 251 | self.discrete_sigmas[timestep - 1].to(t.device)) 252 | f = torch.zeros_like(x) 253 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) 254 | return f, G -------------------------------------------------------------------------------- /controllable_generation.py: -------------------------------------------------------------------------------- 1 | from models import utils as mutils 2 | import torch 3 | import numpy as np 4 | from sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn 5 | import functools 6 | 7 | 8 | def get_pc_inpainter(sde, predictor, corrector, inverse_scaler, snr, 9 | n_steps=1, probability_flow=False, continuous=False, 10 | denoise=True, eps=1e-5): 11 | """Create an image inpainting function that uses PC samplers. 12 | 13 | Args: 14 | sde: An `sde_lib.SDE` object that represents the forward SDE. 15 | predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm. 16 | corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm. 17 | inverse_scaler: The inverse data normalizer. 18 | snr: A `float` number. The signal-to-noise ratio for the corrector. 19 | n_steps: An integer. The number of corrector steps per update of the corrector. 20 | probability_flow: If `True`, predictor solves the probability flow ODE for sampling. 21 | continuous: `True` indicates that the score-based model was trained with continuous time. 22 | denoise: If `True`, add one-step denoising to final samples. 23 | eps: A `float` number. The reverse-time SDE/ODE is integrated to `eps` for numerical stability. 24 | 25 | Returns: 26 | An inpainting function. 27 | """ 28 | # Define predictor & corrector 29 | predictor_update_fn = functools.partial(shared_predictor_update_fn, 30 | sde=sde, 31 | predictor=predictor, 32 | probability_flow=probability_flow, 33 | continuous=continuous) 34 | corrector_update_fn = functools.partial(shared_corrector_update_fn, 35 | sde=sde, 36 | corrector=corrector, 37 | continuous=continuous, 38 | snr=snr, 39 | n_steps=n_steps) 40 | 41 | def get_inpaint_update_fn(update_fn): 42 | """Modify the update function of predictor & corrector to incorporate data information.""" 43 | 44 | def inpaint_update_fn(model, data, mask, x, t): 45 | with torch.no_grad(): 46 | vec_t = torch.ones(data.shape[0], device=data.device) * t 47 | x, x_mean = update_fn(x, vec_t, model=model) 48 | masked_data_mean, std = sde.marginal_prob(data, vec_t) 49 | masked_data = masked_data_mean + torch.randn_like(x) * std[:, None, None, None] 50 | x = x * (1. - mask) + masked_data * mask 51 | x_mean = x * (1. - mask) + masked_data_mean * mask 52 | return x, x_mean 53 | 54 | return inpaint_update_fn 55 | 56 | projector_inpaint_update_fn = get_inpaint_update_fn(predictor_update_fn) 57 | corrector_inpaint_update_fn = get_inpaint_update_fn(corrector_update_fn) 58 | 59 | def pc_inpainter(model, data, mask): 60 | """Predictor-Corrector (PC) sampler for image inpainting. 61 | 62 | Args: 63 | model: A score model. 64 | data: A PyTorch tensor that represents a mini-batch of images to inpaint. 65 | mask: A 0-1 tensor with the same shape of `data`. Value `1` marks known pixels, 66 | and value `0` marks pixels that require inpainting. 67 | 68 | Returns: 69 | Inpainted (complete) images. 70 | """ 71 | with torch.no_grad(): 72 | # Initial sample 73 | x = data * mask + sde.prior_sampling(data.shape).to(data.device) * (1. - mask) 74 | timesteps = torch.linspace(sde.T, eps, sde.N) 75 | for i in range(sde.N): 76 | t = timesteps[i] 77 | x, x_mean = corrector_inpaint_update_fn(model, data, mask, x, t) 78 | x, x_mean = projector_inpaint_update_fn(model, data, mask, x, t) 79 | 80 | return inverse_scaler(x_mean if denoise else x) 81 | 82 | return pc_inpainter 83 | 84 | 85 | def get_pc_colorizer(sde, predictor, corrector, inverse_scaler, 86 | snr, n_steps=1, probability_flow=False, continuous=False, 87 | denoise=True, eps=1e-5): 88 | """Create a image colorization function based on Predictor-Corrector (PC) sampling. 89 | 90 | Args: 91 | sde: An `sde_lib.SDE` object that represents the forward SDE. 92 | predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm. 93 | corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm. 94 | inverse_scaler: The inverse data normalizer. 95 | snr: A `float` number. The signal-to-noise ratio for correctors. 96 | n_steps: An integer. The number of corrector steps per update of the predictor. 97 | probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor. 98 | continuous: `True` indicates that the score-based model was trained with continuous time steps. 99 | denoise: If `True`, add one-step denoising to final samples. 100 | eps: A `float` number. The SDE/ODE will start from `eps` to avoid numerical stabilities. 101 | 102 | Returns: A colorization function. 103 | """ 104 | 105 | # `M` is an orthonormal matrix to decouple image space to a latent space where the gray-scale image 106 | # occupies a separate channel 107 | M = torch.tensor([[5.7735014e-01, -8.1649649e-01, 4.7008697e-08], 108 | [5.7735026e-01, 4.0824834e-01, 7.0710671e-01], 109 | [5.7735026e-01, 4.0824822e-01, -7.0710683e-01]]) 110 | # `invM` is the inverse transformation of `M` 111 | invM = torch.inverse(M) 112 | 113 | # Decouple a gray-scale image with `M` 114 | def decouple(inputs): 115 | return torch.einsum('bihw,ij->bjhw', inputs, M.to(inputs.device)) 116 | 117 | # The inverse function to `decouple`. 118 | def couple(inputs): 119 | return torch.einsum('bihw,ij->bjhw', inputs, invM.to(inputs.device)) 120 | 121 | predictor_update_fn = functools.partial(shared_predictor_update_fn, 122 | sde=sde, 123 | predictor=predictor, 124 | probability_flow=probability_flow, 125 | continuous=continuous) 126 | corrector_update_fn = functools.partial(shared_corrector_update_fn, 127 | sde=sde, 128 | corrector=corrector, 129 | continuous=continuous, 130 | snr=snr, 131 | n_steps=n_steps) 132 | 133 | def get_colorization_update_fn(update_fn): 134 | """Modify update functions of predictor & corrector to incorporate information of gray-scale images.""" 135 | 136 | def colorization_update_fn(model, gray_scale_img, x, t): 137 | mask = get_mask(x) 138 | vec_t = torch.ones(x.shape[0], device=x.device) * t 139 | x, x_mean = update_fn(x, vec_t, model=model) 140 | masked_data_mean, std = sde.marginal_prob(decouple(gray_scale_img), vec_t) 141 | masked_data = masked_data_mean + torch.randn_like(x) * std[:, None, None, None] 142 | x = couple(decouple(x) * (1. - mask) + masked_data * mask) 143 | x_mean = couple(decouple(x) * (1. - mask) + masked_data_mean * mask) 144 | return x, x_mean 145 | 146 | return colorization_update_fn 147 | 148 | def get_mask(image): 149 | mask = torch.cat([torch.ones_like(image[:, :1, ...]), 150 | torch.zeros_like(image[:, 1:, ...])], dim=1) 151 | return mask 152 | 153 | predictor_colorize_update_fn = get_colorization_update_fn(predictor_update_fn) 154 | corrector_colorize_update_fn = get_colorization_update_fn(corrector_update_fn) 155 | 156 | def pc_colorizer(model, gray_scale_img): 157 | """Colorize gray-scale images using Predictor-Corrector (PC) sampler. 158 | 159 | Args: 160 | model: A score model. 161 | gray_scale_img: A minibatch of gray-scale images. Their R,G,B channels have same values. 162 | 163 | Returns: 164 | Colorized images. 165 | """ 166 | with torch.no_grad(): 167 | shape = gray_scale_img.shape 168 | mask = get_mask(gray_scale_img) 169 | # Initial sample 170 | x = couple(decouple(gray_scale_img) * mask + \ 171 | decouple(sde.prior_sampling(shape).to(gray_scale_img.device) 172 | * (1. - mask))) 173 | timesteps = torch.linspace(sde.T, eps, sde.N) 174 | for i in range(sde.N): 175 | t = timesteps[i] 176 | x, x_mean = corrector_colorize_update_fn(model, gray_scale_img, x, t) 177 | x, x_mean = predictor_colorize_update_fn(model, gray_scale_img, x, t) 178 | 179 | return inverse_scaler(x_mean if denoise else x) 180 | 181 | return pc_colorizer -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions related to loss computation and optimization. 17 | """ 18 | 19 | import torch 20 | import torch.optim as optim 21 | import numpy as np 22 | from models import utils as mutils 23 | from sde_lib import VESDE, VPSDE 24 | 25 | 26 | def get_optimizer(config, params): 27 | """Returns a flax optimizer object based on `config`.""" 28 | if config.optim.optimizer == 'Adam': 29 | optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, 30 | weight_decay=config.optim.weight_decay) 31 | else: 32 | raise NotImplementedError( 33 | f'Optimizer {config.optim.optimizer} not supported yet!') 34 | 35 | return optimizer 36 | 37 | 38 | def optimization_manager(config): 39 | """Returns an optimize_fn based on `config`.""" 40 | 41 | def optimize_fn(optimizer, params, step, lr=config.optim.lr, 42 | warmup=config.optim.warmup, 43 | grad_clip=config.optim.grad_clip): 44 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 45 | if warmup > 0: 46 | for g in optimizer.param_groups: 47 | g['lr'] = lr * np.minimum(step / warmup, 1.0) 48 | if grad_clip >= 0: 49 | torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) 50 | optimizer.step() 51 | 52 | return optimize_fn 53 | 54 | 55 | def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5): 56 | """Create a loss function for training with arbirary SDEs. 57 | 58 | Args: 59 | sde: An `sde_lib.SDE` object that represents the forward SDE. 60 | train: `True` for training loss and `False` for evaluation loss. 61 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 62 | continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires 63 | ad-hoc interpolation to take continuous time steps. 64 | likelihood_weighting: If `True`, weight the mixture of score matching losses 65 | according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper. 66 | eps: A `float` number. The smallest time step to sample from. 67 | 68 | Returns: 69 | A loss function. 70 | """ 71 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 72 | 73 | def loss_fn(model, batch): 74 | """Compute the loss function. 75 | 76 | Args: 77 | model: A score model. 78 | batch: A mini-batch of training data. 79 | 80 | Returns: 81 | loss: A scalar that represents the average loss value across the mini-batch. 82 | """ 83 | score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) 84 | t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps 85 | z = torch.randn_like(batch) 86 | mean, std = sde.marginal_prob(batch, t) 87 | perturbed_data = mean + std[:, None, None, None] * z 88 | score = score_fn(perturbed_data, t) 89 | 90 | if not likelihood_weighting: 91 | losses = torch.square(score * std[:, None, None, None] + z) 92 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) 93 | else: 94 | g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2 95 | losses = torch.square(score + z / std[:, None, None, None]) 96 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2 97 | 98 | loss = torch.mean(losses) 99 | return loss 100 | 101 | return loss_fn 102 | 103 | 104 | def get_smld_loss_fn(vesde, train, reduce_mean=False): 105 | """Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work.""" 106 | assert isinstance(vesde, VESDE), "SMLD training only works for VESDEs." 107 | 108 | # Previous SMLD models assume descending sigmas 109 | smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,)) 110 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 111 | 112 | def loss_fn(model, batch): 113 | model_fn = mutils.get_model_fn(model, train=train) 114 | labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device) 115 | sigmas = smld_sigma_array.to(batch.device)[labels] 116 | noise = torch.randn_like(batch) * sigmas[:, None, None, None] 117 | perturbed_data = noise + batch 118 | score = model_fn(perturbed_data, labels) 119 | target = -noise / (sigmas ** 2)[:, None, None, None] 120 | losses = torch.square(score - target) 121 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2 122 | loss = torch.mean(losses) 123 | return loss 124 | 125 | return loss_fn 126 | 127 | 128 | def get_ddpm_loss_fn(vpsde, train, reduce_mean=True): 129 | """Legacy code to reproduce previous results on DDPM. Not recommended for new work.""" 130 | assert isinstance(vpsde, VPSDE), "DDPM training only works for VPSDEs." 131 | 132 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) 133 | 134 | def loss_fn(model, batch): 135 | model_fn = mutils.get_model_fn(model, train=train) 136 | labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) 137 | sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) 138 | sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) 139 | noise = torch.randn_like(batch) 140 | perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \ 141 | sqrt_1m_alphas_cumprod[labels, None, None, None] * noise 142 | score = model_fn(perturbed_data, labels) 143 | losses = torch.square(score - noise) 144 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) 145 | loss = torch.mean(losses) 146 | return loss 147 | 148 | return loss_fn 149 | 150 | 151 | def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False): 152 | """Create a one-step training/evaluation function. 153 | 154 | Args: 155 | sde: An `sde_lib.SDE` object that represents the forward SDE. 156 | optimize_fn: An optimization function. 157 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 158 | continuous: `True` indicates that the model is defined to take continuous time steps. 159 | likelihood_weighting: If `True`, weight the mixture of score matching losses according to 160 | https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. 161 | 162 | Returns: 163 | A one-step function for training or evaluation. 164 | """ 165 | if continuous: 166 | loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, 167 | continuous=True, likelihood_weighting=likelihood_weighting) 168 | else: 169 | assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." 170 | if isinstance(sde, VESDE): 171 | loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean) 172 | elif isinstance(sde, VPSDE): 173 | loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) 174 | else: 175 | raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") 176 | 177 | def step_fn(state, batch): 178 | """Running one step of training or evaluation. 179 | 180 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together 181 | for faster execution. 182 | 183 | Args: 184 | state: A dictionary of training information, containing the score model, optimizer, 185 | EMA status, and number of optimization steps. 186 | batch: A mini-batch of training/evaluation data. 187 | 188 | Returns: 189 | loss: The average loss value of this state. 190 | """ 191 | model = state['model'] 192 | if train: 193 | optimizer = state['optimizer'] 194 | optimizer.zero_grad() 195 | loss = loss_fn(model, batch) 196 | loss.backward() 197 | optimize_fn(optimizer, model.parameters(), step=state['step']) 198 | state['step'] += 1 199 | state['ema'].update(model.parameters()) 200 | else: 201 | with torch.no_grad(): 202 | ema = state['ema'] 203 | ema.store(model.parameters()) 204 | ema.copy_to(model.parameters()) 205 | loss = loss_fn(model, batch) 206 | ema.restore(model.parameters()) 207 | 208 | return loss 209 | 210 | return step_fn 211 | -------------------------------------------------------------------------------- /models/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from op import upfirdn2d 11 | 12 | 13 | # Function ported from StyleGAN2 14 | def get_weight(module, 15 | shape, 16 | weight_var='weight', 17 | kernel_init=None): 18 | """Get/create weight tensor for a convolution or fully-connected layer.""" 19 | 20 | return module.param(weight_var, kernel_init, shape) 21 | 22 | 23 | class Conv2d(nn.Module): 24 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 25 | 26 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 27 | resample_kernel=(1, 3, 3, 1), 28 | use_bias=True, 29 | kernel_init=None): 30 | super().__init__() 31 | assert not (up and down) 32 | assert kernel >= 1 and kernel % 2 == 1 33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 34 | if kernel_init is not None: 35 | self.weight.data = kernel_init(self.weight.data.shape) 36 | if use_bias: 37 | self.bias = nn.Parameter(torch.zeros(out_ch)) 38 | 39 | self.up = up 40 | self.down = down 41 | self.resample_kernel = resample_kernel 42 | self.kernel = kernel 43 | self.use_bias = use_bias 44 | 45 | def forward(self, x): 46 | if self.up: 47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 48 | elif self.down: 49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 50 | else: 51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 52 | 53 | if self.use_bias: 54 | x = x + self.bias.reshape(1, -1, 1, 1) 55 | 56 | return x 57 | 58 | 59 | def naive_upsample_2d(x, factor=2): 60 | _N, C, H, W = x.shape 61 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 62 | x = x.repeat(1, 1, 1, factor, 1, factor) 63 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 64 | 65 | 66 | def naive_downsample_2d(x, factor=2): 67 | _N, C, H, W = x.shape 68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 69 | return torch.mean(x, dim=(3, 5)) 70 | 71 | 72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 74 | 75 | Padding is performed only once at the beginning, not between the 76 | operations. 77 | The fused op is considerably more efficient than performing the same 78 | calculation 79 | using standard TensorFlow ops. It supports gradients of arbitrary order. 80 | Args: 81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 82 | C]`. 83 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 84 | outChannels]`. Grouped convolution can be performed by `inChannels = 85 | x.shape[0] // numGroups`. 86 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 87 | (separable). The default is `[1] * factor`, which corresponds to 88 | nearest-neighbor upsampling. 89 | factor: Integer upsampling factor (default: 2). 90 | gain: Scaling factor for signal magnitude (default: 1.0). 91 | 92 | Returns: 93 | Tensor of the shape `[N, C, H * factor, W * factor]` or 94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 95 | """ 96 | 97 | assert isinstance(factor, int) and factor >= 1 98 | 99 | # Check weight shape. 100 | assert len(w.shape) == 4 101 | convH = w.shape[2] 102 | convW = w.shape[3] 103 | inC = w.shape[1] 104 | outC = w.shape[0] 105 | 106 | assert convW == convH 107 | 108 | # Setup filter kernel. 109 | if k is None: 110 | k = [1] * factor 111 | k = _setup_kernel(k) * (gain * (factor ** 2)) 112 | p = (k.shape[0] - factor) - (convW - 1) 113 | 114 | stride = (factor, factor) 115 | 116 | # Determine data dimensions. 117 | stride = [1, 1, factor, factor] 118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 121 | assert output_padding[0] >= 0 and output_padding[1] >= 0 122 | num_groups = _shape(x, 1) // inC 123 | 124 | # Transpose weights. 125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 128 | 129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 130 | ## Original TF code. 131 | # x = tf.nn.conv2d_transpose( 132 | # x, 133 | # w, 134 | # output_shape=output_shape, 135 | # strides=stride, 136 | # padding='VALID', 137 | # data_format=data_format) 138 | ## JAX equivalent 139 | 140 | return upfirdn2d(x, torch.tensor(k, device=x.device), 141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 142 | 143 | 144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 146 | 147 | Padding is performed only once at the beginning, not between the operations. 148 | The fused op is considerably more efficient than performing the same 149 | calculation 150 | using standard TensorFlow ops. It supports gradients of arbitrary order. 151 | Args: 152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 153 | C]`. 154 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 155 | outChannels]`. Grouped convolution can be performed by `inChannels = 156 | x.shape[0] // numGroups`. 157 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 158 | (separable). The default is `[1] * factor`, which corresponds to 159 | average pooling. 160 | factor: Integer downsampling factor (default: 2). 161 | gain: Scaling factor for signal magnitude (default: 1.0). 162 | 163 | Returns: 164 | Tensor of the shape `[N, C, H // factor, W // factor]` or 165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 166 | """ 167 | 168 | assert isinstance(factor, int) and factor >= 1 169 | _outC, _inC, convH, convW = w.shape 170 | assert convW == convH 171 | if k is None: 172 | k = [1] * factor 173 | k = _setup_kernel(k) * gain 174 | p = (k.shape[0] - factor) + (convW - 1) 175 | s = [factor, factor] 176 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 177 | pad=((p + 1) // 2, p // 2)) 178 | return F.conv2d(x, w, stride=s, padding=0) 179 | 180 | 181 | def _setup_kernel(k): 182 | k = np.asarray(k, dtype=np.float32) 183 | if k.ndim == 1: 184 | k = np.outer(k, k) 185 | k /= np.sum(k) 186 | assert k.ndim == 2 187 | assert k.shape[0] == k.shape[1] 188 | return k 189 | 190 | 191 | def _shape(x, dim): 192 | return x.shape[dim] 193 | 194 | 195 | def upsample_2d(x, k=None, factor=2, gain=1): 196 | r"""Upsample a batch of 2D images with the given filter. 197 | 198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 199 | and upsamples each image with the given filter. The filter is normalized so 200 | that 201 | if the input pixels are constant, they will be scaled by the specified 202 | `gain`. 203 | Pixels outside the image are assumed to be zero, and the filter is padded 204 | with 205 | zeros so that its shape is a multiple of the upsampling factor. 206 | Args: 207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 208 | C]`. 209 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 210 | (separable). The default is `[1] * factor`, which corresponds to 211 | nearest-neighbor upsampling. 212 | factor: Integer upsampling factor (default: 2). 213 | gain: Scaling factor for signal magnitude (default: 1.0). 214 | 215 | Returns: 216 | Tensor of the shape `[N, C, H * factor, W * factor]` 217 | """ 218 | assert isinstance(factor, int) and factor >= 1 219 | if k is None: 220 | k = [1] * factor 221 | k = _setup_kernel(k) * (gain * (factor ** 2)) 222 | p = k.shape[0] - factor 223 | return upfirdn2d(x, torch.tensor(k, device=x.device), 224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 225 | 226 | 227 | def downsample_2d(x, k=None, factor=2, gain=1): 228 | r"""Downsample a batch of 2D images with the given filter. 229 | 230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 231 | and downsamples each image with the given filter. The filter is normalized 232 | so that 233 | if the input pixels are constant, they will be scaled by the specified 234 | `gain`. 235 | Pixels outside the image are assumed to be zero, and the filter is padded 236 | with 237 | zeros so that its shape is a multiple of the downsampling factor. 238 | Args: 239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 240 | C]`. 241 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 242 | (separable). The default is `[1] * factor`, which corresponds to 243 | average pooling. 244 | factor: Integer downsampling factor (default: 2). 245 | gain: Scaling factor for signal magnitude (default: 1.0). 246 | 247 | Returns: 248 | Tensor of the shape `[N, C, H // factor, W // factor]` 249 | """ 250 | 251 | assert isinstance(factor, int) and factor >= 1 252 | if k is None: 253 | k = [1] * factor 254 | k = _setup_kernel(k) * gain 255 | p = k.shape[0] - factor 256 | return upfirdn2d(x, torch.tensor(k, device=x.device), 257 | down=factor, pad=((p + 1) // 2, p // 2)) 258 | -------------------------------------------------------------------------------- /models/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 68 | eps=1e-6) 69 | self.NIN_0 = NIN(channels, channels) 70 | self.NIN_1 = NIN(channels, channels) 71 | self.NIN_2 = NIN(channels, channels) 72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 73 | self.skip_rescale = skip_rescale 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | h = self.GroupNorm_0(x) 78 | q = self.NIN_0(h) 79 | k = self.NIN_1(h) 80 | v = self.NIN_2(h) 81 | 82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 83 | w = torch.reshape(w, (B, H, W, H * W)) 84 | w = F.softmax(w, dim=-1) 85 | w = torch.reshape(w, (B, H, W, H, W)) 86 | h = torch.einsum('bhwij,bcij->bchw', w, v) 87 | h = self.NIN_3(h) 88 | if not self.skip_rescale: 89 | return x + h 90 | else: 91 | return (x + h) / np.sqrt(2.) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 96 | fir_kernel=(1, 3, 3, 1)): 97 | super().__init__() 98 | out_ch = out_ch if out_ch else in_ch 99 | if not fir: 100 | if with_conv: 101 | self.Conv_0 = conv3x3(in_ch, out_ch) 102 | else: 103 | if with_conv: 104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 105 | kernel=3, up=True, 106 | resample_kernel=fir_kernel, 107 | use_bias=True, 108 | kernel_init=default_init()) 109 | self.fir = fir 110 | self.with_conv = with_conv 111 | self.fir_kernel = fir_kernel 112 | self.out_ch = out_ch 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.shape 116 | if not self.fir: 117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 118 | if self.with_conv: 119 | h = self.Conv_0(h) 120 | else: 121 | if not self.with_conv: 122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 123 | else: 124 | h = self.Conv2d_0(x) 125 | 126 | return h 127 | 128 | 129 | class Downsample(nn.Module): 130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 131 | fir_kernel=(1, 3, 3, 1)): 132 | super().__init__() 133 | out_ch = out_ch if out_ch else in_ch 134 | if not fir: 135 | if with_conv: 136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 137 | else: 138 | if with_conv: 139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 140 | kernel=3, down=True, 141 | resample_kernel=fir_kernel, 142 | use_bias=True, 143 | kernel_init=default_init()) 144 | self.fir = fir 145 | self.fir_kernel = fir_kernel 146 | self.with_conv = with_conv 147 | self.out_ch = out_ch 148 | 149 | def forward(self, x): 150 | B, C, H, W = x.shape 151 | if not self.fir: 152 | if self.with_conv: 153 | x = F.pad(x, (0, 1, 0, 1)) 154 | x = self.Conv_0(x) 155 | else: 156 | x = F.avg_pool2d(x, 2, stride=2) 157 | else: 158 | if not self.with_conv: 159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 160 | else: 161 | x = self.Conv2d_0(x) 162 | 163 | return x 164 | 165 | 166 | class ResnetBlockDDPMpp(nn.Module): 167 | """ResBlock adapted from DDPM.""" 168 | 169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 170 | dropout=0.1, skip_rescale=False, init_scale=0.): 171 | super().__init__() 172 | out_ch = out_ch if out_ch else in_ch 173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 174 | self.Conv_0 = conv3x3(in_ch, out_ch) 175 | if temb_dim is not None: 176 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 178 | nn.init.zeros_(self.Dense_0.bias) 179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 180 | self.Dropout_0 = nn.Dropout(dropout) 181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 182 | if in_ch != out_ch: 183 | if conv_shortcut: 184 | self.Conv_2 = conv3x3(in_ch, out_ch) 185 | else: 186 | self.NIN_0 = NIN(in_ch, out_ch) 187 | 188 | self.skip_rescale = skip_rescale 189 | self.act = act 190 | self.out_ch = out_ch 191 | self.conv_shortcut = conv_shortcut 192 | 193 | def forward(self, x, temb=None): 194 | h = self.act(self.GroupNorm_0(x)) 195 | h = self.Conv_0(h) 196 | if temb is not None: 197 | h += self.Dense_0(self.act(temb))[:, :, None, None] 198 | h = self.act(self.GroupNorm_1(h)) 199 | h = self.Dropout_0(h) 200 | h = self.Conv_1(h) 201 | if x.shape[1] != self.out_ch: 202 | if self.conv_shortcut: 203 | x = self.Conv_2(x) 204 | else: 205 | x = self.NIN_0(x) 206 | if not self.skip_rescale: 207 | return x + h 208 | else: 209 | return (x + h) / np.sqrt(2.) 210 | 211 | 212 | class ResnetBlockBigGANpp(nn.Module): 213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 215 | skip_rescale=True, init_scale=0.): 216 | super().__init__() 217 | 218 | out_ch = out_ch if out_ch else in_ch 219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 220 | self.up = up 221 | self.down = down 222 | self.fir = fir 223 | self.fir_kernel = fir_kernel 224 | 225 | self.Conv_0 = conv3x3(in_ch, out_ch) 226 | if temb_dim is not None: 227 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 229 | nn.init.zeros_(self.Dense_0.bias) 230 | 231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 232 | self.Dropout_0 = nn.Dropout(dropout) 233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 234 | if in_ch != out_ch or up or down: 235 | self.Conv_2 = conv1x1(in_ch, out_ch) 236 | 237 | self.skip_rescale = skip_rescale 238 | self.act = act 239 | self.in_ch = in_ch 240 | self.out_ch = out_ch 241 | 242 | def forward(self, x, temb=None): 243 | h = self.act(self.GroupNorm_0(x)) 244 | 245 | if self.up: 246 | if self.fir: 247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 249 | else: 250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 252 | elif self.down: 253 | if self.fir: 254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 259 | 260 | h = self.Conv_0(h) 261 | # Add bias to each feature map conditioned on the time embedding 262 | if temb is not None: 263 | h += self.Dense_0(self.act(temb))[:, :, None, None] 264 | h = self.act(self.GroupNorm_1(h)) 265 | h = self.Dropout_0(h) 266 | h = self.Conv_1(h) 267 | 268 | if self.in_ch != self.out_ch or self.up or self.down: 269 | x = self.Conv_2(x) 270 | 271 | if not self.skip_rescale: 272 | return x + h 273 | else: 274 | return (x + h) / np.sqrt(2.) 275 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } --------------------------------------------------------------------------------