├── RectifiedFlow ├── __init__.py ├── configs │ ├── __init__.py │ ├── celeba_hq_pytorch_rf_gaussian.py │ └── default_configs.py ├── op │ ├── __init__.py │ ├── fused_bias_act.cpp │ ├── upfirdn2d.cpp │ ├── fused_act.py │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── utils.py ├── models │ ├── __init__.py │ ├── utils.py │ ├── ema.py │ ├── ddpm.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── ncsnv2.py │ └── layers.py └── datasets.py ├── demo └── celeba.jpg ├── github_misc ├── fig1.png └── example.png ├── main.py ├── README.md └── utils ├── run_lib_flowgrad.py ├── DiffAugment_pytorch.py └── flowgrad_utils.py /RectifiedFlow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RectifiedFlow/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnobitab/FlowGrad/HEAD/demo/celeba.jpg -------------------------------------------------------------------------------- /github_misc/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnobitab/FlowGrad/HEAD/github_misc/fig1.png -------------------------------------------------------------------------------- /github_misc/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnobitab/FlowGrad/HEAD/github_misc/example.png -------------------------------------------------------------------------------- /RectifiedFlow/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /RectifiedFlow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | 5 | 6 | def restore_checkpoint(ckpt_dir, state, device): 7 | loaded_state = torch.load(ckpt_dir, map_location=device) 8 | state['model'].load_state_dict(loaded_state['model'], strict=False) 9 | state['ema'].load_state_dict(loaded_state['ema']) 10 | state['step'] = loaded_state['step'] 11 | return state 12 | 13 | 14 | def save_checkpoint(ckpt_dir, state): 15 | saved_state = { 16 | 'optimizer': state['optimizer'].state_dict(), 17 | 'model': state['model'].state_dict(), 18 | 'ema': state['ema'].state_dict(), 19 | 'step': state['step'] 20 | } 21 | torch.save(saved_state, ckpt_dir) 22 | -------------------------------------------------------------------------------- /RectifiedFlow/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /RectifiedFlow/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /RectifiedFlow/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /RectifiedFlow/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import os 19 | 20 | def get_data_scaler(config): 21 | """Data normalizer. Assume data are always in [0, 1].""" 22 | if config.data.centered: 23 | # Rescale to [-1, 1] 24 | return lambda x: x * 2. - 1. 25 | else: 26 | return lambda x: x 27 | 28 | 29 | def get_data_inverse_scaler(config): 30 | """Inverse data normalizer.""" 31 | if config.data.centered: 32 | # Rescale [-1, 1] to [0, 1] 33 | return lambda x: (x + 1.) / 2. 34 | else: 35 | return lambda x: x 36 | 37 | 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from utils import run_lib_flowgrad 4 | from absl import app 5 | from absl import flags 6 | from ml_collections.config_flags import config_flags 7 | import logging 8 | import os 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | config_flags.DEFINE_config_file("config", None, "Rectified Flow Model configuration.", lock_config=True) 13 | flags.DEFINE_enum("mode", 'flowgrad-edit', ["flowgrad-edit"], "Running mode.") 14 | flags.DEFINE_string("text_prompt", None, "text prompt for editing") 15 | flags.DEFINE_float("alpha", 0.7, "The coefficient to balance the edit loss and the reconstruction loss.") 16 | flags.DEFINE_string("model_path", None, "Path to pre-trained model checkpoint.") 17 | flags.DEFINE_string("image_path", None, "The path to the image that will be edited") 18 | flags.DEFINE_string("output_folder", "output", "The folder name for storing output") 19 | flags.mark_flags_as_required(["model_path", "text_prompt", "alpha", "image_path"]) 20 | 21 | 22 | def main(argv): 23 | if FLAGS.mode == "flowgrad-edit": 24 | run_lib_flowgrad.flowgrad_edit(FLAGS.config, FLAGS.text_prompt, FLAGS.alpha, FLAGS.model_path, FLAGS.image_path, FLAGS.output_folder) 25 | else: 26 | raise ValueError(f"Mode {FLAGS.mode} not recognized.") 27 | 28 | 29 | if __name__ == "__main__": 30 | app.run(main) 31 | -------------------------------------------------------------------------------- /RectifiedFlow/configs/celeba_hq_pytorch_rf_gaussian.py: -------------------------------------------------------------------------------- 1 | """Training rectified Flow on CelebA HQ.""" 2 | 3 | import ml_collections 4 | from RectifiedFlow.configs.default_configs import get_default_configs 5 | 6 | 7 | def get_config(): 8 | config = get_default_configs() 9 | # training 10 | training = config.training 11 | training.sde = 'rectified_flow' 12 | training.continuous = False 13 | training.reduce_mean = True 14 | training.snapshot_freq = 100000 15 | training.data_dir = 'DATA_DIR' 16 | 17 | # sampling 18 | sampling = config.sampling 19 | sampling.method = 'rectified_flow' 20 | sampling.init_type = 'gaussian' 21 | sampling.init_noise_scale = 1.0 22 | sampling.use_ode_sampler = 'rk45' 23 | 24 | # data 25 | data = config.data 26 | data.dataset = 'CelebA-HQ-Pytorch' 27 | data.centered = True 28 | 29 | # model 30 | model = config.model 31 | model.name = 'ncsnpp' 32 | model.scale_by_sigma = True 33 | model.ema_rate = 0.999 34 | model.normalization = 'GroupNorm' 35 | model.nonlinearity = 'swish' 36 | model.nf = 128 37 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 38 | model.num_res_blocks = 2 39 | model.attn_resolutions = (16,) 40 | model.resamp_with_conv = True 41 | model.conditional = True 42 | model.fir = True 43 | model.fir_kernel = [1, 3, 3, 1] 44 | model.skip_rescale = True 45 | model.resblock_type = 'biggan' 46 | model.progressive = 'output_skip' 47 | model.progressive_input = 'input_skip' 48 | model.progressive_combine = 'sum' 49 | model.attention_type = 'ddpm' 50 | model.init_scale = 0. 51 | model.fourier_scale = 16 52 | model.conv_size = 3 53 | 54 | return config 55 | -------------------------------------------------------------------------------- /RectifiedFlow/configs/default_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 64 10 | training.n_iters = 2400001 11 | training.snapshot_freq = 50000 12 | training.log_freq = 50 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 5000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = False 19 | training.continuous = True 20 | training.reduce_mean = False 21 | 22 | # sampling 23 | config.sampling = sampling = ml_collections.ConfigDict() 24 | sampling.n_steps_each = 1 25 | sampling.noise_removal = True 26 | sampling.probability_flow = False 27 | sampling.snr = 0.075 28 | 29 | sampling.sigma_variance = 0.0 # NOTE: XC: sigma variance for turning ODe to SDE 30 | sampling.init_noise_scale = 1.0 31 | sampling.use_ode_sampler = 'ode' 32 | sampling.ode_tol = 1e-5 33 | sampling.sample_N = 1000 34 | 35 | # evaluation 36 | config.eval = evaluate = ml_collections.ConfigDict() 37 | evaluate.begin_ckpt = 50 38 | evaluate.end_ckpt = 96 39 | evaluate.batch_size = 512 40 | evaluate.enable_sampling = False 41 | evaluate.enable_figures_only = False 42 | evaluate.num_samples = 50000 43 | evaluate.enable_loss = False 44 | evaluate.enable_bpd = False 45 | evaluate.bpd_dataset = 'test' 46 | 47 | # data 48 | config.data = data = ml_collections.ConfigDict() 49 | data.dataset = 'LSUN' 50 | data.image_size = 256 51 | data.random_flip = True 52 | data.uniform_dequantization = False 53 | data.centered = False 54 | data.num_channels = 3 55 | data.root_path = 'YOUR_ROOT_PATH' 56 | 57 | # model 58 | config.model = model = ml_collections.ConfigDict() 59 | model.sigma_max = 378 60 | model.sigma_min = 0.01 61 | model.num_scales = 2000 62 | model.beta_min = 0.1 63 | model.beta_max = 20. 64 | model.dropout = 0. 65 | model.embedding_type = 'fourier' 66 | 67 | # optimization 68 | config.optim = optim = ml_collections.ConfigDict() 69 | optim.weight_decay = 0 70 | optim.optimizer = 'Adam' 71 | optim.lr = 2e-4 72 | optim.beta1 = 0.9 73 | optim.eps = 1e-8 74 | optim.warmup = 5000 75 | optim.grad_clip = 1. 76 | 77 | config.seed = 42 78 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 79 | 80 | return config 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowGrad 2 | 3 | This is the official implementation of the CVPR2023 paper 4 | ## [FlowGrad: Controlling the Output of Generative ODEs With Gradients](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_FlowGrad_Controlling_the_Output_of_Generative_ODEs_With_Gradients_CVPR_2023_paper.html) 5 | by *Xingchao Liu, Lemeng Wu, Shujian Zhang, Chengyue Gong, Wei Ping, Qiang Liu* from NVIDIA and UT Austin 6 | 7 | ![](github_misc/fig1.png) 8 | 9 | ## Interactive Colab notebook 10 | 11 | We provide an introductory Colab notebook on a toy 2D example to help users understand the method. Play [here](https://colab.research.google.com/drive/1rx3-WbC6yyx1jnES3xVQ0b463xfihFJU?usp=sharing). 12 | 13 | ## Controlling Rectified Flow on CelebA-HQ 14 | 15 | We provide the scripts for applying FlowGrad to control the output of pre-trained Rectified Flow model on CelebA-HQ. 16 | First, clone and enter the repo with, 17 | 18 | ``` 19 | git clone https://github.com/gnobitab/FlowGrad.git 20 | cd FlowGrad 21 | ``` 22 | 23 | The pre-trained generative model can be downloaded from [Rectified Flow CelebA-HQ](https://drive.google.com/file/d/1ryhuJGz75S35GEdWDLiq4XFrsbwPdHnF/view?usp=sharing) 24 | Just put it in ``` ./ ``` 25 | 26 | ### Dependencies 27 | The following packages are required, 28 | 29 | ``` 30 | torch, numpy, lpips, clip, ml_collections, absl-py 31 | ``` 32 | 33 | ### Run 34 | In our example, we use the demo image ```demo/celeba.jpg``` and text prompt ```A photo of a smiling face.``` The following command can be used to do this editing. 35 | 36 | ``` 37 | python -u main.py --config RectifiedFlow/configs/celeba_hq_pytorch_rf_gaussian.py --text_prompt 'A photo of a smiling face.' --alpha 0.7 --model_path ./checkpoint_10.pth --image_path demo/celeba.jpg 38 | ``` 39 | 40 | The images will be saved in ```output/figs/```. The folder includes, 41 | 42 | * ```original.png```: the original image. 43 | 44 | * ```reconstruct.png```: the image generated from the encoded latent of the original image by running the ODE in the reverse direction. There is subtle difference from the orignal image due to the discretization error. 45 | 46 | * ```optimized.png```: the image generated after editing with FlowGrad. 47 | 48 | ![](github_misc/example.png) 49 | 50 | ## Citation 51 | If you use the code or our work is related to yours, please cite us: 52 | ``` 53 | @InProceedings{Liu_2023_CVPR, 54 | author = {Liu, Xingchao and Wu, Lemeng and Zhang, Shujian and Gong, Chengyue and Ping, Wei and Liu, Qiang}, 55 | title = {FlowGrad: Controlling the Output of Generative ODEs With Gradients}, 56 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 57 | month = {June}, 58 | year = {2023}, 59 | pages = {24335-24344} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /utils/run_lib_flowgrad.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import io 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import logging 8 | 9 | # Keep the import below for registering all model definitions 10 | from RectifiedFlow.models import ddpm, ncsnv2, ncsnpp 11 | from RectifiedFlow.models import utils as mutils 12 | from RectifiedFlow.models.ema import ExponentialMovingAverage 13 | from absl import flags 14 | import torch 15 | from torchvision.utils import make_grid, save_image 16 | from RectifiedFlow.utils import save_checkpoint, restore_checkpoint 17 | import RectifiedFlow.datasets as datasets 18 | 19 | from RectifiedFlow.models.utils import get_model_fn 20 | from RectifiedFlow.models import utils as mutils 21 | 22 | from .flowgrad_utils import get_img, embed_to_latent, clip_semantic_loss, save_img, generate_traj, flowgrad_optimization 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | def flowgrad_edit(config, text_prompt, alpha, model_path, image_path, output_folder="output"): 27 | # Create data normalizer and its inverse 28 | scaler = datasets.get_data_scaler(config) 29 | inverse_scaler = datasets.get_data_inverse_scaler(config) 30 | 31 | # Initialize model 32 | score_model = mutils.create_model(config) 33 | ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) 34 | state = dict(model=score_model, ema=ema, step=0) 35 | 36 | state = restore_checkpoint(model_path, state, device=config.device) 37 | ema.copy_to(score_model.parameters()) 38 | 39 | model_fn = mutils.get_model_fn(score_model, train=False) 40 | 41 | # Load the image to edit 42 | original_img = get_img(image_path) 43 | 44 | log_folder = os.path.join(output_folder, 'figs') 45 | print('Images will be saved to:', log_folder) 46 | if not os.path.exists(log_folder): os.makedirs(log_folder) 47 | save_img(original_img, path=os.path.join(log_folder, 'original.png')) 48 | 49 | # Get latent code of the image and save reconstruction 50 | original_img = original_img.to(config.device) 51 | clip_loss = clip_semantic_loss(text_prompt, original_img, config.device, alpha=alpha, inverse_scaler=inverse_scaler) 52 | 53 | t_s = time.time() 54 | latent = embed_to_latent(model_fn, scaler(original_img)) 55 | traj = generate_traj(model_fn, latent, N=100) 56 | save_img(inverse_scaler(traj[-1]), path=os.path.join(log_folder, 'reconstruct.png')) 57 | print('Finished getting latent code and reconstruction; image saved.') 58 | 59 | # Edit according to text prompt 60 | u_ind = [i for i in range(100)] 61 | opt_u = flowgrad_optimization(latent, u_ind, model_fn, generate_traj, N=100, L_N=clip_loss.L_N, u_init=None, number_of_iterations=10, straightness_threshold=5e-3, lr=10.0) 62 | 63 | traj = generate_traj(model_fn, latent, u=opt_u, N=100) 64 | 65 | print('Total time:', time.time() - t_s) 66 | save_img(inverse_scaler(traj[-1]), path=os.path.join(log_folder, 'optimized.png')) 67 | print('Finished Editting; images saved.') 68 | 69 | -------------------------------------------------------------------------------- /RectifiedFlow/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /RectifiedFlow/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /RectifiedFlow/models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | import numpy as np 21 | 22 | 23 | _MODELS = {} 24 | 25 | 26 | def register_model(cls=None, *, name=None): 27 | """A decorator for registering model classes.""" 28 | 29 | def _register(cls): 30 | if name is None: 31 | local_name = cls.__name__ 32 | else: 33 | local_name = name 34 | if local_name in _MODELS: 35 | raise ValueError(f'Already registered model with name: {local_name}') 36 | _MODELS[local_name] = cls 37 | return cls 38 | #print(cls, name) 39 | if cls is None: 40 | return _register 41 | else: 42 | return _register(cls) 43 | 44 | 45 | def get_model(name): 46 | #print(_MODELS) 47 | return _MODELS[name] 48 | 49 | def get_sigmas(config): 50 | """Get sigmas --- the set of noise levels for SMLD from config files. 51 | Args: 52 | config: A ConfigDict object parsed from the config file 53 | Returns: 54 | sigmas: a jax numpy arrary of noise levels 55 | """ 56 | sigmas = np.exp( 57 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 58 | 59 | return sigmas 60 | 61 | 62 | 63 | def create_model(config): 64 | """Create the score model.""" 65 | model_name = config.model.name 66 | score_model = get_model(model_name)(config) 67 | score_model = score_model.to(config.device) 68 | 69 | num_params = 0 70 | for p in score_model.parameters(): 71 | num_params += p.numel() 72 | print('Number of Parameters in the Score Model:', num_params) 73 | 74 | score_model = torch.nn.DataParallel(score_model) 75 | return score_model 76 | 77 | 78 | def get_model_fn(model, train=False): 79 | """Create a function to give the output of the score-based model. 80 | 81 | Args: 82 | model: The score model. 83 | train: `True` for training and `False` for evaluation. 84 | 85 | Returns: 86 | A model function. 87 | """ 88 | 89 | def model_fn(x, labels): 90 | """Compute the output of the score-based model. 91 | 92 | Args: 93 | x: A mini-batch of input data. 94 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 95 | for different models. 96 | 97 | Returns: 98 | A tuple of (model output, new mutable states) 99 | """ 100 | if not train: 101 | model.eval() 102 | return model(x, labels) 103 | else: 104 | model.train() 105 | return model(x, labels) 106 | 107 | return model_fn 108 | 109 | 110 | def to_flattened_numpy(x): 111 | """Flatten a torch tensor `x` and convert it to numpy.""" 112 | return x.detach().cpu().numpy().reshape((-1,)) 113 | 114 | 115 | def from_flattened_numpy(x, shape): 116 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 117 | return torch.from_numpy(x.reshape(shape)) 118 | -------------------------------------------------------------------------------- /RectifiedFlow/models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /utils/DiffAugment_pytorch.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def DiffAugment(x, policy='', channels_first=True): 11 | if policy: 12 | if not channels_first: 13 | x = x.permute(0, 3, 1, 2) 14 | for p in policy.split(','): 15 | for f in AUGMENT_FNS[p]: 16 | x = f(x) 17 | if not channels_first: 18 | x = x.permute(0, 2, 3, 1) 19 | x = x.contiguous() 20 | return x 21 | 22 | 23 | def rand_brightness(x): 24 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 25 | return x 26 | 27 | 28 | def rand_saturation(x): 29 | x_mean = x.mean(dim=1, keepdim=True) 30 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 31 | return x 32 | 33 | 34 | def rand_contrast(x): 35 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 36 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 37 | return x 38 | 39 | 40 | def rand_translation(x, ratio=0.125): ### ratio: org: 0.125 41 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 42 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 43 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 44 | grid_batch, grid_x, grid_y = torch.meshgrid( 45 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 47 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 48 | ) 49 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 50 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 51 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 52 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 53 | return x 54 | 55 | def rand_resize(x, min_ratio=0.8, max_ratio=1.2): ### ratio: org: 0.125 56 | resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio 57 | resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear') 58 | org_size = x.shape[3] 59 | #print('ORG:', x.shape) 60 | #print('RESIZED:', resized_img.shape) 61 | if int(resize_ratio*x.shape[3]) < x.shape[3]: 62 | left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2. 63 | left_pad = int(left_pad) 64 | right_pad = x.shape[3] - left_pad - resized_img.shape[3] 65 | #print('PAD:', left_pad, right_pad) 66 | x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.) 67 | #print('SMALL:', x.shape) 68 | else: 69 | left = (int(resize_ratio*x.shape[3])-x.shape[3])/2. 70 | left = int(left) 71 | #print('LEFT:', left) 72 | x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])] 73 | #print('LARGE:', x.shape) 74 | assert x.shape[2] == org_size 75 | assert x.shape[3] == org_size 76 | 77 | return x 78 | 79 | 80 | def rand_cutout(x, ratio=0.5): 81 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 82 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 83 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 84 | grid_batch, grid_x, grid_y = torch.meshgrid( 85 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 86 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 87 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 88 | ) 89 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 90 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 91 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 92 | mask[grid_batch, grid_x, grid_y] = 0 93 | x = x * mask.unsqueeze(1) 94 | return x 95 | 96 | 97 | AUGMENT_FNS = { 98 | 'color': [rand_brightness, rand_saturation, rand_contrast], 99 | 'translation': [rand_translation], 100 | 'resize': [rand_resize], 101 | 'cutout': [rand_cutout], 102 | } 103 | -------------------------------------------------------------------------------- /RectifiedFlow/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /RectifiedFlow/models/ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import functools 25 | 26 | from . import utils, layers, normalization 27 | 28 | RefineBlock = layers.RefineBlock 29 | ResidualBlock = layers.ResidualBlock 30 | ResnetBlockDDPM = layers.ResnetBlockDDPM 31 | Upsample = layers.Upsample 32 | Downsample = layers.Downsample 33 | conv3x3 = layers.ddpm_conv3x3 34 | get_act = layers.get_act 35 | get_normalization = normalization.get_normalization 36 | default_initializer = layers.default_init 37 | 38 | 39 | @utils.register_model(name='ddpm') 40 | class DDPM(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 45 | 46 | self.nf = nf = config.model.nf 47 | ch_mult = config.model.ch_mult 48 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 49 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 50 | dropout = config.model.dropout 51 | resamp_with_conv = config.model.resamp_with_conv 52 | self.num_resolutions = num_resolutions = len(ch_mult) 53 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 54 | 55 | AttnBlock = functools.partial(layers.AttnBlock) 56 | self.conditional = conditional = config.model.conditional 57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 58 | if conditional: 59 | # Condition on noise levels. 60 | modules = [nn.Linear(nf, nf * 4)] 61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 62 | nn.init.zeros_(modules[0].bias) 63 | modules.append(nn.Linear(nf * 4, nf * 4)) 64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 65 | nn.init.zeros_(modules[1].bias) 66 | 67 | self.centered = config.data.centered 68 | channels = config.data.num_channels 69 | 70 | # Downsampling block 71 | modules.append(conv3x3(channels, nf)) 72 | hs_c = [nf] 73 | in_ch = nf 74 | for i_level in range(num_resolutions): 75 | # Residual blocks for this resolution 76 | for i_block in range(num_res_blocks): 77 | out_ch = nf * ch_mult[i_level] 78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 79 | in_ch = out_ch 80 | if all_resolutions[i_level] in attn_resolutions: 81 | modules.append(AttnBlock(channels=in_ch)) 82 | hs_c.append(in_ch) 83 | if i_level != num_resolutions - 1: 84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 85 | hs_c.append(in_ch) 86 | 87 | in_ch = hs_c[-1] 88 | modules.append(ResnetBlock(in_ch=in_ch)) 89 | modules.append(AttnBlock(channels=in_ch)) 90 | modules.append(ResnetBlock(in_ch=in_ch)) 91 | 92 | # Upsampling block 93 | for i_level in reversed(range(num_resolutions)): 94 | for i_block in range(num_res_blocks + 1): 95 | out_ch = nf * ch_mult[i_level] 96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 97 | in_ch = out_ch 98 | if all_resolutions[i_level] in attn_resolutions: 99 | modules.append(AttnBlock(channels=in_ch)) 100 | if i_level != 0: 101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 102 | 103 | assert not hs_c 104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 105 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 106 | self.all_modules = nn.ModuleList(modules) 107 | 108 | self.scale_by_sigma = config.model.scale_by_sigma 109 | 110 | def forward(self, x, labels): 111 | modules = self.all_modules 112 | m_idx = 0 113 | if self.conditional: 114 | # timestep/scale embedding 115 | timesteps = labels 116 | temb = layers.get_timestep_embedding(timesteps, self.nf) 117 | temb = modules[m_idx](temb) 118 | m_idx += 1 119 | temb = modules[m_idx](self.act(temb)) 120 | m_idx += 1 121 | else: 122 | temb = None 123 | 124 | if self.centered: 125 | # Input is in [-1, 1] 126 | h = x 127 | else: 128 | # Input is in [0, 1] 129 | h = 2 * x - 1. 130 | 131 | # Downsampling block 132 | hs = [modules[m_idx](h)] 133 | m_idx += 1 134 | for i_level in range(self.num_resolutions): 135 | # Residual blocks for this resolution 136 | for i_block in range(self.num_res_blocks): 137 | h = modules[m_idx](hs[-1], temb) 138 | m_idx += 1 139 | if h.shape[-1] in self.attn_resolutions: 140 | h = modules[m_idx](h) 141 | m_idx += 1 142 | hs.append(h) 143 | if i_level != self.num_resolutions - 1: 144 | hs.append(modules[m_idx](hs[-1])) 145 | m_idx += 1 146 | 147 | h = hs[-1] 148 | h = modules[m_idx](h, temb) 149 | m_idx += 1 150 | h = modules[m_idx](h) 151 | m_idx += 1 152 | h = modules[m_idx](h, temb) 153 | m_idx += 1 154 | 155 | # Upsampling block 156 | for i_level in reversed(range(self.num_resolutions)): 157 | for i_block in range(self.num_res_blocks + 1): 158 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 159 | m_idx += 1 160 | if h.shape[-1] in self.attn_resolutions: 161 | h = modules[m_idx](h) 162 | m_idx += 1 163 | if i_level != 0: 164 | h = modules[m_idx](h) 165 | m_idx += 1 166 | 167 | assert not hs 168 | h = self.act(modules[m_idx](h)) 169 | m_idx += 1 170 | h = modules[m_idx](h) 171 | m_idx += 1 172 | assert m_idx == len(modules) 173 | 174 | if self.scale_by_sigma: 175 | # Divide the output by sigmas. Useful for training with the NCSN loss. 176 | # The DDPM loss scales the network output by sigma in the loss function, 177 | # so no need of doing it here. 178 | used_sigmas = self.sigmas[labels, None, None, None] 179 | h = h / used_sigmas 180 | 181 | return h 182 | -------------------------------------------------------------------------------- /RectifiedFlow/models/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /RectifiedFlow/models/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from RectifiedFlow.op import upfirdn2d 11 | 12 | 13 | # Function ported from StyleGAN2 14 | def get_weight(module, 15 | shape, 16 | weight_var='weight', 17 | kernel_init=None): 18 | """Get/create weight tensor for a convolution or fully-connected layer.""" 19 | 20 | return module.param(weight_var, kernel_init, shape) 21 | 22 | 23 | class Conv2d(nn.Module): 24 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 25 | 26 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 27 | resample_kernel=(1, 3, 3, 1), 28 | use_bias=True, 29 | kernel_init=None): 30 | super().__init__() 31 | assert not (up and down) 32 | assert kernel >= 1 and kernel % 2 == 1 33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 34 | if kernel_init is not None: 35 | self.weight.data = kernel_init(self.weight.data.shape) 36 | if use_bias: 37 | self.bias = nn.Parameter(torch.zeros(out_ch)) 38 | 39 | self.up = up 40 | self.down = down 41 | self.resample_kernel = resample_kernel 42 | self.kernel = kernel 43 | self.use_bias = use_bias 44 | 45 | def forward(self, x): 46 | if self.up: 47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 48 | elif self.down: 49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 50 | else: 51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 52 | 53 | if self.use_bias: 54 | x = x + self.bias.reshape(1, -1, 1, 1) 55 | 56 | return x 57 | 58 | 59 | def naive_upsample_2d(x, factor=2): 60 | _N, C, H, W = x.shape 61 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 62 | x = x.repeat(1, 1, 1, factor, 1, factor) 63 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 64 | 65 | 66 | def naive_downsample_2d(x, factor=2): 67 | _N, C, H, W = x.shape 68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 69 | return torch.mean(x, dim=(3, 5)) 70 | 71 | 72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 74 | 75 | Padding is performed only once at the beginning, not between the 76 | operations. 77 | The fused op is considerably more efficient than performing the same 78 | calculation 79 | using standard TensorFlow ops. It supports gradients of arbitrary order. 80 | Args: 81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 82 | C]`. 83 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 84 | outChannels]`. Grouped convolution can be performed by `inChannels = 85 | x.shape[0] // numGroups`. 86 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 87 | (separable). The default is `[1] * factor`, which corresponds to 88 | nearest-neighbor upsampling. 89 | factor: Integer upsampling factor (default: 2). 90 | gain: Scaling factor for signal magnitude (default: 1.0). 91 | 92 | Returns: 93 | Tensor of the shape `[N, C, H * factor, W * factor]` or 94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 95 | """ 96 | 97 | assert isinstance(factor, int) and factor >= 1 98 | 99 | # Check weight shape. 100 | assert len(w.shape) == 4 101 | convH = w.shape[2] 102 | convW = w.shape[3] 103 | inC = w.shape[1] 104 | outC = w.shape[0] 105 | 106 | assert convW == convH 107 | 108 | # Setup filter kernel. 109 | if k is None: 110 | k = [1] * factor 111 | k = _setup_kernel(k) * (gain * (factor ** 2)) 112 | p = (k.shape[0] - factor) - (convW - 1) 113 | 114 | stride = (factor, factor) 115 | 116 | # Determine data dimensions. 117 | stride = [1, 1, factor, factor] 118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 121 | assert output_padding[0] >= 0 and output_padding[1] >= 0 122 | num_groups = _shape(x, 1) // inC 123 | 124 | # Transpose weights. 125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 128 | 129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 130 | ## Original TF code. 131 | # x = tf.nn.conv2d_transpose( 132 | # x, 133 | # w, 134 | # output_shape=output_shape, 135 | # strides=stride, 136 | # padding='VALID', 137 | # data_format=data_format) 138 | ## JAX equivalent 139 | 140 | return upfirdn2d(x, torch.tensor(k, device=x.device), 141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 142 | 143 | 144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 146 | 147 | Padding is performed only once at the beginning, not between the operations. 148 | The fused op is considerably more efficient than performing the same 149 | calculation 150 | using standard TensorFlow ops. It supports gradients of arbitrary order. 151 | Args: 152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 153 | C]`. 154 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 155 | outChannels]`. Grouped convolution can be performed by `inChannels = 156 | x.shape[0] // numGroups`. 157 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 158 | (separable). The default is `[1] * factor`, which corresponds to 159 | average pooling. 160 | factor: Integer downsampling factor (default: 2). 161 | gain: Scaling factor for signal magnitude (default: 1.0). 162 | 163 | Returns: 164 | Tensor of the shape `[N, C, H // factor, W // factor]` or 165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 166 | """ 167 | 168 | assert isinstance(factor, int) and factor >= 1 169 | _outC, _inC, convH, convW = w.shape 170 | assert convW == convH 171 | if k is None: 172 | k = [1] * factor 173 | k = _setup_kernel(k) * gain 174 | p = (k.shape[0] - factor) + (convW - 1) 175 | s = [factor, factor] 176 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 177 | pad=((p + 1) // 2, p // 2)) 178 | return F.conv2d(x, w, stride=s, padding=0) 179 | 180 | 181 | def _setup_kernel(k): 182 | k = np.asarray(k, dtype=np.float32) 183 | if k.ndim == 1: 184 | k = np.outer(k, k) 185 | k /= np.sum(k) 186 | assert k.ndim == 2 187 | assert k.shape[0] == k.shape[1] 188 | return k 189 | 190 | 191 | def _shape(x, dim): 192 | return x.shape[dim] 193 | 194 | 195 | def upsample_2d(x, k=None, factor=2, gain=1): 196 | r"""Upsample a batch of 2D images with the given filter. 197 | 198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 199 | and upsamples each image with the given filter. The filter is normalized so 200 | that 201 | if the input pixels are constant, they will be scaled by the specified 202 | `gain`. 203 | Pixels outside the image are assumed to be zero, and the filter is padded 204 | with 205 | zeros so that its shape is a multiple of the upsampling factor. 206 | Args: 207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 208 | C]`. 209 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 210 | (separable). The default is `[1] * factor`, which corresponds to 211 | nearest-neighbor upsampling. 212 | factor: Integer upsampling factor (default: 2). 213 | gain: Scaling factor for signal magnitude (default: 1.0). 214 | 215 | Returns: 216 | Tensor of the shape `[N, C, H * factor, W * factor]` 217 | """ 218 | assert isinstance(factor, int) and factor >= 1 219 | if k is None: 220 | k = [1] * factor 221 | k = _setup_kernel(k) * (gain * (factor ** 2)) 222 | p = k.shape[0] - factor 223 | return upfirdn2d(x, torch.tensor(k, device=x.device), 224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 225 | 226 | 227 | def downsample_2d(x, k=None, factor=2, gain=1): 228 | r"""Downsample a batch of 2D images with the given filter. 229 | 230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 231 | and downsamples each image with the given filter. The filter is normalized 232 | so that 233 | if the input pixels are constant, they will be scaled by the specified 234 | `gain`. 235 | Pixels outside the image are assumed to be zero, and the filter is padded 236 | with 237 | zeros so that its shape is a multiple of the downsampling factor. 238 | Args: 239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 240 | C]`. 241 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 242 | (separable). The default is `[1] * factor`, which corresponds to 243 | average pooling. 244 | factor: Integer downsampling factor (default: 2). 245 | gain: Scaling factor for signal magnitude (default: 1.0). 246 | 247 | Returns: 248 | Tensor of the shape `[N, C, H // factor, W // factor]` 249 | """ 250 | 251 | assert isinstance(factor, int) and factor >= 1 252 | if k is None: 253 | k = [1] * factor 254 | k = _setup_kernel(k) * gain 255 | p = k.shape[0] - factor 256 | return upfirdn2d(x, torch.tensor(k, device=x.device), 257 | down=factor, pad=((p + 1) // 2, p // 2)) 258 | -------------------------------------------------------------------------------- /RectifiedFlow/models/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 68 | eps=1e-6) 69 | self.NIN_0 = NIN(channels, channels) 70 | self.NIN_1 = NIN(channels, channels) 71 | self.NIN_2 = NIN(channels, channels) 72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 73 | self.skip_rescale = skip_rescale 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | h = self.GroupNorm_0(x) 78 | q = self.NIN_0(h) 79 | k = self.NIN_1(h) 80 | v = self.NIN_2(h) 81 | 82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 83 | w = torch.reshape(w, (B, H, W, H * W)) 84 | w = F.softmax(w, dim=-1) 85 | w = torch.reshape(w, (B, H, W, H, W)) 86 | h = torch.einsum('bhwij,bcij->bchw', w, v) 87 | h = self.NIN_3(h) 88 | if not self.skip_rescale: 89 | return x + h 90 | else: 91 | return (x + h) / np.sqrt(2.) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 96 | fir_kernel=(1, 3, 3, 1)): 97 | super().__init__() 98 | out_ch = out_ch if out_ch else in_ch 99 | if not fir: 100 | if with_conv: 101 | self.Conv_0 = conv3x3(in_ch, out_ch) 102 | else: 103 | if with_conv: 104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 105 | kernel=3, up=True, 106 | resample_kernel=fir_kernel, 107 | use_bias=True, 108 | kernel_init=default_init()) 109 | self.fir = fir 110 | self.with_conv = with_conv 111 | self.fir_kernel = fir_kernel 112 | self.out_ch = out_ch 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.shape 116 | if not self.fir: 117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 118 | if self.with_conv: 119 | h = self.Conv_0(h) 120 | else: 121 | if not self.with_conv: 122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 123 | else: 124 | h = self.Conv2d_0(x) 125 | 126 | return h 127 | 128 | 129 | class Downsample(nn.Module): 130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 131 | fir_kernel=(1, 3, 3, 1)): 132 | super().__init__() 133 | out_ch = out_ch if out_ch else in_ch 134 | if not fir: 135 | if with_conv: 136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 137 | else: 138 | if with_conv: 139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 140 | kernel=3, down=True, 141 | resample_kernel=fir_kernel, 142 | use_bias=True, 143 | kernel_init=default_init()) 144 | self.fir = fir 145 | self.fir_kernel = fir_kernel 146 | self.with_conv = with_conv 147 | self.out_ch = out_ch 148 | 149 | def forward(self, x): 150 | B, C, H, W = x.shape 151 | if not self.fir: 152 | if self.with_conv: 153 | x = F.pad(x, (0, 1, 0, 1)) 154 | x = self.Conv_0(x) 155 | else: 156 | x = F.avg_pool2d(x, 2, stride=2) 157 | else: 158 | if not self.with_conv: 159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 160 | else: 161 | x = self.Conv2d_0(x) 162 | 163 | return x 164 | 165 | 166 | class ResnetBlockDDPMpp(nn.Module): 167 | """ResBlock adapted from DDPM.""" 168 | 169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 170 | dropout=0.1, skip_rescale=False, init_scale=0.): 171 | super().__init__() 172 | out_ch = out_ch if out_ch else in_ch 173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 174 | self.Conv_0 = conv3x3(in_ch, out_ch) 175 | if temb_dim is not None: 176 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 178 | nn.init.zeros_(self.Dense_0.bias) 179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 180 | self.Dropout_0 = nn.Dropout(dropout) 181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 182 | if in_ch != out_ch: 183 | if conv_shortcut: 184 | self.Conv_2 = conv3x3(in_ch, out_ch) 185 | else: 186 | self.NIN_0 = NIN(in_ch, out_ch) 187 | 188 | self.skip_rescale = skip_rescale 189 | self.act = act 190 | self.out_ch = out_ch 191 | self.conv_shortcut = conv_shortcut 192 | 193 | def forward(self, x, temb=None): 194 | h = self.act(self.GroupNorm_0(x)) 195 | h = self.Conv_0(h) 196 | if temb is not None: 197 | h += self.Dense_0(self.act(temb))[:, :, None, None] 198 | h = self.act(self.GroupNorm_1(h)) 199 | h = self.Dropout_0(h) 200 | h = self.Conv_1(h) 201 | if x.shape[1] != self.out_ch: 202 | if self.conv_shortcut: 203 | x = self.Conv_2(x) 204 | else: 205 | x = self.NIN_0(x) 206 | if not self.skip_rescale: 207 | return x + h 208 | else: 209 | return (x + h) / np.sqrt(2.) 210 | 211 | 212 | class ResnetBlockBigGANpp(nn.Module): 213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 215 | skip_rescale=True, init_scale=0.): 216 | super().__init__() 217 | 218 | out_ch = out_ch if out_ch else in_ch 219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 220 | self.up = up 221 | self.down = down 222 | self.fir = fir 223 | self.fir_kernel = fir_kernel 224 | 225 | self.Conv_0 = conv3x3(in_ch, out_ch) 226 | if temb_dim is not None: 227 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 229 | nn.init.zeros_(self.Dense_0.bias) 230 | 231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 232 | self.Dropout_0 = nn.Dropout(dropout) 233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 234 | if in_ch != out_ch or up or down: 235 | self.Conv_2 = conv1x1(in_ch, out_ch) 236 | 237 | self.skip_rescale = skip_rescale 238 | self.act = act 239 | self.in_ch = in_ch 240 | self.out_ch = out_ch 241 | 242 | def forward(self, x, temb=None): 243 | h = self.act(self.GroupNorm_0(x)) 244 | 245 | if self.up: 246 | if self.fir: 247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 249 | else: 250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 252 | elif self.down: 253 | if self.fir: 254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 259 | 260 | h = self.Conv_0(h) 261 | # Add bias to each feature map conditioned on the time embedding 262 | if temb is not None: 263 | h += self.Dense_0(self.act(temb))[:, :, None, None] 264 | h = self.act(self.GroupNorm_1(h)) 265 | h = self.Dropout_0(h) 266 | h = self.Conv_1(h) 267 | 268 | if self.in_ch != self.out_ch or self.up or self.down: 269 | x = self.Conv_2(x) 270 | 271 | if not self.skip_rescale: 272 | return x + h 273 | else: 274 | return (x + h) / np.sqrt(2.) 275 | -------------------------------------------------------------------------------- /utils/flowgrad_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import abc 7 | 8 | from RectifiedFlow.models.utils import from_flattened_numpy, to_flattened_numpy 9 | from scipy import integrate 10 | 11 | 12 | import imageio 13 | 14 | import lpips 15 | import clip 16 | 17 | from .DiffAugment_pytorch import DiffAugment 18 | 19 | import os 20 | import time 21 | 22 | @torch.no_grad() 23 | def embed_to_latent(model_fn, img): 24 | device = img.device 25 | def ode_func(t, x): 26 | x = from_flattened_numpy(x, img.shape).to(device).type(torch.float32) 27 | vec_t = torch.ones(img.shape[0], device=x.device) * t 28 | drift = model_fn(x, vec_t*999) 29 | return to_flattened_numpy(drift) 30 | 31 | rtol=atol=1e-5 32 | method='RK45' 33 | eps=1e-3 34 | 35 | # Initial sample 36 | x = img.detach().clone() 37 | 38 | solution = integrate.solve_ivp(ode_func, (1., eps), to_flattened_numpy(x), 39 | rtol=rtol, atol=atol, method=method) 40 | nfe = solution.nfev 41 | x = torch.tensor(solution.y[:, -1]).reshape(img.shape).to(device).type(torch.float32) 42 | 43 | return x 44 | 45 | @torch.no_grad() 46 | def generate_traj(dynamic, z0, u=None, N=100, straightness_threshold=None): 47 | traj = [] 48 | 49 | # Initial sample 50 | z = z0.detach().clone() 51 | traj.append(z.detach().clone().cpu()) 52 | batchsize = z0.shape[0] 53 | 54 | dt = 1./N 55 | eps = 1e-3 56 | pred_list = [] 57 | for i in range(N): 58 | if (u is not None): 59 | try: 60 | z = z + u[i] 61 | except: 62 | pass 63 | 64 | t = torch.ones(z0.shape[0], device=z0.device) * i / N * (1.-eps) + eps 65 | pred = dynamic(z, t*999) 66 | z = z.detach().clone() + pred * dt 67 | 68 | traj.append(z.detach().clone()) 69 | 70 | pred_list.append(pred.detach().clone().cpu()) 71 | 72 | if straightness_threshold is not None: 73 | ### compute straightness and construct G 74 | non_uniform_set = {} 75 | non_uniform_set['indices'] = [] 76 | non_uniform_set['length'] = {} 77 | accumulate_length = 0 78 | accumulate_straightness = 0 79 | cur_index = 0 80 | for i in range(N): 81 | try: 82 | d1 = (pred_list[i-1] - pred_list[i]).pow(2).sum() / pred_list[i].pow(2).sum() 83 | except: 84 | d1 = 0 85 | 86 | try: 87 | d2 = (pred_list[i+1] - pred_list[i]).pow(2).sum() / pred_list[i].pow(2).sum() 88 | except: 89 | d2 = 0 90 | 91 | d = max(d1, d2) 92 | accumulate_straightness += d 93 | accumulate_length += 1 94 | if (accumulate_straightness > straightness_threshold) or (i==(N-1)): 95 | non_uniform_set['length'][cur_index] = accumulate_length 96 | non_uniform_set['indices'].append(cur_index) 97 | 98 | accumulate_straightness = 0 99 | accumulate_length = 0 100 | cur_index = i+1 101 | 102 | return traj, non_uniform_set 103 | else: 104 | return traj 105 | 106 | @torch.no_grad() 107 | def generate_traj_with_guidance(dynamic, z0, N=100, L=None, alpha_L=1.0): 108 | traj = [] 109 | 110 | # Initial sample 111 | z = z0.detach().clone() 112 | traj.append(z.detach().clone().cpu()) 113 | batchsize = z0.shape[0] 114 | 115 | dt = 1./N 116 | eps = 1e-3 117 | for i in range(N): 118 | t = torch.ones(z0.shape[0], device=z0.device) * i / N * (1.-eps) + eps 119 | 120 | if L is not None: 121 | with torch.enable_grad(): 122 | inputs = z.detach().clone() 123 | inputs.requires_grad = True 124 | pred = dynamic(inputs, t*999) 125 | #loss = L(inputs) ### NOTE: compute loss on xt 126 | loss = L(inputs + pred * (1. - t.detach().clone())) ### NOTE: compute loss on x1 127 | g = torch.autograd.grad(loss, inputs)[0] 128 | g *= alpha_L 129 | print(i, loss.item()) 130 | 131 | z = z.detach().clone() + pred * dt 132 | 133 | if L is not None: 134 | z = z - g.detach().clone() 135 | 136 | traj.append(z.detach().clone()) 137 | 138 | return traj 139 | 140 | def get_img(path=None): 141 | img = imageio.imread(path) ### 4-no expression 142 | img = img / 255. 143 | img = img[np.newaxis, :, :, :] 144 | img = img.transpose(0, 3, 1, 2) 145 | print('read image from:', path, 'img range:', img.min(), img.max()) 146 | img = torch.tensor(img).float() 147 | img = torch.nn.functional.interpolate(img, size=256) 148 | 149 | return img 150 | 151 | def save_img(img, path=None): 152 | torchvision.utils.save_image(img.clamp_(0.0, 1.0), os.path.join(path), nrow=16, normalize=False) 153 | 154 | class clip_semantic_loss(): 155 | def __init__(self, text, img, device, alpha=0.5, replicate=20, inverse_scaler=None): 156 | self.loss_fn_alex = lpips.LPIPS(net='alex', spatial=False).to(device) 157 | self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device) 158 | self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device) 159 | clip_mode="ViT-B/32" 160 | self.interp_mode='bilinear' 161 | self.clip_model, _ = clip.load(clip_mode, device=device) 162 | self.clip_c = self.clip_model.logit_scale.exp() 163 | self.text_tok = clip.tokenize([text]).to(device) 164 | self.policy = 'color,translation,resize,cutout' 165 | self.replicate = 20 166 | self.img = img 167 | self.alpha = alpha 168 | self.inverse_scaler = inverse_scaler 169 | 170 | def L_N(self, x): 171 | sim = (self.inverse_scaler(x) - self.img).abs().mean() 172 | 173 | img_aug = DiffAugment(x.repeat(self.replicate, 1, 1, 1), policy=self.policy) 174 | img_aug = self.inverse_scaler(img_aug) 175 | img_aug = torch.nn.functional.interpolate(img_aug, size=224, mode=self.interp_mode) 176 | img_aug.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 177 | 178 | logits_per_image, logits_per_text = self.clip_model(img_aug, self.text_tok) 179 | logits_per_image = logits_per_image / self.clip_c 180 | concept_loss = (-1.) * logits_per_image 181 | 182 | return self.alpha * concept_loss.mean() + (1.-self.alpha) * sim.sum() 183 | 184 | def flowgrad_optimization(z0, u_ind, dynamic, generate_traj, N=100, L_N=None, u_init=None, number_of_iterations=10, straightness_threshold=5e-3, lr=1.0): 185 | device = z0.device 186 | shape = z0.shape 187 | u = {} 188 | if u_init is None: 189 | for ind in u_ind: 190 | u[ind] = torch.zeros_like(z0).to(z0.device) 191 | u[ind].requires_grad = True 192 | u[ind].grad = torch.zeros_like(u[ind], device=u[ind].device) 193 | else: 194 | for ind in u_init.keys(): 195 | u[ind] = u_init[ind].detach().clone().to(z0.device) 196 | 197 | for ind in u_ind: 198 | try: 199 | u[ind].requires_grad = True 200 | except: 201 | u[ind] = torch.zeros_like(z0).to(z0.device) 202 | u[ind].requires_grad = True 203 | u[ind].grad = torch.zeros_like(u[ind], device=u[ind].device) 204 | 205 | u_optimizer = torch.optim.SGD([u[key] for key in u_ind], lr=lr) ### white black 5e-3 206 | 207 | ### L is supposed to be a function (ideally, a lambda expression). The output of L should a scalar. 208 | L_best = 1e6 209 | for i in range(number_of_iterations): 210 | u_optimizer.zero_grad() 211 | 212 | ### get the forward simulation result and the non-uniform discretization trajectory 213 | ### non_uniform_set: indices and interval length (t_{j+1} - t_j) 214 | z_traj, non_uniform_set = generate_traj(dynamic, z0, u=u, N=N, straightness_threshold=straightness_threshold) 215 | print(non_uniform_set) 216 | 217 | t_s = time.time() 218 | ### use lambda to store \nabla L 219 | inputs = torch.zeros(z_traj[-1].shape, device=device) 220 | inputs.data = z_traj[-1].to(device).detach().clone() 221 | inputs.requires_grad = True 222 | loss = L_N(inputs) 223 | lam = torch.autograd.grad(loss, inputs)[0] 224 | lam = lam.detach().clone() 225 | 226 | print('iteration:', i) 227 | print(' inputs:', inputs.view(-1).detach().cpu().numpy()) 228 | print(' L:%.6f'%loss.detach().cpu().numpy()) 229 | print(' lambda:', lam.reshape(-1).detach().cpu().numpy()) 230 | 231 | if loss.detach().cpu().numpy() < L_best: 232 | opt_u = {} 233 | for ind in u.keys(): 234 | opt_u[ind] = u[ind].detach().clone() 235 | L_best = loss.detach().cpu().numpy() 236 | print(' L_best:%.6f'%L_best) 237 | if i == number_of_iterations: break 238 | 239 | eps = 1e-3 # default: 1e-3 240 | g_old = None 241 | d = [] 242 | for j in range(N-1, -1, -1): 243 | if j in non_uniform_set['indices']: 244 | assert j in u_ind 245 | else: 246 | continue 247 | 248 | ### compute lambda: correct vjp version 249 | inputs = torch.zeros(lam.shape, device=device) 250 | inputs.data = z_traj[j].to(device).detach().clone() 251 | inputs.requires_grad = True 252 | t = (torch.ones((1, )) * j / N * (1.-eps) + eps) * 999 253 | func = lambda x: (x.contiguous().reshape(shape) + u[j].detach().clone() + \ 254 | dynamic(x.contiguous().reshape(shape) + u[j].detach().clone(), t.detach().clone()) * non_uniform_set['length'][j] / N).view(-1) 255 | output, vjp = torch.autograd.functional.vjp(func, inputs=inputs.view(-1), v=lam.detach().clone().reshape(-1)) 256 | lam = vjp.detach().clone().contiguous().reshape(shape) 257 | 258 | u[j].grad = lam.detach().clone() 259 | del inputs 260 | if j == 0: break 261 | 262 | print('BP time:', time.time() - t_s) 263 | ### Re-assignment 264 | for j in range(len(non_uniform_set['indices'])): 265 | start = non_uniform_set['indices'][j] 266 | try: 267 | end = non_uniform_set['indices'][j+1] 268 | except: 269 | end = N 270 | 271 | for k in range(start, end): 272 | if k in u_ind: 273 | u[k].grad = u[start].grad.detach().clone() 274 | 275 | u_optimizer.step() 276 | 277 | opt_u = {} 278 | for ind in u.keys(): 279 | opt_u[ind] = u[ind].detach().clone() 280 | 281 | return opt_u 282 | -------------------------------------------------------------------------------- /RectifiedFlow/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /RectifiedFlow/models/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from . import utils, layers, layerspp, normalization 19 | import torch.nn as nn 20 | import functools 21 | import torch 22 | import numpy as np 23 | 24 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 25 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 26 | Combine = layerspp.Combine 27 | conv3x3 = layerspp.conv3x3 28 | conv1x1 = layerspp.conv1x1 29 | get_act = layers.get_act 30 | get_normalization = normalization.get_normalization 31 | default_initializer = layers.default_init 32 | 33 | 34 | @utils.register_model(name='ncsnpp') 35 | class NCSNpp(nn.Module): 36 | """NCSN++ model""" 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.config = config 41 | self.act = act = get_act(config) 42 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 43 | 44 | self.nf = nf = config.model.nf 45 | ch_mult = config.model.ch_mult 46 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 47 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 48 | dropout = config.model.dropout 49 | resamp_with_conv = config.model.resamp_with_conv 50 | self.num_resolutions = num_resolutions = len(ch_mult) 51 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 52 | 53 | self.conditional = conditional = config.model.conditional # noise-conditional 54 | fir = config.model.fir 55 | fir_kernel = config.model.fir_kernel 56 | self.skip_rescale = skip_rescale = config.model.skip_rescale 57 | self.resblock_type = resblock_type = config.model.resblock_type.lower() 58 | self.progressive = progressive = config.model.progressive.lower() 59 | self.progressive_input = progressive_input = config.model.progressive_input.lower() 60 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 61 | init_scale = config.model.init_scale 62 | assert progressive in ['none', 'output_skip', 'residual'] 63 | assert progressive_input in ['none', 'input_skip', 'residual'] 64 | assert embedding_type in ['fourier', 'positional'] 65 | combine_method = config.model.progressive_combine.lower() 66 | combiner = functools.partial(Combine, method=combine_method) 67 | 68 | modules = [] 69 | # timestep/noise_level embedding; only for continuous training 70 | if embedding_type == 'fourier': 71 | # Gaussian Fourier features embeddings. 72 | assert config.training.continuous or config.training.sde=='rectified_flow', "Fourier features are only used for continuous training." 73 | 74 | modules.append(layerspp.GaussianFourierProjection( 75 | embedding_size=nf, scale=config.model.fourier_scale 76 | )) 77 | embed_dim = 2 * nf 78 | 79 | elif embedding_type == 'positional': 80 | embed_dim = nf 81 | 82 | else: 83 | raise ValueError(f'embedding type {embedding_type} unknown.') 84 | 85 | if conditional: 86 | modules.append(nn.Linear(embed_dim, nf * 4)) 87 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 88 | nn.init.zeros_(modules[-1].bias) 89 | modules.append(nn.Linear(nf * 4, nf * 4)) 90 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 91 | nn.init.zeros_(modules[-1].bias) 92 | 93 | AttnBlock = functools.partial(layerspp.AttnBlockpp, 94 | init_scale=init_scale, 95 | skip_rescale=skip_rescale) 96 | 97 | Upsample = functools.partial(layerspp.Upsample, 98 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 99 | 100 | if progressive == 'output_skip': 101 | self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 102 | elif progressive == 'residual': 103 | pyramid_upsample = functools.partial(layerspp.Upsample, 104 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 105 | 106 | Downsample = functools.partial(layerspp.Downsample, 107 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 108 | 109 | if progressive_input == 'input_skip': 110 | self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 111 | elif progressive_input == 'residual': 112 | pyramid_downsample = functools.partial(layerspp.Downsample, 113 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 114 | 115 | if resblock_type == 'ddpm': 116 | ResnetBlock = functools.partial(ResnetBlockDDPM, 117 | act=act, 118 | dropout=dropout, 119 | init_scale=init_scale, 120 | skip_rescale=skip_rescale, 121 | temb_dim=nf * 4) 122 | 123 | elif resblock_type == 'biggan': 124 | ResnetBlock = functools.partial(ResnetBlockBigGAN, 125 | act=act, 126 | dropout=dropout, 127 | fir=fir, 128 | fir_kernel=fir_kernel, 129 | init_scale=init_scale, 130 | skip_rescale=skip_rescale, 131 | temb_dim=nf * 4) 132 | 133 | else: 134 | raise ValueError(f'resblock type {resblock_type} unrecognized.') 135 | 136 | # Downsampling block 137 | 138 | channels = config.data.num_channels 139 | if progressive_input != 'none': 140 | input_pyramid_ch = channels 141 | 142 | modules.append(conv3x3(channels, nf)) 143 | hs_c = [nf] 144 | 145 | in_ch = nf 146 | for i_level in range(num_resolutions): 147 | # Residual blocks for this resolution 148 | for i_block in range(num_res_blocks): 149 | out_ch = nf * ch_mult[i_level] 150 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 151 | in_ch = out_ch 152 | 153 | if all_resolutions[i_level] in attn_resolutions: 154 | modules.append(AttnBlock(channels=in_ch)) 155 | hs_c.append(in_ch) 156 | 157 | if i_level != num_resolutions - 1: 158 | if resblock_type == 'ddpm': 159 | modules.append(Downsample(in_ch=in_ch)) 160 | else: 161 | modules.append(ResnetBlock(down=True, in_ch=in_ch)) 162 | 163 | if progressive_input == 'input_skip': 164 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) 165 | if combine_method == 'cat': 166 | in_ch *= 2 167 | 168 | elif progressive_input == 'residual': 169 | modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) 170 | input_pyramid_ch = in_ch 171 | 172 | hs_c.append(in_ch) 173 | 174 | in_ch = hs_c[-1] 175 | modules.append(ResnetBlock(in_ch=in_ch)) 176 | modules.append(AttnBlock(channels=in_ch)) 177 | modules.append(ResnetBlock(in_ch=in_ch)) 178 | 179 | pyramid_ch = 0 180 | # Upsampling block 181 | for i_level in reversed(range(num_resolutions)): 182 | for i_block in range(num_res_blocks + 1): 183 | out_ch = nf * ch_mult[i_level] 184 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), 185 | out_ch=out_ch)) 186 | in_ch = out_ch 187 | 188 | if all_resolutions[i_level] in attn_resolutions: 189 | modules.append(AttnBlock(channels=in_ch)) 190 | 191 | if progressive != 'none': 192 | if i_level == num_resolutions - 1: 193 | if progressive == 'output_skip': 194 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 195 | num_channels=in_ch, eps=1e-6)) 196 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 197 | pyramid_ch = channels 198 | elif progressive == 'residual': 199 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 200 | num_channels=in_ch, eps=1e-6)) 201 | modules.append(conv3x3(in_ch, in_ch, bias=True)) 202 | pyramid_ch = in_ch 203 | else: 204 | raise ValueError(f'{progressive} is not a valid name.') 205 | else: 206 | if progressive == 'output_skip': 207 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 208 | num_channels=in_ch, eps=1e-6)) 209 | modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) 210 | pyramid_ch = channels 211 | elif progressive == 'residual': 212 | modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) 213 | pyramid_ch = in_ch 214 | else: 215 | raise ValueError(f'{progressive} is not a valid name') 216 | 217 | if i_level != 0: 218 | if resblock_type == 'ddpm': 219 | modules.append(Upsample(in_ch=in_ch)) 220 | else: 221 | modules.append(ResnetBlock(in_ch=in_ch, up=True)) 222 | 223 | assert not hs_c 224 | 225 | if progressive != 'output_skip': 226 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 227 | num_channels=in_ch, eps=1e-6)) 228 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 229 | 230 | self.all_modules = nn.ModuleList(modules) 231 | 232 | 233 | def forward(self, x, time_cond): 234 | # timestep/noise_level embedding; only for continuous training 235 | modules = self.all_modules 236 | m_idx = 0 237 | if self.embedding_type == 'fourier': 238 | # Gaussian Fourier features embeddings. 239 | used_sigmas = time_cond 240 | temb = modules[m_idx](torch.log(used_sigmas)) 241 | m_idx += 1 242 | 243 | elif self.embedding_type == 'positional': 244 | # Sinusoidal positional embeddings. 245 | timesteps = time_cond 246 | used_sigmas = self.sigmas[time_cond.long()] 247 | temb = layers.get_timestep_embedding(timesteps, self.nf) 248 | 249 | else: 250 | raise ValueError(f'embedding type {self.embedding_type} unknown.') 251 | 252 | if self.conditional: 253 | temb = modules[m_idx](temb) 254 | m_idx += 1 255 | temb = modules[m_idx](self.act(temb)) 256 | m_idx += 1 257 | else: 258 | temb = None 259 | 260 | if not self.config.data.centered: 261 | # If input data is in [0, 1] 262 | x = 2 * x - 1. 263 | 264 | # Downsampling block 265 | input_pyramid = None 266 | if self.progressive_input != 'none': 267 | input_pyramid = x 268 | 269 | hs = [modules[m_idx](x)] 270 | m_idx += 1 271 | for i_level in range(self.num_resolutions): 272 | # Residual blocks for this resolution 273 | for i_block in range(self.num_res_blocks): 274 | h = modules[m_idx](hs[-1], temb) 275 | m_idx += 1 276 | if h.shape[-1] in self.attn_resolutions: 277 | h = modules[m_idx](h) 278 | m_idx += 1 279 | 280 | hs.append(h) 281 | 282 | if i_level != self.num_resolutions - 1: 283 | if self.resblock_type == 'ddpm': 284 | h = modules[m_idx](hs[-1]) 285 | m_idx += 1 286 | else: 287 | h = modules[m_idx](hs[-1], temb) 288 | m_idx += 1 289 | 290 | if self.progressive_input == 'input_skip': 291 | input_pyramid = self.pyramid_downsample(input_pyramid) 292 | h = modules[m_idx](input_pyramid, h) 293 | m_idx += 1 294 | 295 | elif self.progressive_input == 'residual': 296 | input_pyramid = modules[m_idx](input_pyramid) 297 | m_idx += 1 298 | if self.skip_rescale: 299 | input_pyramid = (input_pyramid + h) / np.sqrt(2.) 300 | else: 301 | input_pyramid = input_pyramid + h 302 | h = input_pyramid 303 | 304 | hs.append(h) 305 | 306 | 307 | h = hs[-1] 308 | h = modules[m_idx](h, temb) 309 | m_idx += 1 310 | h = modules[m_idx](h) 311 | m_idx += 1 312 | h = modules[m_idx](h, temb) 313 | m_idx += 1 314 | 315 | pyramid = None 316 | 317 | # Upsampling block 318 | for i_level in reversed(range(self.num_resolutions)): 319 | for i_block in range(self.num_res_blocks + 1): 320 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 321 | m_idx += 1 322 | 323 | if h.shape[-1] in self.attn_resolutions: 324 | h = modules[m_idx](h) 325 | m_idx += 1 326 | 327 | if self.progressive != 'none': 328 | if i_level == self.num_resolutions - 1: 329 | if self.progressive == 'output_skip': 330 | pyramid = self.act(modules[m_idx](h)) 331 | m_idx += 1 332 | pyramid = modules[m_idx](pyramid) 333 | m_idx += 1 334 | elif self.progressive == 'residual': 335 | pyramid = self.act(modules[m_idx](h)) 336 | m_idx += 1 337 | pyramid = modules[m_idx](pyramid) 338 | m_idx += 1 339 | else: 340 | raise ValueError(f'{self.progressive} is not a valid name.') 341 | else: 342 | if self.progressive == 'output_skip': 343 | pyramid = self.pyramid_upsample(pyramid) 344 | pyramid_h = self.act(modules[m_idx](h)) 345 | m_idx += 1 346 | pyramid_h = modules[m_idx](pyramid_h) 347 | m_idx += 1 348 | pyramid = pyramid + pyramid_h 349 | elif self.progressive == 'residual': 350 | pyramid = modules[m_idx](pyramid) 351 | m_idx += 1 352 | if self.skip_rescale: 353 | pyramid = (pyramid + h) / np.sqrt(2.) 354 | else: 355 | pyramid = pyramid + h 356 | h = pyramid 357 | else: 358 | raise ValueError(f'{self.progressive} is not a valid name') 359 | 360 | if i_level != 0: 361 | if self.resblock_type == 'ddpm': 362 | h = modules[m_idx](h) 363 | m_idx += 1 364 | else: 365 | h = modules[m_idx](h, temb) 366 | m_idx += 1 367 | 368 | assert not hs 369 | 370 | if self.progressive == 'output_skip': 371 | h = pyramid 372 | else: 373 | h = self.act(modules[m_idx](h)) 374 | m_idx += 1 375 | h = modules[m_idx](h) 376 | m_idx += 1 377 | 378 | assert m_idx == len(modules) 379 | if self.config.model.scale_by_sigma: 380 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 381 | h = h / used_sigmas 382 | 383 | return h 384 | -------------------------------------------------------------------------------- /RectifiedFlow/models/ncsnv2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """The NCSNv2 model.""" 18 | import torch 19 | import torch.nn as nn 20 | import functools 21 | 22 | from .utils import get_sigmas, register_model 23 | from .layers import (CondRefineBlock, RefineBlock, ResidualBlock, ncsn_conv3x3, 24 | ConditionalResidualBlock, get_act) 25 | from .normalization import get_normalization 26 | 27 | CondResidualBlock = ConditionalResidualBlock 28 | conv3x3 = ncsn_conv3x3 29 | 30 | 31 | def get_network(config): 32 | if config.data.image_size < 96: 33 | return functools.partial(NCSNv2, config=config) 34 | elif 96 <= config.data.image_size <= 128: 35 | return functools.partial(NCSNv2_128, config=config) 36 | elif 128 < config.data.image_size <= 256: 37 | return functools.partial(NCSNv2_256, config=config) 38 | else: 39 | raise NotImplementedError( 40 | f'No network suitable for {config.data.image_size}px implemented yet.') 41 | 42 | 43 | @register_model(name='ncsnv2_64') 44 | class NCSNv2(nn.Module): 45 | def __init__(self, config): 46 | super().__init__() 47 | self.centered = config.data.centered 48 | self.norm = get_normalization(config) 49 | self.nf = nf = config.model.nf 50 | 51 | self.act = act = get_act(config) 52 | self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) 53 | self.config = config 54 | 55 | self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) 56 | 57 | self.normalizer = self.norm(nf, config.model.num_scales) 58 | self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) 59 | 60 | self.res1 = nn.ModuleList([ 61 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 62 | normalization=self.norm), 63 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 64 | normalization=self.norm)] 65 | ) 66 | 67 | self.res2 = nn.ModuleList([ 68 | ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, 69 | normalization=self.norm), 70 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 71 | normalization=self.norm)] 72 | ) 73 | 74 | self.res3 = nn.ModuleList([ 75 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 76 | normalization=self.norm, dilation=2), 77 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 78 | normalization=self.norm, dilation=2)] 79 | ) 80 | 81 | if config.data.image_size == 28: 82 | self.res4 = nn.ModuleList([ 83 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 84 | normalization=self.norm, adjust_padding=True, dilation=4), 85 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 86 | normalization=self.norm, dilation=4)] 87 | ) 88 | else: 89 | self.res4 = nn.ModuleList([ 90 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 91 | normalization=self.norm, adjust_padding=False, dilation=4), 92 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 93 | normalization=self.norm, dilation=4)] 94 | ) 95 | 96 | self.refine1 = RefineBlock([2 * self.nf], 2 * self.nf, act=act, start=True) 97 | self.refine2 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) 98 | self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) 99 | self.refine4 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) 100 | 101 | def _compute_cond_module(self, module, x): 102 | for m in module: 103 | x = m(x) 104 | return x 105 | 106 | def forward(self, x, y): 107 | if not self.centered: 108 | h = 2 * x - 1. 109 | else: 110 | h = x 111 | 112 | output = self.begin_conv(h) 113 | 114 | layer1 = self._compute_cond_module(self.res1, output) 115 | layer2 = self._compute_cond_module(self.res2, layer1) 116 | layer3 = self._compute_cond_module(self.res3, layer2) 117 | layer4 = self._compute_cond_module(self.res4, layer3) 118 | 119 | ref1 = self.refine1([layer4], layer4.shape[2:]) 120 | ref2 = self.refine2([layer3, ref1], layer3.shape[2:]) 121 | ref3 = self.refine3([layer2, ref2], layer2.shape[2:]) 122 | output = self.refine4([layer1, ref3], layer1.shape[2:]) 123 | 124 | output = self.normalizer(output) 125 | output = self.act(output) 126 | output = self.end_conv(output) 127 | 128 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 129 | 130 | output = output / used_sigmas 131 | 132 | return output 133 | 134 | 135 | @register_model(name='ncsn') 136 | class NCSN(nn.Module): 137 | def __init__(self, config): 138 | super().__init__() 139 | self.centered = config.data.centered 140 | self.norm = get_normalization(config) 141 | self.nf = nf = config.model.nf 142 | self.act = act = get_act(config) 143 | self.config = config 144 | 145 | self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) 146 | 147 | self.normalizer = self.norm(nf, config.model.num_scales) 148 | self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) 149 | 150 | self.res1 = nn.ModuleList([ 151 | ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act, 152 | normalization=self.norm), 153 | ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act, 154 | normalization=self.norm)] 155 | ) 156 | 157 | self.res2 = nn.ModuleList([ 158 | ConditionalResidualBlock(self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, 159 | normalization=self.norm), 160 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, 161 | normalization=self.norm)] 162 | ) 163 | 164 | self.res3 = nn.ModuleList([ 165 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, 166 | normalization=self.norm, dilation=2), 167 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, 168 | normalization=self.norm, dilation=2)] 169 | ) 170 | 171 | if config.data.image_size == 28: 172 | self.res4 = nn.ModuleList([ 173 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, 174 | normalization=self.norm, adjust_padding=True, dilation=4), 175 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, 176 | normalization=self.norm, dilation=4)] 177 | ) 178 | else: 179 | self.res4 = nn.ModuleList([ 180 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, 181 | normalization=self.norm, adjust_padding=False, dilation=4), 182 | ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, 183 | normalization=self.norm, dilation=4)] 184 | ) 185 | 186 | self.refine1 = CondRefineBlock([2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act, start=True) 187 | self.refine2 = CondRefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act) 188 | self.refine3 = CondRefineBlock([2 * self.nf, 2 * self.nf], self.nf, config.model.num_scales, self.norm, act=act) 189 | self.refine4 = CondRefineBlock([self.nf, self.nf], self.nf, config.model.num_scales, self.norm, act=act, end=True) 190 | 191 | def _compute_cond_module(self, module, x, y): 192 | for m in module: 193 | x = m(x, y) 194 | return x 195 | 196 | def forward(self, x, y): 197 | if not self.centered: 198 | h = 2 * x - 1. 199 | else: 200 | h = x 201 | 202 | output = self.begin_conv(h) 203 | 204 | layer1 = self._compute_cond_module(self.res1, output, y) 205 | layer2 = self._compute_cond_module(self.res2, layer1, y) 206 | layer3 = self._compute_cond_module(self.res3, layer2, y) 207 | layer4 = self._compute_cond_module(self.res4, layer3, y) 208 | 209 | ref1 = self.refine1([layer4], y, layer4.shape[2:]) 210 | ref2 = self.refine2([layer3, ref1], y, layer3.shape[2:]) 211 | ref3 = self.refine3([layer2, ref2], y, layer2.shape[2:]) 212 | output = self.refine4([layer1, ref3], y, layer1.shape[2:]) 213 | 214 | output = self.normalizer(output, y) 215 | output = self.act(output) 216 | output = self.end_conv(output) 217 | 218 | return output 219 | 220 | 221 | @register_model(name='ncsnv2_128') 222 | class NCSNv2_128(nn.Module): 223 | """NCSNv2 model architecture for 128px images.""" 224 | def __init__(self, config): 225 | super().__init__() 226 | self.centered = config.data.centered 227 | self.norm = get_normalization(config) 228 | self.nf = nf = config.model.nf 229 | self.act = act = get_act(config) 230 | self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) 231 | self.config = config 232 | 233 | self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) 234 | self.normalizer = self.norm(nf, config.model.num_scales) 235 | 236 | self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) 237 | 238 | self.res1 = nn.ModuleList([ 239 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 240 | normalization=self.norm), 241 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 242 | normalization=self.norm)] 243 | ) 244 | 245 | self.res2 = nn.ModuleList([ 246 | ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, 247 | normalization=self.norm), 248 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 249 | normalization=self.norm)] 250 | ) 251 | 252 | self.res3 = nn.ModuleList([ 253 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 254 | normalization=self.norm), 255 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 256 | normalization=self.norm)] 257 | ) 258 | 259 | self.res4 = nn.ModuleList([ 260 | ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act, 261 | normalization=self.norm, dilation=2), 262 | ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, 263 | normalization=self.norm, dilation=2)] 264 | ) 265 | 266 | self.res5 = nn.ModuleList([ 267 | ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act, 268 | normalization=self.norm, dilation=4), 269 | ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, 270 | normalization=self.norm, dilation=4)] 271 | ) 272 | 273 | self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True) 274 | self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act) 275 | self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) 276 | self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) 277 | self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) 278 | 279 | def _compute_cond_module(self, module, x): 280 | for m in module: 281 | x = m(x) 282 | return x 283 | 284 | def forward(self, x, y): 285 | if not self.centered: 286 | h = 2 * x - 1. 287 | else: 288 | h = x 289 | 290 | output = self.begin_conv(h) 291 | 292 | layer1 = self._compute_cond_module(self.res1, output) 293 | layer2 = self._compute_cond_module(self.res2, layer1) 294 | layer3 = self._compute_cond_module(self.res3, layer2) 295 | layer4 = self._compute_cond_module(self.res4, layer3) 296 | layer5 = self._compute_cond_module(self.res5, layer4) 297 | 298 | ref1 = self.refine1([layer5], layer5.shape[2:]) 299 | ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) 300 | ref3 = self.refine3([layer3, ref2], layer3.shape[2:]) 301 | ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) 302 | output = self.refine5([layer1, ref4], layer1.shape[2:]) 303 | 304 | output = self.normalizer(output) 305 | output = self.act(output) 306 | output = self.end_conv(output) 307 | 308 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 309 | 310 | output = output / used_sigmas 311 | 312 | return output 313 | 314 | 315 | @register_model(name='ncsnv2_256') 316 | class NCSNv2_256(nn.Module): 317 | """NCSNv2 model architecture for 256px images.""" 318 | def __init__(self, config): 319 | super().__init__() 320 | self.centered = config.data.centered 321 | self.norm = get_normalization(config) 322 | self.nf = nf = config.model.nf 323 | self.act = act = get_act(config) 324 | self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) 325 | self.config = config 326 | 327 | self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) 328 | self.normalizer = self.norm(nf, config.model.num_scales) 329 | 330 | self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) 331 | 332 | self.res1 = nn.ModuleList([ 333 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 334 | normalization=self.norm), 335 | ResidualBlock(self.nf, self.nf, resample=None, act=act, 336 | normalization=self.norm)] 337 | ) 338 | 339 | self.res2 = nn.ModuleList([ 340 | ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, 341 | normalization=self.norm), 342 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 343 | normalization=self.norm)] 344 | ) 345 | 346 | self.res3 = nn.ModuleList([ 347 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 348 | normalization=self.norm), 349 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 350 | normalization=self.norm)] 351 | ) 352 | 353 | self.res31 = nn.ModuleList([ 354 | ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, 355 | normalization=self.norm), 356 | ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, 357 | normalization=self.norm)] 358 | ) 359 | 360 | self.res4 = nn.ModuleList([ 361 | ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act, 362 | normalization=self.norm, dilation=2), 363 | ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, 364 | normalization=self.norm, dilation=2)] 365 | ) 366 | 367 | self.res5 = nn.ModuleList([ 368 | ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act, 369 | normalization=self.norm, dilation=4), 370 | ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, 371 | normalization=self.norm, dilation=4)] 372 | ) 373 | 374 | self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True) 375 | self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act) 376 | self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) 377 | self.refine31 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) 378 | self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) 379 | self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) 380 | 381 | def _compute_cond_module(self, module, x): 382 | for m in module: 383 | x = m(x) 384 | return x 385 | 386 | def forward(self, x, y): 387 | if not self.centered: 388 | h = 2 * x - 1. 389 | else: 390 | h = x 391 | 392 | output = self.begin_conv(h) 393 | 394 | layer1 = self._compute_cond_module(self.res1, output) 395 | layer2 = self._compute_cond_module(self.res2, layer1) 396 | layer3 = self._compute_cond_module(self.res3, layer2) 397 | layer31 = self._compute_cond_module(self.res31, layer3) 398 | layer4 = self._compute_cond_module(self.res4, layer31) 399 | layer5 = self._compute_cond_module(self.res5, layer4) 400 | 401 | ref1 = self.refine1([layer5], layer5.shape[2:]) 402 | ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) 403 | ref31 = self.refine31([layer31, ref2], layer31.shape[2:]) 404 | ref3 = self.refine3([layer3, ref31], layer3.shape[2:]) 405 | ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) 406 | output = self.refine5([layer1, ref4], layer1.shape[2:]) 407 | 408 | output = self.normalizer(output) 409 | output = self.act(output) 410 | output = self.end_conv(output) 411 | 412 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 413 | 414 | output = output / used_sigmas 415 | 416 | return output -------------------------------------------------------------------------------- /RectifiedFlow/models/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Common layers for defining score networks. 18 | """ 19 | import math 20 | import string 21 | from functools import partial 22 | import torch.nn as nn 23 | import torch 24 | import torch.nn.functional as F 25 | import numpy as np 26 | from .normalization import ConditionalInstanceNorm2dPlus 27 | 28 | 29 | def get_act(config): 30 | """Get activation functions from the config file.""" 31 | 32 | if config.model.nonlinearity.lower() == 'elu': 33 | return nn.ELU() 34 | elif config.model.nonlinearity.lower() == 'relu': 35 | return nn.ReLU() 36 | elif config.model.nonlinearity.lower() == 'lrelu': 37 | return nn.LeakyReLU(negative_slope=0.2) 38 | elif config.model.nonlinearity.lower() == 'swish': 39 | return nn.SiLU() 40 | else: 41 | raise NotImplementedError('activation function does not exist!') 42 | 43 | 44 | def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0): 45 | """1x1 convolution. Same as NCSNv1/v2.""" 46 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, 47 | padding=padding) 48 | init_scale = 1e-10 if init_scale == 0 else init_scale 49 | conv.weight.data *= init_scale 50 | conv.bias.data *= init_scale 51 | return conv 52 | 53 | 54 | def variance_scaling(scale, mode, distribution, 55 | in_axis=1, out_axis=0, 56 | dtype=torch.float32, 57 | device='cpu'): 58 | """Ported from JAX. """ 59 | 60 | def _compute_fans(shape, in_axis=1, out_axis=0): 61 | receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] 62 | fan_in = shape[in_axis] * receptive_field_size 63 | fan_out = shape[out_axis] * receptive_field_size 64 | return fan_in, fan_out 65 | 66 | def init(shape, dtype=dtype, device=device): 67 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 68 | if mode == "fan_in": 69 | denominator = fan_in 70 | elif mode == "fan_out": 71 | denominator = fan_out 72 | elif mode == "fan_avg": 73 | denominator = (fan_in + fan_out) / 2 74 | else: 75 | raise ValueError( 76 | "invalid mode for variance scaling initializer: {}".format(mode)) 77 | variance = scale / denominator 78 | if distribution == "normal": 79 | return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) 80 | elif distribution == "uniform": 81 | return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) 82 | else: 83 | raise ValueError("invalid distribution for variance scaling initializer") 84 | 85 | return init 86 | 87 | 88 | def default_init(scale=1.): 89 | """The same initialization used in DDPM.""" 90 | scale = 1e-10 if scale == 0 else scale 91 | return variance_scaling(scale, 'fan_avg', 'uniform') 92 | 93 | 94 | class Dense(nn.Module): 95 | """Linear layer with `default_init`.""" 96 | def __init__(self): 97 | super().__init__() 98 | 99 | 100 | def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): 101 | """1x1 convolution with DDPM initialization.""" 102 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) 103 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 104 | nn.init.zeros_(conv.bias) 105 | return conv 106 | 107 | 108 | def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 109 | """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" 110 | init_scale = 1e-10 if init_scale == 0 else init_scale 111 | conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias, 112 | dilation=dilation, padding=padding, kernel_size=3) 113 | conv.weight.data *= init_scale 114 | conv.bias.data *= init_scale 115 | return conv 116 | 117 | 118 | def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 119 | """3x3 convolution with DDPM initialization.""" 120 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 121 | dilation=dilation, bias=bias) 122 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 123 | nn.init.zeros_(conv.bias) 124 | return conv 125 | 126 | ########################################################################### 127 | # Functions below are ported over from the NCSNv1/NCSNv2 codebase: 128 | # https://github.com/ermongroup/ncsn 129 | # https://github.com/ermongroup/ncsnv2 130 | ########################################################################### 131 | 132 | 133 | class CRPBlock(nn.Module): 134 | def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): 135 | super().__init__() 136 | self.convs = nn.ModuleList() 137 | for i in range(n_stages): 138 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 139 | self.n_stages = n_stages 140 | if maxpool: 141 | self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 142 | else: 143 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 144 | 145 | self.act = act 146 | 147 | def forward(self, x): 148 | x = self.act(x) 149 | path = x 150 | for i in range(self.n_stages): 151 | path = self.pool(path) 152 | path = self.convs[i](path) 153 | x = path + x 154 | return x 155 | 156 | 157 | class CondCRPBlock(nn.Module): 158 | def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): 159 | super().__init__() 160 | self.convs = nn.ModuleList() 161 | self.norms = nn.ModuleList() 162 | self.normalizer = normalizer 163 | for i in range(n_stages): 164 | self.norms.append(normalizer(features, num_classes, bias=True)) 165 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 166 | 167 | self.n_stages = n_stages 168 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 169 | self.act = act 170 | 171 | def forward(self, x, y): 172 | x = self.act(x) 173 | path = x 174 | for i in range(self.n_stages): 175 | path = self.norms[i](path, y) 176 | path = self.pool(path) 177 | path = self.convs[i](path) 178 | 179 | x = path + x 180 | return x 181 | 182 | 183 | class RCUBlock(nn.Module): 184 | def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): 185 | super().__init__() 186 | 187 | for i in range(n_blocks): 188 | for j in range(n_stages): 189 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 190 | 191 | self.stride = 1 192 | self.n_blocks = n_blocks 193 | self.n_stages = n_stages 194 | self.act = act 195 | 196 | def forward(self, x): 197 | for i in range(self.n_blocks): 198 | residual = x 199 | for j in range(self.n_stages): 200 | x = self.act(x) 201 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 202 | 203 | x += residual 204 | return x 205 | 206 | 207 | class CondRCUBlock(nn.Module): 208 | def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()): 209 | super().__init__() 210 | 211 | for i in range(n_blocks): 212 | for j in range(n_stages): 213 | setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) 214 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 215 | 216 | self.stride = 1 217 | self.n_blocks = n_blocks 218 | self.n_stages = n_stages 219 | self.act = act 220 | self.normalizer = normalizer 221 | 222 | def forward(self, x, y): 223 | for i in range(self.n_blocks): 224 | residual = x 225 | for j in range(self.n_stages): 226 | x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) 227 | x = self.act(x) 228 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 229 | 230 | x += residual 231 | return x 232 | 233 | 234 | class MSFBlock(nn.Module): 235 | def __init__(self, in_planes, features): 236 | super().__init__() 237 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 238 | self.convs = nn.ModuleList() 239 | self.features = features 240 | 241 | for i in range(len(in_planes)): 242 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 243 | 244 | def forward(self, xs, shape): 245 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 246 | for i in range(len(self.convs)): 247 | h = self.convs[i](xs[i]) 248 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 249 | sums += h 250 | return sums 251 | 252 | 253 | class CondMSFBlock(nn.Module): 254 | def __init__(self, in_planes, features, num_classes, normalizer): 255 | super().__init__() 256 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 257 | 258 | self.convs = nn.ModuleList() 259 | self.norms = nn.ModuleList() 260 | self.features = features 261 | self.normalizer = normalizer 262 | 263 | for i in range(len(in_planes)): 264 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 265 | self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) 266 | 267 | def forward(self, xs, y, shape): 268 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 269 | for i in range(len(self.convs)): 270 | h = self.norms[i](xs[i], y) 271 | h = self.convs[i](h) 272 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 273 | sums += h 274 | return sums 275 | 276 | 277 | class RefineBlock(nn.Module): 278 | def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True): 279 | super().__init__() 280 | 281 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 282 | self.n_blocks = n_blocks = len(in_planes) 283 | 284 | self.adapt_convs = nn.ModuleList() 285 | for i in range(n_blocks): 286 | self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) 287 | 288 | self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) 289 | 290 | if not start: 291 | self.msf = MSFBlock(in_planes, features) 292 | 293 | self.crp = CRPBlock(features, 2, act, maxpool=maxpool) 294 | 295 | def forward(self, xs, output_shape): 296 | assert isinstance(xs, tuple) or isinstance(xs, list) 297 | hs = [] 298 | for i in range(len(xs)): 299 | h = self.adapt_convs[i](xs[i]) 300 | hs.append(h) 301 | 302 | if self.n_blocks > 1: 303 | h = self.msf(hs, output_shape) 304 | else: 305 | h = hs[0] 306 | 307 | h = self.crp(h) 308 | h = self.output_convs(h) 309 | 310 | return h 311 | 312 | 313 | class CondRefineBlock(nn.Module): 314 | def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False): 315 | super().__init__() 316 | 317 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 318 | self.n_blocks = n_blocks = len(in_planes) 319 | 320 | self.adapt_convs = nn.ModuleList() 321 | for i in range(n_blocks): 322 | self.adapt_convs.append( 323 | CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) 324 | ) 325 | 326 | self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act) 327 | 328 | if not start: 329 | self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) 330 | 331 | self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) 332 | 333 | def forward(self, xs, y, output_shape): 334 | assert isinstance(xs, tuple) or isinstance(xs, list) 335 | hs = [] 336 | for i in range(len(xs)): 337 | h = self.adapt_convs[i](xs[i], y) 338 | hs.append(h) 339 | 340 | if self.n_blocks > 1: 341 | h = self.msf(hs, y, output_shape) 342 | else: 343 | h = hs[0] 344 | 345 | h = self.crp(h, y) 346 | h = self.output_convs(h, y) 347 | 348 | return h 349 | 350 | 351 | class ConvMeanPool(nn.Module): 352 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False): 353 | super().__init__() 354 | if not adjust_padding: 355 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 356 | self.conv = conv 357 | else: 358 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 359 | 360 | self.conv = nn.Sequential( 361 | nn.ZeroPad2d((1, 0, 1, 0)), 362 | conv 363 | ) 364 | 365 | def forward(self, inputs): 366 | output = self.conv(inputs) 367 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 368 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 369 | return output 370 | 371 | 372 | class MeanPoolConv(nn.Module): 373 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 374 | super().__init__() 375 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 376 | 377 | def forward(self, inputs): 378 | output = inputs 379 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 380 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 381 | return self.conv(output) 382 | 383 | 384 | class UpsampleConv(nn.Module): 385 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 386 | super().__init__() 387 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 388 | self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) 389 | 390 | def forward(self, inputs): 391 | output = inputs 392 | output = torch.cat([output, output, output, output], dim=1) 393 | output = self.pixelshuffle(output) 394 | return self.conv(output) 395 | 396 | 397 | class ConditionalResidualBlock(nn.Module): 398 | def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(), 399 | normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None): 400 | super().__init__() 401 | self.non_linearity = act 402 | self.input_dim = input_dim 403 | self.output_dim = output_dim 404 | self.resample = resample 405 | self.normalization = normalization 406 | if resample == 'down': 407 | if dilation > 1: 408 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 409 | self.normalize2 = normalization(input_dim, num_classes) 410 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 411 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 412 | else: 413 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 414 | self.normalize2 = normalization(input_dim, num_classes) 415 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 416 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 417 | 418 | elif resample is None: 419 | if dilation > 1: 420 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 421 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 422 | self.normalize2 = normalization(output_dim, num_classes) 423 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 424 | else: 425 | conv_shortcut = nn.Conv2d 426 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 427 | self.normalize2 = normalization(output_dim, num_classes) 428 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 429 | else: 430 | raise Exception('invalid resample value') 431 | 432 | if output_dim != input_dim or resample is not None: 433 | self.shortcut = conv_shortcut(input_dim, output_dim) 434 | 435 | self.normalize1 = normalization(input_dim, num_classes) 436 | 437 | def forward(self, x, y): 438 | output = self.normalize1(x, y) 439 | output = self.non_linearity(output) 440 | output = self.conv1(output) 441 | output = self.normalize2(output, y) 442 | output = self.non_linearity(output) 443 | output = self.conv2(output) 444 | 445 | if self.output_dim == self.input_dim and self.resample is None: 446 | shortcut = x 447 | else: 448 | shortcut = self.shortcut(x) 449 | 450 | return shortcut + output 451 | 452 | 453 | class ResidualBlock(nn.Module): 454 | def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), 455 | normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1): 456 | super().__init__() 457 | self.non_linearity = act 458 | self.input_dim = input_dim 459 | self.output_dim = output_dim 460 | self.resample = resample 461 | self.normalization = normalization 462 | if resample == 'down': 463 | if dilation > 1: 464 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 465 | self.normalize2 = normalization(input_dim) 466 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 467 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 468 | else: 469 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 470 | self.normalize2 = normalization(input_dim) 471 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 472 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 473 | 474 | elif resample is None: 475 | if dilation > 1: 476 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 477 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 478 | self.normalize2 = normalization(output_dim) 479 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 480 | else: 481 | # conv_shortcut = nn.Conv2d ### Something wierd here. 482 | conv_shortcut = partial(ncsn_conv1x1) 483 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 484 | self.normalize2 = normalization(output_dim) 485 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 486 | else: 487 | raise Exception('invalid resample value') 488 | 489 | if output_dim != input_dim or resample is not None: 490 | self.shortcut = conv_shortcut(input_dim, output_dim) 491 | 492 | self.normalize1 = normalization(input_dim) 493 | 494 | def forward(self, x): 495 | output = self.normalize1(x) 496 | output = self.non_linearity(output) 497 | output = self.conv1(output) 498 | output = self.normalize2(output) 499 | output = self.non_linearity(output) 500 | output = self.conv2(output) 501 | 502 | if self.output_dim == self.input_dim and self.resample is None: 503 | shortcut = x 504 | else: 505 | shortcut = self.shortcut(x) 506 | 507 | return shortcut + output 508 | 509 | 510 | ########################################################################### 511 | # Functions below are ported over from the DDPM codebase: 512 | # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py 513 | ########################################################################### 514 | 515 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 516 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 517 | half_dim = embedding_dim // 2 518 | # magic number 10000 is from transformers 519 | emb = math.log(max_positions) / (half_dim - 1) 520 | # emb = math.log(2.) / (half_dim - 1) 521 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 522 | # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] 523 | # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] 524 | emb = timesteps.float()[:, None] * emb[None, :] 525 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 526 | if embedding_dim % 2 == 1: # zero pad 527 | emb = F.pad(emb, (0, 1), mode='constant') 528 | assert emb.shape == (timesteps.shape[0], embedding_dim) 529 | return emb 530 | 531 | 532 | def _einsum(a, b, c, x, y): 533 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 534 | return torch.einsum(einsum_str, x, y) 535 | 536 | 537 | def contract_inner(x, y): 538 | """tensordot(x, y, 1).""" 539 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 540 | y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) 541 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 542 | out_chars = x_chars[:-1] + y_chars[1:] 543 | return _einsum(x_chars, y_chars, out_chars, x, y) 544 | 545 | 546 | class NIN(nn.Module): 547 | def __init__(self, in_dim, num_units, init_scale=0.1): 548 | super().__init__() 549 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 550 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 551 | 552 | def forward(self, x): 553 | x = x.permute(0, 2, 3, 1) 554 | y = contract_inner(x, self.W) + self.b 555 | return y.permute(0, 3, 1, 2) 556 | 557 | 558 | class AttnBlock(nn.Module): 559 | """Channel-wise self-attention block.""" 560 | def __init__(self, channels): 561 | super().__init__() 562 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) 563 | self.NIN_0 = NIN(channels, channels) 564 | self.NIN_1 = NIN(channels, channels) 565 | self.NIN_2 = NIN(channels, channels) 566 | self.NIN_3 = NIN(channels, channels, init_scale=0.) 567 | 568 | def forward(self, x): 569 | B, C, H, W = x.shape 570 | h = self.GroupNorm_0(x) 571 | q = self.NIN_0(h) 572 | k = self.NIN_1(h) 573 | v = self.NIN_2(h) 574 | 575 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 576 | w = torch.reshape(w, (B, H, W, H * W)) 577 | w = F.softmax(w, dim=-1) 578 | w = torch.reshape(w, (B, H, W, H, W)) 579 | h = torch.einsum('bhwij,bcij->bchw', w, v) 580 | h = self.NIN_3(h) 581 | return x + h 582 | 583 | 584 | class Upsample(nn.Module): 585 | def __init__(self, channels, with_conv=False): 586 | super().__init__() 587 | if with_conv: 588 | self.Conv_0 = ddpm_conv3x3(channels, channels) 589 | self.with_conv = with_conv 590 | 591 | def forward(self, x): 592 | B, C, H, W = x.shape 593 | h = F.interpolate(x, (H * 2, W * 2), mode='nearest') 594 | if self.with_conv: 595 | h = self.Conv_0(h) 596 | return h 597 | 598 | 599 | class Downsample(nn.Module): 600 | def __init__(self, channels, with_conv=False): 601 | super().__init__() 602 | if with_conv: 603 | self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) 604 | self.with_conv = with_conv 605 | 606 | def forward(self, x): 607 | B, C, H, W = x.shape 608 | # Emulate 'SAME' padding 609 | if self.with_conv: 610 | x = F.pad(x, (0, 1, 0, 1)) 611 | x = self.Conv_0(x) 612 | else: 613 | x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) 614 | 615 | assert x.shape == (B, C, H // 2, W // 2) 616 | return x 617 | 618 | 619 | class ResnetBlockDDPM(nn.Module): 620 | """The ResNet Blocks used in DDPM.""" 621 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1): 622 | super().__init__() 623 | if out_ch is None: 624 | out_ch = in_ch 625 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) 626 | self.act = act 627 | self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) 628 | if temb_dim is not None: 629 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 630 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 631 | nn.init.zeros_(self.Dense_0.bias) 632 | 633 | self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) 634 | self.Dropout_0 = nn.Dropout(dropout) 635 | self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) 636 | if in_ch != out_ch: 637 | if conv_shortcut: 638 | self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) 639 | else: 640 | self.NIN_0 = NIN(in_ch, out_ch) 641 | self.out_ch = out_ch 642 | self.in_ch = in_ch 643 | self.conv_shortcut = conv_shortcut 644 | 645 | def forward(self, x, temb=None): 646 | B, C, H, W = x.shape 647 | assert C == self.in_ch 648 | out_ch = self.out_ch if self.out_ch else self.in_ch 649 | h = self.act(self.GroupNorm_0(x)) 650 | h = self.Conv_0(h) 651 | # Add bias to each feature map conditioned on the time embedding 652 | if temb is not None: 653 | h += self.Dense_0(self.act(temb))[:, :, None, None] 654 | h = self.act(self.GroupNorm_1(h)) 655 | h = self.Dropout_0(h) 656 | h = self.Conv_1(h) 657 | if C != out_ch: 658 | if self.conv_shortcut: 659 | x = self.Conv_2(x) 660 | else: 661 | x = self.NIN_0(x) 662 | return x + h --------------------------------------------------------------------------------