├── LICENSE ├── README.md ├── demo ├── gradio_app.py ├── run_mar.ipynb └── visual.png ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py └── respace.py ├── engine_mar.py ├── environment.yaml ├── fid_stats └── adm_in256_stats.npz ├── main_cache.py ├── main_mar.py ├── models ├── diffloss.py ├── mar.py └── vae.py └── util ├── crop.py ├── download.py ├── loader.py ├── lr_sched.py └── misc.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tianhong Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autoregressive Image Generation without Vector Quantization
Official PyTorch Implementation 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2406.11838-b31b1b.svg)](https://arxiv.org/abs/2406.11838)  4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/autoregressive-image-generation-without/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=autoregressive-image-generation-without) 5 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb) 6 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-mar-yellow)](https://huggingface.co/jadechoghari/mar)  7 | 8 |

9 | 10 |

11 | 12 | This is a PyTorch/GPU implementation of the paper [Autoregressive Image Generation without Vector Quantization](https://arxiv.org/abs/2406.11838) (Neurips 2024 Spotlight Presentation): 13 | 14 | ``` 15 | @article{li2024autoregressive, 16 | title={Autoregressive Image Generation without Vector Quantization}, 17 | author={Li, Tianhong and Tian, Yonglong and Li, He and Deng, Mingyang and He, Kaiming}, 18 | journal={arXiv preprint arXiv:2406.11838}, 19 | year={2024} 20 | } 21 | ``` 22 | 23 | This repo contains: 24 | 25 | * 🪐 A simple PyTorch implementation of [MAR](models/mar.py) and [DiffLoss](models/diffloss.py) 26 | * ⚡️ Pre-trained class-conditional MAR models trained on ImageNet 256x256 27 | * 💥 A self-contained [Colab notebook](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb) for running various pre-trained MAR models 28 | * 🛸 An MAR+DiffLoss [training and evaluation script](main_mar.py) using PyTorch DDP 29 | * 🎉 Also checkout our [Hugging Face model cards](https://huggingface.co/jadechoghari/mar) and [Gradio demo](https://huggingface.co/spaces/jadechoghari/mar) (thanks [@jadechoghari](https://github.com/jadechoghari)). 30 | 31 | ## Preparation 32 | 33 | ### Dataset 34 | Download [ImageNet](http://image-net.org/download) dataset, and place it in your `IMAGENET_PATH`. 35 | 36 | ### Installation 37 | 38 | Download the code: 39 | ``` 40 | git clone https://github.com/LTH14/mar.git 41 | cd mar 42 | ``` 43 | 44 | A suitable [conda](https://conda.io/) environment named `mar` can be created and activated with: 45 | 46 | ``` 47 | conda env create -f environment.yaml 48 | conda activate mar 49 | ``` 50 | 51 | Download pre-trained VAE and MAR models: 52 | 53 | ``` 54 | python util/download.py 55 | ``` 56 | 57 | For convenience, our pre-trained MAR models can be downloaded directly here as well: 58 | 59 | | MAR Model | FID-50K | Inception Score | #params | 60 | |------------------------------------------------------------------------|---------|-----------------|---------| 61 | | [MAR-B](https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0) | 2.31 | 281.7 | 208M | 62 | | [MAR-L](https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0) | 1.78 | 296.0 | 479M | 63 | | [MAR-H](https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0) | 1.55 | 303.7 | 943M | 64 | 65 | ### (Optional) Caching VAE Latents 66 | 67 | Given that our data augmentation consists of simple center cropping and random flipping, 68 | the VAE latents can be pre-computed and saved to `CACHED_PATH` to save computations during MAR training: 69 | 70 | ``` 71 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 72 | main_cache.py \ 73 | --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 \ 74 | --batch_size 128 \ 75 | --data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH} 76 | ``` 77 | 78 | ## Usage 79 | 80 | ### Demo 81 | Run our interactive visualization [demo](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb) using Colab notebook! 82 | 83 | ### Local Gradio App 84 | 85 | ``` 86 | python demo/gradio_app.py 87 | ``` 88 | 89 | 90 | 91 | ### Training 92 | Script for the default setting (MAR-L, DiffLoss MLP with 3 blocks and a width of 1024 channels, 400 epochs): 93 | ``` 94 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 95 | main_mar.py \ 96 | --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \ 97 | --model mar_large --diffloss_d 3 --diffloss_w 1024 \ 98 | --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \ 99 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \ 100 | --data_path ${IMAGENET_PATH} 101 | ``` 102 | - Training time is ~1d7h on 32 H100 GPUs with `--batch_size 64`. 103 | - Add `--online_eval` to evaluate FID during training (every 40 epochs). 104 | - (Optional) To train with cached VAE latents, add `--use_cached --cached_path ${CACHED_PATH}` to the arguments. 105 | Training time with cached latents is ~1d11h on 16 H100 GPUs with `--batch_size 128` (nearly 2x faster than without caching). 106 | - (Optional) To save GPU memory during training by using gradient checkpointing (thanks to @Jiawei-Yang), add `--grad_checkpointing` to the arguments. 107 | Note that this may slightly reduce training speed. 108 | 109 | ### Evaluation (ImageNet 256x256) 110 | 111 | Evaluate MAR-B (DiffLoss MLP with 6 blocks and a width of 1024 channels, 800 epochs) with classifier-free guidance: 112 | ``` 113 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 114 | main_mar.py \ 115 | --model mar_base --diffloss_d 6 --diffloss_w 1024 \ 116 | --eval_bsz 256 --num_images 50000 \ 117 | --num_iter 256 --num_sampling_steps 100 --cfg 2.9 --cfg_schedule linear --temperature 1.0 \ 118 | --output_dir pretrained_models/mar/mar_base \ 119 | --resume pretrained_models/mar/mar_base \ 120 | --data_path ${IMAGENET_PATH} --evaluate 121 | ``` 122 | 123 | Evaluate MAR-L (DiffLoss MLP with 8 blocks and a width of 1280 channels, 800 epochs) with classifier-free guidance: 124 | ``` 125 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 126 | main_mar.py \ 127 | --model mar_large --diffloss_d 8 --diffloss_w 1280 \ 128 | --eval_bsz 256 --num_images 50000 \ 129 | --num_iter 256 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \ 130 | --output_dir pretrained_models/mar/mar_large \ 131 | --resume pretrained_models/mar/mar_large \ 132 | --data_path ${IMAGENET_PATH} --evaluate 133 | ``` 134 | 135 | Evaluate MAR-H (DiffLoss MLP with 12 blocks and a width of 1536 channels, 800 epochs) with classifier-free guidance: 136 | ``` 137 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 138 | main_mar.py \ 139 | --model mar_huge --diffloss_d 12 --diffloss_w 1536 \ 140 | --eval_bsz 128 --num_images 50000 \ 141 | --num_iter 256 --num_sampling_steps 100 --cfg 3.2 --cfg_schedule linear --temperature 1.0 \ 142 | --output_dir pretrained_models/mar/mar_huge \ 143 | --resume pretrained_models/mar/mar_huge \ 144 | --data_path ${IMAGENET_PATH} --evaluate 145 | ``` 146 | 147 | - Set `--cfg 1.0 --temperature 0.95` to evaluate without classifier-free guidance. 148 | - Generation speed can be significantly increased by reducing the number of autoregressive iterations (e.g., `--num_iter 64`). 149 | 150 | ## Acknowledgements 151 | We thank Congyue Deng and Xinlei Chen for helpful discussion. We thank 152 | Google TPU Research Cloud (TRC) for granting us access to TPUs, and Google Cloud Platform for 153 | supporting GPU resources. 154 | 155 | A large portion of codes in this repo is based on [MAE](https://github.com/facebookresearch/mae), [MAGE](https://github.com/LTH14/mage) and [DiT](https://github.com/facebookresearch/DiT). 156 | 157 | ## Contact 158 | 159 | If you have any questions, feel free to contact me through email (tianhong@mit.edu). Enjoy! 160 | -------------------------------------------------------------------------------- /demo/gradio_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from diffusers import DiffusionPipeline 3 | import os 4 | import torch 5 | import shutil 6 | import spaces 7 | 8 | 9 | def find_cuda(): 10 | # Check if CUDA_HOME or CUDA_PATH environment variables are set 11 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 12 | 13 | if cuda_home and os.path.exists(cuda_home): 14 | return cuda_home 15 | 16 | # Search for the nvcc executable in the system's PATH 17 | nvcc_path = shutil.which('nvcc') 18 | 19 | if nvcc_path: 20 | # Remove the 'bin/nvcc' part to get the CUDA installation path 21 | cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) 22 | return cuda_path 23 | 24 | return None 25 | 26 | 27 | cuda_path = find_cuda() 28 | 29 | if cuda_path: 30 | print(f"CUDA installation found at: {cuda_path}") 31 | else: 32 | print("CUDA installation not found") 33 | 34 | # check if cuda is available 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | 37 | # load the pipeline/model 38 | pipeline = DiffusionPipeline.from_pretrained("jadechoghari/mar", trust_remote_code=True, 39 | custom_pipeline="jadechoghari/mar") 40 | 41 | 42 | # function that generates images 43 | @spaces.GPU 44 | def generate_image(seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule): 45 | generated_image = pipeline( 46 | model_type="mar_huge", # using mar_huge 47 | seed=seed, 48 | num_ar_steps=num_ar_steps, 49 | class_labels=[int(label.strip()) for label in class_labels.split(',')], 50 | cfg_scale=cfg_scale, 51 | cfg_schedule=cfg_schedule, 52 | output_dir="./images" 53 | ) 54 | return generated_image 55 | 56 | 57 | with gr.Blocks() as demo: 58 | gr.Markdown(""" 59 | # MAR Image Generation Demo 🚀 60 | 61 | Welcome to the demo for **MAR** (Masked Autoregressive Model), a novel approach to image generation that eliminates the need for vector quantization. MAR uses a diffusion process to generate images in a continuous-valued space, resulting in faster, more efficient, and higher-quality outputs. 62 | 63 | Simply adjust the parameters below to create your custom images in real-time. 64 | 65 | Make sure to provide valid **ImageNet class labels** to see the translation of text to image. For a complete list of ImageNet classes, check out [this reference](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/). 66 | 67 | For more details, visit the [GitHub repository](https://github.com/LTH14/mar). 68 | """) 69 | 70 | seed = gr.Number(value=0, label="Seed") 71 | num_ar_steps = gr.Slider(minimum=1, maximum=256, value=64, label="Number of AR Steps") 72 | class_labels = gr.Textbox(value="207, 360, 388, 113, 355, 980, 323, 979", 73 | label="Class Labels (comma-separated ImageNet labels)") 74 | cfg_scale = gr.Slider(minimum=1, maximum=10, value=4, label="CFG Scale") 75 | cfg_schedule = gr.Dropdown(choices=["constant", "linear"], label="CFG Schedule", value="constant") 76 | 77 | image_output = gr.Image(label="Generated Image") 78 | 79 | generate_button = gr.Button("Generate Image") 80 | 81 | # we link the button to the function and display the output 82 | generate_button.click(generate_image, inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule], 83 | outputs=image_output) 84 | 85 | gr.Interface( 86 | generate_image, 87 | inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule], 88 | outputs=image_output, 89 | ) 90 | 91 | demo.launch() -------------------------------------------------------------------------------- /demo/visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LTH14/mar/fe470ac24afbee924668d8c5c83e9fec60af3a73/demo/visual.png -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos 2 | # DiT: https://github.com/facebookresearch/DiT/diffusion 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | from . import gaussian_diffusion as gd 8 | from .respace import SpacedDiffusion, space_timesteps 9 | 10 | 11 | def create_diffusion( 12 | timestep_respacing, 13 | noise_schedule="linear", 14 | use_kl=False, 15 | sigma_small=False, 16 | predict_xstart=False, 17 | learn_sigma=True, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a Gaussian distribution discretizing to a 50 | given image. 51 | :param x: the target images. It is assumed that this was uint8 values, 52 | rescaled to the range [-1, 1]. 53 | :param means: the Gaussian mean Tensor. 54 | :param log_scales: the Gaussian log stddev Tensor. 55 | :return: a tensor like x of log probabilities (in nats). 56 | """ 57 | assert x.shape == means.shape == log_scales.shape 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 61 | cdf_plus = approx_standard_normal_cdf(plus_in) 62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 63 | cdf_min = approx_standard_normal_cdf(min_in) 64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 66 | cdf_delta = cdf_plus - cdf_min 67 | log_probs = th.where( 68 | x < -0.999, 69 | log_cdf_plus, 70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 71 | ) 72 | assert log_probs.shape == x.shape 73 | return log_probs 74 | -------------------------------------------------------------------------------- /diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch as th 11 | import enum 12 | 13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 14 | 15 | 16 | def mean_flat(tensor): 17 | """ 18 | Take the mean over all non-batch dimensions. 19 | """ 20 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 21 | 22 | 23 | class ModelMeanType(enum.Enum): 24 | """ 25 | Which type of output the model predicts. 26 | """ 27 | 28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 29 | START_X = enum.auto() # the model predicts x_0 30 | EPSILON = enum.auto() # the model predicts epsilon 31 | 32 | 33 | class ModelVarType(enum.Enum): 34 | """ 35 | What is used as the model's output variance. 36 | The LEARNED_RANGE option has been added to allow the model to predict 37 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 38 | """ 39 | 40 | LEARNED = enum.auto() 41 | FIXED_SMALL = enum.auto() 42 | FIXED_LARGE = enum.auto() 43 | LEARNED_RANGE = enum.auto() 44 | 45 | 46 | class LossType(enum.Enum): 47 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 48 | RESCALED_MSE = ( 49 | enum.auto() 50 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 51 | KL = enum.auto() # use the variational lower-bound 52 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 53 | 54 | def is_vb(self): 55 | return self == LossType.KL or self == LossType.RESCALED_KL 56 | 57 | 58 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 59 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 60 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 61 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 62 | return betas 63 | 64 | 65 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 66 | """ 67 | This is the deprecated API for creating beta schedules. 68 | See get_named_beta_schedule() for the new library of schedules. 69 | """ 70 | if beta_schedule == "quad": 71 | betas = ( 72 | np.linspace( 73 | beta_start ** 0.5, 74 | beta_end ** 0.5, 75 | num_diffusion_timesteps, 76 | dtype=np.float64, 77 | ) 78 | ** 2 79 | ) 80 | elif beta_schedule == "linear": 81 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 82 | elif beta_schedule == "warmup10": 83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 84 | elif beta_schedule == "warmup50": 85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 86 | elif beta_schedule == "const": 87 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 88 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 89 | betas = 1.0 / np.linspace( 90 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 91 | ) 92 | else: 93 | raise NotImplementedError(beta_schedule) 94 | assert betas.shape == (num_diffusion_timesteps,) 95 | return betas 96 | 97 | 98 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 99 | """ 100 | Get a pre-defined beta schedule for the given name. 101 | The beta schedule library consists of beta schedules which remain similar 102 | in the limit of num_diffusion_timesteps. 103 | Beta schedules may be added, but should not be removed or changed once 104 | they are committed to maintain backwards compatibility. 105 | """ 106 | if schedule_name == "linear": 107 | # Linear schedule from Ho et al, extended to work for any number of 108 | # diffusion steps. 109 | scale = 1000 / num_diffusion_timesteps 110 | return get_beta_schedule( 111 | "linear", 112 | beta_start=scale * 0.0001, 113 | beta_end=scale * 0.02, 114 | num_diffusion_timesteps=num_diffusion_timesteps, 115 | ) 116 | elif schedule_name == "cosine": 117 | return betas_for_alpha_bar( 118 | num_diffusion_timesteps, 119 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 120 | ) 121 | else: 122 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 123 | 124 | 125 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 126 | """ 127 | Create a beta schedule that discretizes the given alpha_t_bar function, 128 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 129 | :param num_diffusion_timesteps: the number of betas to produce. 130 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 131 | produces the cumulative product of (1-beta) up to that 132 | part of the diffusion process. 133 | :param max_beta: the maximum beta to use; use values lower than 1 to 134 | prevent singularities. 135 | """ 136 | betas = [] 137 | for i in range(num_diffusion_timesteps): 138 | t1 = i / num_diffusion_timesteps 139 | t2 = (i + 1) / num_diffusion_timesteps 140 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 141 | return np.array(betas) 142 | 143 | 144 | class GaussianDiffusion: 145 | """ 146 | Utilities for training and sampling diffusion models. 147 | Original ported from this codebase: 148 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 149 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 150 | starting at T and going to 1. 151 | """ 152 | 153 | def __init__( 154 | self, 155 | *, 156 | betas, 157 | model_mean_type, 158 | model_var_type, 159 | loss_type 160 | ): 161 | 162 | self.model_mean_type = model_mean_type 163 | self.model_var_type = model_var_type 164 | self.loss_type = loss_type 165 | 166 | # Use float64 for accuracy. 167 | betas = np.array(betas, dtype=np.float64) 168 | self.betas = betas 169 | assert len(betas.shape) == 1, "betas must be 1-D" 170 | assert (betas > 0).all() and (betas <= 1).all() 171 | 172 | self.num_timesteps = int(betas.shape[0]) 173 | 174 | alphas = 1.0 - betas 175 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 176 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 177 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 178 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 179 | 180 | # calculations for diffusion q(x_t | x_{t-1}) and others 181 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 182 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 183 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 184 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 185 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 186 | 187 | # calculations for posterior q(x_{t-1} | x_t, x_0) 188 | self.posterior_variance = ( 189 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 190 | ) 191 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 192 | self.posterior_log_variance_clipped = np.log( 193 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 194 | ) if len(self.posterior_variance) > 1 else np.array([]) 195 | 196 | self.posterior_mean_coef1 = ( 197 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 198 | ) 199 | self.posterior_mean_coef2 = ( 200 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 201 | ) 202 | 203 | def q_mean_variance(self, x_start, t): 204 | """ 205 | Get the distribution q(x_t | x_0). 206 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 207 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 208 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 209 | """ 210 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 211 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 212 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 213 | return mean, variance, log_variance 214 | 215 | def q_sample(self, x_start, t, noise=None): 216 | """ 217 | Diffuse the data for a given number of diffusion steps. 218 | In other words, sample from q(x_t | x_0). 219 | :param x_start: the initial data batch. 220 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 221 | :param noise: if specified, the split-out normal noise. 222 | :return: A noisy version of x_start. 223 | """ 224 | if noise is None: 225 | noise = th.randn_like(x_start) 226 | assert noise.shape == x_start.shape 227 | return ( 228 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 229 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 230 | ) 231 | 232 | def q_posterior_mean_variance(self, x_start, x_t, t): 233 | """ 234 | Compute the mean and variance of the diffusion posterior: 235 | q(x_{t-1} | x_t, x_0) 236 | """ 237 | assert x_start.shape == x_t.shape 238 | posterior_mean = ( 239 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 240 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 241 | ) 242 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 243 | posterior_log_variance_clipped = _extract_into_tensor( 244 | self.posterior_log_variance_clipped, t, x_t.shape 245 | ) 246 | assert ( 247 | posterior_mean.shape[0] 248 | == posterior_variance.shape[0] 249 | == posterior_log_variance_clipped.shape[0] 250 | == x_start.shape[0] 251 | ) 252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 253 | 254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 255 | """ 256 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 257 | the initial x, x_0. 258 | :param model: the model, which takes a signal and a batch of timesteps 259 | as input. 260 | :param x: the [N x C x ...] tensor at time t. 261 | :param t: a 1-D Tensor of timesteps. 262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 263 | :param denoised_fn: if not None, a function which applies to the 264 | x_start prediction before it is used to sample. Applies before 265 | clip_denoised. 266 | :param model_kwargs: if not None, a dict of extra keyword arguments to 267 | pass to the model. This can be used for conditioning. 268 | :return: a dict with the following keys: 269 | - 'mean': the model mean output. 270 | - 'variance': the model variance output. 271 | - 'log_variance': the log of 'variance'. 272 | - 'pred_xstart': the prediction for x_0. 273 | """ 274 | if model_kwargs is None: 275 | model_kwargs = {} 276 | 277 | B, C = x.shape[:2] 278 | assert t.shape == (B,) 279 | model_output = model(x, t, **model_kwargs) 280 | if isinstance(model_output, tuple): 281 | model_output, extra = model_output 282 | else: 283 | extra = None 284 | 285 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 286 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 287 | model_output, model_var_values = th.split(model_output, C, dim=1) 288 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 289 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 290 | # The model_var_values is [-1, 1] for [min_var, max_var]. 291 | frac = (model_var_values + 1) / 2 292 | model_log_variance = frac * max_log + (1 - frac) * min_log 293 | model_variance = th.exp(model_log_variance) 294 | else: 295 | model_variance, model_log_variance = { 296 | # for fixedlarge, we set the initial (log-)variance like so 297 | # to get a better decoder log likelihood. 298 | ModelVarType.FIXED_LARGE: ( 299 | np.append(self.posterior_variance[1], self.betas[1:]), 300 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 301 | ), 302 | ModelVarType.FIXED_SMALL: ( 303 | self.posterior_variance, 304 | self.posterior_log_variance_clipped, 305 | ), 306 | }[self.model_var_type] 307 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 308 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 309 | 310 | def process_xstart(x): 311 | if denoised_fn is not None: 312 | x = denoised_fn(x) 313 | if clip_denoised: 314 | return x.clamp(-1, 1) 315 | return x 316 | 317 | if self.model_mean_type == ModelMeanType.START_X: 318 | pred_xstart = process_xstart(model_output) 319 | else: 320 | pred_xstart = process_xstart( 321 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 322 | ) 323 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 324 | 325 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 326 | return { 327 | "mean": model_mean, 328 | "variance": model_variance, 329 | "log_variance": model_log_variance, 330 | "pred_xstart": pred_xstart, 331 | "extra": extra, 332 | } 333 | 334 | def _predict_xstart_from_eps(self, x_t, t, eps): 335 | assert x_t.shape == eps.shape 336 | return ( 337 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 338 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 339 | ) 340 | 341 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 342 | return ( 343 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 344 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 345 | 346 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 347 | """ 348 | Compute the mean for the previous step, given a function cond_fn that 349 | computes the gradient of a conditional log probability with respect to 350 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 351 | condition on y. 352 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 353 | """ 354 | gradient = cond_fn(x, t, **model_kwargs) 355 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 356 | return new_mean 357 | 358 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 359 | """ 360 | Compute what the p_mean_variance output would have been, should the 361 | model's score function be conditioned by cond_fn. 362 | See condition_mean() for details on cond_fn. 363 | Unlike condition_mean(), this instead uses the conditioning strategy 364 | from Song et al (2020). 365 | """ 366 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 367 | 368 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 369 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 370 | 371 | out = p_mean_var.copy() 372 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 373 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 374 | return out 375 | 376 | def p_sample( 377 | self, 378 | model, 379 | x, 380 | t, 381 | clip_denoised=True, 382 | denoised_fn=None, 383 | cond_fn=None, 384 | model_kwargs=None, 385 | temperature=1.0 386 | ): 387 | """ 388 | Sample x_{t-1} from the model at the given timestep. 389 | :param model: the model to sample from. 390 | :param x: the current tensor at x_{t-1}. 391 | :param t: the value of t, starting at 0 for the first diffusion step. 392 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 393 | :param denoised_fn: if not None, a function which applies to the 394 | x_start prediction before it is used to sample. 395 | :param cond_fn: if not None, this is a gradient function that acts 396 | similarly to the model. 397 | :param model_kwargs: if not None, a dict of extra keyword arguments to 398 | pass to the model. This can be used for conditioning. 399 | :param temperature: temperature scaling during Diff Loss sampling. 400 | :return: a dict containing the following keys: 401 | - 'sample': a random sample from the model. 402 | - 'pred_xstart': a prediction of x_0. 403 | """ 404 | out = self.p_mean_variance( 405 | model, 406 | x, 407 | t, 408 | clip_denoised=clip_denoised, 409 | denoised_fn=denoised_fn, 410 | model_kwargs=model_kwargs, 411 | ) 412 | noise = th.randn_like(x) 413 | nonzero_mask = ( 414 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 415 | ) # no noise when t == 0 416 | if cond_fn is not None: 417 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 418 | # scale the noise by temperature 419 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature 420 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 421 | 422 | def p_sample_loop( 423 | self, 424 | model, 425 | shape, 426 | noise=None, 427 | clip_denoised=True, 428 | denoised_fn=None, 429 | cond_fn=None, 430 | model_kwargs=None, 431 | device=None, 432 | progress=False, 433 | temperature=1.0, 434 | ): 435 | """ 436 | Generate samples from the model. 437 | :param model: the model module. 438 | :param shape: the shape of the samples, (N, C, H, W). 439 | :param noise: if specified, the noise from the encoder to sample. 440 | Should be of the same shape as `shape`. 441 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 442 | :param denoised_fn: if not None, a function which applies to the 443 | x_start prediction before it is used to sample. 444 | :param cond_fn: if not None, this is a gradient function that acts 445 | similarly to the model. 446 | :param model_kwargs: if not None, a dict of extra keyword arguments to 447 | pass to the model. This can be used for conditioning. 448 | :param device: if specified, the device to create the samples on. 449 | If not specified, use a model parameter's device. 450 | :param progress: if True, show a tqdm progress bar. 451 | :param temperature: temperature scaling during Diff Loss sampling. 452 | :return: a non-differentiable batch of samples. 453 | """ 454 | final = None 455 | for sample in self.p_sample_loop_progressive( 456 | model, 457 | shape, 458 | noise=noise, 459 | clip_denoised=clip_denoised, 460 | denoised_fn=denoised_fn, 461 | cond_fn=cond_fn, 462 | model_kwargs=model_kwargs, 463 | device=device, 464 | progress=progress, 465 | temperature=temperature, 466 | ): 467 | final = sample 468 | return final["sample"] 469 | 470 | def p_sample_loop_progressive( 471 | self, 472 | model, 473 | shape, 474 | noise=None, 475 | clip_denoised=True, 476 | denoised_fn=None, 477 | cond_fn=None, 478 | model_kwargs=None, 479 | device=None, 480 | progress=False, 481 | temperature=1.0, 482 | ): 483 | """ 484 | Generate samples from the model and yield intermediate samples from 485 | each timestep of diffusion. 486 | Arguments are the same as p_sample_loop(). 487 | Returns a generator over dicts, where each dict is the return value of 488 | p_sample(). 489 | """ 490 | assert isinstance(shape, (tuple, list)) 491 | if noise is not None: 492 | img = noise 493 | else: 494 | img = th.randn(*shape).cuda() 495 | indices = list(range(self.num_timesteps))[::-1] 496 | 497 | if progress: 498 | # Lazy import so that we don't depend on tqdm. 499 | from tqdm.auto import tqdm 500 | 501 | indices = tqdm(indices) 502 | 503 | for i in indices: 504 | t = th.tensor([i] * shape[0]).cuda() 505 | with th.no_grad(): 506 | out = self.p_sample( 507 | model, 508 | img, 509 | t, 510 | clip_denoised=clip_denoised, 511 | denoised_fn=denoised_fn, 512 | cond_fn=cond_fn, 513 | model_kwargs=model_kwargs, 514 | temperature=temperature, 515 | ) 516 | yield out 517 | img = out["sample"] 518 | 519 | def ddim_sample( 520 | self, 521 | model, 522 | x, 523 | t, 524 | clip_denoised=True, 525 | denoised_fn=None, 526 | cond_fn=None, 527 | model_kwargs=None, 528 | eta=0.0, 529 | ): 530 | """ 531 | Sample x_{t-1} from the model using DDIM. 532 | Same usage as p_sample(). 533 | """ 534 | out = self.p_mean_variance( 535 | model, 536 | x, 537 | t, 538 | clip_denoised=clip_denoised, 539 | denoised_fn=denoised_fn, 540 | model_kwargs=model_kwargs, 541 | ) 542 | if cond_fn is not None: 543 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 544 | 545 | # Usually our model outputs epsilon, but we re-derive it 546 | # in case we used x_start or x_prev prediction. 547 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 548 | 549 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 550 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 551 | sigma = ( 552 | eta 553 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 554 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 555 | ) 556 | # Equation 12. 557 | noise = th.randn_like(x) 558 | mean_pred = ( 559 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 560 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 561 | ) 562 | nonzero_mask = ( 563 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 564 | ) # no noise when t == 0 565 | sample = mean_pred + nonzero_mask * sigma * noise 566 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 567 | 568 | def ddim_reverse_sample( 569 | self, 570 | model, 571 | x, 572 | t, 573 | clip_denoised=True, 574 | denoised_fn=None, 575 | cond_fn=None, 576 | model_kwargs=None, 577 | eta=0.0, 578 | ): 579 | """ 580 | Sample x_{t+1} from the model using DDIM reverse ODE. 581 | """ 582 | assert eta == 0.0, "Reverse ODE only for deterministic path" 583 | out = self.p_mean_variance( 584 | model, 585 | x, 586 | t, 587 | clip_denoised=clip_denoised, 588 | denoised_fn=denoised_fn, 589 | model_kwargs=model_kwargs, 590 | ) 591 | if cond_fn is not None: 592 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 593 | # Usually our model outputs epsilon, but we re-derive it 594 | # in case we used x_start or x_prev prediction. 595 | eps = ( 596 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 597 | - out["pred_xstart"] 598 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 599 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 600 | 601 | # Equation 12. reversed 602 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 603 | 604 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 605 | 606 | def ddim_sample_loop( 607 | self, 608 | model, 609 | shape, 610 | noise=None, 611 | clip_denoised=True, 612 | denoised_fn=None, 613 | cond_fn=None, 614 | model_kwargs=None, 615 | device=None, 616 | progress=False, 617 | eta=0.0, 618 | ): 619 | """ 620 | Generate samples from the model using DDIM. 621 | Same usage as p_sample_loop(). 622 | """ 623 | final = None 624 | for sample in self.ddim_sample_loop_progressive( 625 | model, 626 | shape, 627 | noise=noise, 628 | clip_denoised=clip_denoised, 629 | denoised_fn=denoised_fn, 630 | cond_fn=cond_fn, 631 | model_kwargs=model_kwargs, 632 | device=device, 633 | progress=progress, 634 | eta=eta, 635 | ): 636 | final = sample 637 | return final["sample"] 638 | 639 | def ddim_sample_loop_progressive( 640 | self, 641 | model, 642 | shape, 643 | noise=None, 644 | clip_denoised=True, 645 | denoised_fn=None, 646 | cond_fn=None, 647 | model_kwargs=None, 648 | device=None, 649 | progress=False, 650 | eta=0.0, 651 | ): 652 | """ 653 | Use DDIM to sample from the model and yield intermediate samples from 654 | each timestep of DDIM. 655 | Same usage as p_sample_loop_progressive(). 656 | """ 657 | assert isinstance(shape, (tuple, list)) 658 | if noise is not None: 659 | img = noise 660 | else: 661 | img = th.randn(*shape).cuda() 662 | indices = list(range(self.num_timesteps))[::-1] 663 | 664 | if progress: 665 | # Lazy import so that we don't depend on tqdm. 666 | from tqdm.auto import tqdm 667 | 668 | indices = tqdm(indices) 669 | 670 | for i in indices: 671 | t = th.tensor([i] * shape[0]).cuda() 672 | with th.no_grad(): 673 | out = self.ddim_sample( 674 | model, 675 | img, 676 | t, 677 | clip_denoised=clip_denoised, 678 | denoised_fn=denoised_fn, 679 | cond_fn=cond_fn, 680 | model_kwargs=model_kwargs, 681 | eta=eta, 682 | ) 683 | yield out 684 | img = out["sample"] 685 | 686 | def _vb_terms_bpd( 687 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 688 | ): 689 | """ 690 | Get a term for the variational lower-bound. 691 | The resulting units are bits (rather than nats, as one might expect). 692 | This allows for comparison to other papers. 693 | :return: a dict with the following keys: 694 | - 'output': a shape [N] tensor of NLLs or KLs. 695 | - 'pred_xstart': the x_0 predictions. 696 | """ 697 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 698 | x_start=x_start, x_t=x_t, t=t 699 | ) 700 | out = self.p_mean_variance( 701 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 702 | ) 703 | kl = normal_kl( 704 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 705 | ) 706 | kl = mean_flat(kl) / np.log(2.0) 707 | 708 | decoder_nll = -discretized_gaussian_log_likelihood( 709 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 710 | ) 711 | assert decoder_nll.shape == x_start.shape 712 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 713 | 714 | # At the first timestep return the decoder NLL, 715 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 716 | output = th.where((t == 0), decoder_nll, kl) 717 | return {"output": output, "pred_xstart": out["pred_xstart"]} 718 | 719 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 720 | """ 721 | Compute training losses for a single timestep. 722 | :param model: the model to evaluate loss on. 723 | :param x_start: the [N x C x ...] tensor of inputs. 724 | :param t: a batch of timestep indices. 725 | :param model_kwargs: if not None, a dict of extra keyword arguments to 726 | pass to the model. This can be used for conditioning. 727 | :param noise: if specified, the specific Gaussian noise to try to remove. 728 | :return: a dict with the key "loss" containing a tensor of shape [N]. 729 | Some mean or variance settings may also have other keys. 730 | """ 731 | if model_kwargs is None: 732 | model_kwargs = {} 733 | if noise is None: 734 | noise = th.randn_like(x_start) 735 | x_t = self.q_sample(x_start, t, noise=noise) 736 | 737 | terms = {} 738 | 739 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 740 | terms["loss"] = self._vb_terms_bpd( 741 | model=model, 742 | x_start=x_start, 743 | x_t=x_t, 744 | t=t, 745 | clip_denoised=False, 746 | model_kwargs=model_kwargs, 747 | )["output"] 748 | if self.loss_type == LossType.RESCALED_KL: 749 | terms["loss"] *= self.num_timesteps 750 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 751 | model_output = model(x_t, t, **model_kwargs) 752 | 753 | if self.model_var_type in [ 754 | ModelVarType.LEARNED, 755 | ModelVarType.LEARNED_RANGE, 756 | ]: 757 | B, C = x_t.shape[:2] 758 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 759 | model_output, model_var_values = th.split(model_output, C, dim=1) 760 | # Learn the variance using the variational bound, but don't let 761 | # it affect our mean prediction. 762 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 763 | terms["vb"] = self._vb_terms_bpd( 764 | model=lambda *args, r=frozen_out: r, 765 | x_start=x_start, 766 | x_t=x_t, 767 | t=t, 768 | clip_denoised=False, 769 | )["output"] 770 | if self.loss_type == LossType.RESCALED_MSE: 771 | # Divide by 1000 for equivalence with initial implementation. 772 | # Without a factor of 1/1000, the VB term hurts the MSE term. 773 | terms["vb"] *= self.num_timesteps / 1000.0 774 | 775 | target = { 776 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 777 | x_start=x_start, x_t=x_t, t=t 778 | )[0], 779 | ModelMeanType.START_X: x_start, 780 | ModelMeanType.EPSILON: noise, 781 | }[self.model_mean_type] 782 | assert model_output.shape == target.shape == x_start.shape 783 | terms["mse"] = mean_flat((target - model_output) ** 2) 784 | if "vb" in terms: 785 | terms["loss"] = terms["mse"] + terms["vb"] 786 | else: 787 | terms["loss"] = terms["mse"] 788 | else: 789 | raise NotImplementedError(self.loss_type) 790 | 791 | return terms 792 | 793 | def _prior_bpd(self, x_start): 794 | """ 795 | Get the prior KL term for the variational lower-bound, measured in 796 | bits-per-dim. 797 | This term can't be optimized, as it only depends on the encoder. 798 | :param x_start: the [N x C x ...] tensor of inputs. 799 | :return: a batch of [N] KL values (in bits), one per batch element. 800 | """ 801 | batch_size = x_start.shape[0] 802 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 803 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 804 | kl_prior = normal_kl( 805 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 806 | ) 807 | return mean_flat(kl_prior) / np.log(2.0) 808 | 809 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 810 | """ 811 | Compute the entire variational lower-bound, measured in bits-per-dim, 812 | as well as other related quantities. 813 | :param model: the model to evaluate loss on. 814 | :param x_start: the [N x C x ...] tensor of inputs. 815 | :param clip_denoised: if True, clip denoised samples. 816 | :param model_kwargs: if not None, a dict of extra keyword arguments to 817 | pass to the model. This can be used for conditioning. 818 | :return: a dict containing the following keys: 819 | - total_bpd: the total variational lower-bound, per batch element. 820 | - prior_bpd: the prior term in the lower-bound. 821 | - vb: an [N x T] tensor of terms in the lower-bound. 822 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 823 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 824 | """ 825 | device = x_start.device 826 | batch_size = x_start.shape[0] 827 | 828 | vb = [] 829 | xstart_mse = [] 830 | mse = [] 831 | for t in list(range(self.num_timesteps))[::-1]: 832 | t_batch = th.tensor([t] * batch_size, device=device) 833 | noise = th.randn_like(x_start) 834 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 835 | # Calculate VLB term at the current timestep 836 | with th.no_grad(): 837 | out = self._vb_terms_bpd( 838 | model, 839 | x_start=x_start, 840 | x_t=x_t, 841 | t=t_batch, 842 | clip_denoised=clip_denoised, 843 | model_kwargs=model_kwargs, 844 | ) 845 | vb.append(out["output"]) 846 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 847 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 848 | mse.append(mean_flat((eps - noise) ** 2)) 849 | 850 | vb = th.stack(vb, dim=1) 851 | xstart_mse = th.stack(xstart_mse, dim=1) 852 | mse = th.stack(mse, dim=1) 853 | 854 | prior_bpd = self._prior_bpd(x_start) 855 | total_bpd = vb.sum(dim=1) + prior_bpd 856 | return { 857 | "total_bpd": total_bpd, 858 | "prior_bpd": prior_bpd, 859 | "vb": vb, 860 | "xstart_mse": xstart_mse, 861 | "mse": mse, 862 | } 863 | 864 | 865 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 866 | """ 867 | Extract values from a 1-D numpy array for a batch of indices. 868 | :param arr: the 1-D numpy array. 869 | :param timesteps: a tensor of indices into the array to extract. 870 | :param broadcast_shape: a larger shape of K dimensions with the batch 871 | dimension equal to the length of timesteps. 872 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 873 | """ 874 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 875 | while len(res.shape) < len(broadcast_shape): 876 | res = res[..., None] 877 | return res + th.zeros(broadcast_shape, device=timesteps.device) 878 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /engine_mar.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | 5 | import torch 6 | 7 | import util.misc as misc 8 | import util.lr_sched as lr_sched 9 | from models.vae import DiagonalGaussianDistribution 10 | import torch_fidelity 11 | import shutil 12 | import cv2 13 | import numpy as np 14 | import os 15 | import copy 16 | import time 17 | 18 | 19 | def update_ema(target_params, source_params, rate=0.99): 20 | """ 21 | Update target parameters to be closer to those of source parameters using 22 | an exponential moving average. 23 | 24 | :param target_params: the target parameter sequence. 25 | :param source_params: the source parameter sequence. 26 | :param rate: the EMA rate (closer to 1 means slower). 27 | """ 28 | for targ, src in zip(target_params, source_params): 29 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 30 | 31 | 32 | def train_one_epoch(model, vae, 33 | model_params, ema_params, 34 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 35 | device: torch.device, epoch: int, loss_scaler, 36 | log_writer=None, 37 | args=None): 38 | model.train(True) 39 | metric_logger = misc.MetricLogger(delimiter=" ") 40 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 41 | header = 'Epoch: [{}]'.format(epoch) 42 | print_freq = 20 43 | 44 | optimizer.zero_grad() 45 | 46 | if log_writer is not None: 47 | print('log_dir: {}'.format(log_writer.log_dir)) 48 | 49 | for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 50 | 51 | # we use a per iteration (instead of per epoch) lr scheduler 52 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 53 | 54 | samples = samples.to(device, non_blocking=True) 55 | labels = labels.to(device, non_blocking=True) 56 | 57 | with torch.no_grad(): 58 | if args.use_cached: 59 | moments = samples 60 | posterior = DiagonalGaussianDistribution(moments) 61 | else: 62 | posterior = vae.encode(samples) 63 | 64 | # normalize the std of latent to be 1. Change it if you use a different tokenizer 65 | x = posterior.sample().mul_(0.2325) 66 | 67 | # forward 68 | with torch.cuda.amp.autocast(): 69 | loss = model(x, labels) 70 | 71 | loss_value = loss.item() 72 | 73 | if not math.isfinite(loss_value): 74 | print("Loss is {}, stopping training".format(loss_value)) 75 | sys.exit(1) 76 | 77 | loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True) 78 | optimizer.zero_grad() 79 | 80 | torch.cuda.synchronize() 81 | 82 | update_ema(ema_params, model_params, rate=args.ema_rate) 83 | 84 | metric_logger.update(loss=loss_value) 85 | 86 | lr = optimizer.param_groups[0]["lr"] 87 | metric_logger.update(lr=lr) 88 | 89 | loss_value_reduce = misc.all_reduce_mean(loss_value) 90 | if log_writer is not None: 91 | """ We use epoch_1000x as the x-axis in tensorboard. 92 | This calibrates different curves when batch size changes. 93 | """ 94 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 95 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 96 | log_writer.add_scalar('lr', lr, epoch_1000x) 97 | 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0, 105 | use_ema=True): 106 | model_without_ddp.eval() 107 | num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1 108 | save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter, 109 | args.num_sampling_steps, 110 | args.temperature, 111 | args.cfg_schedule, 112 | cfg, 113 | args.num_images)) 114 | if use_ema: 115 | save_folder = save_folder + "_ema" 116 | if args.evaluate: 117 | save_folder = save_folder + "_evaluate" 118 | print("Save to:", save_folder) 119 | if misc.get_rank() == 0: 120 | if not os.path.exists(save_folder): 121 | os.makedirs(save_folder) 122 | 123 | # switch to ema params 124 | if use_ema: 125 | model_state_dict = copy.deepcopy(model_without_ddp.state_dict()) 126 | ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) 127 | for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): 128 | assert name in ema_state_dict 129 | ema_state_dict[name] = ema_params[i] 130 | print("Switch to ema") 131 | model_without_ddp.load_state_dict(ema_state_dict) 132 | 133 | class_num = args.class_num 134 | assert args.num_images % class_num == 0 # number of images per class must be the same 135 | class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num) 136 | class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)]) 137 | world_size = misc.get_world_size() 138 | local_rank = misc.get_rank() 139 | used_time = 0 140 | gen_img_cnt = 0 141 | 142 | for i in range(num_steps): 143 | print("Generation step {}/{}".format(i, num_steps)) 144 | 145 | labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size: 146 | world_size * batch_size * i + (local_rank + 1) * batch_size] 147 | labels_gen = torch.Tensor(labels_gen).long().cuda() 148 | 149 | 150 | torch.cuda.synchronize() 151 | start_time = time.time() 152 | 153 | # generation 154 | with torch.no_grad(): 155 | with torch.cuda.amp.autocast(): 156 | sampled_tokens = model_without_ddp.sample_tokens(bsz=batch_size, num_iter=args.num_iter, cfg=cfg, 157 | cfg_schedule=args.cfg_schedule, labels=labels_gen, 158 | temperature=args.temperature) 159 | sampled_images = vae.decode(sampled_tokens / 0.2325) 160 | 161 | # measure speed after the first generation batch 162 | if i >= 1: 163 | torch.cuda.synchronize() 164 | used_time += time.time() - start_time 165 | gen_img_cnt += batch_size 166 | print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt)) 167 | 168 | torch.distributed.barrier() 169 | sampled_images = sampled_images.detach().cpu() 170 | sampled_images = (sampled_images + 1) / 2 171 | 172 | # distributed save 173 | for b_id in range(sampled_images.size(0)): 174 | img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id 175 | if img_id >= args.num_images: 176 | break 177 | gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255)) 178 | gen_img = gen_img.astype(np.uint8)[:, :, ::-1] 179 | cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img) 180 | 181 | torch.distributed.barrier() 182 | time.sleep(10) 183 | 184 | # back to no ema 185 | if use_ema: 186 | print("Switch back from ema") 187 | model_without_ddp.load_state_dict(model_state_dict) 188 | 189 | # compute FID and IS 190 | if log_writer is not None: 191 | if args.img_size == 256: 192 | input2 = None 193 | fid_statistics_file = 'fid_stats/adm_in256_stats.npz' 194 | else: 195 | raise NotImplementedError 196 | metrics_dict = torch_fidelity.calculate_metrics( 197 | input1=save_folder, 198 | input2=input2, 199 | fid_statistics_file=fid_statistics_file, 200 | cuda=True, 201 | isc=True, 202 | fid=True, 203 | kid=False, 204 | prc=False, 205 | verbose=False, 206 | ) 207 | fid = metrics_dict['frechet_inception_distance'] 208 | inception_score = metrics_dict['inception_score_mean'] 209 | postfix = "" 210 | if use_ema: 211 | postfix = postfix + "_ema" 212 | if not cfg == 1.0: 213 | postfix = postfix + "_cfg{}".format(cfg) 214 | log_writer.add_scalar('fid{}'.format(postfix), fid, epoch) 215 | log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch) 216 | print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score)) 217 | # remove temporal saving folder 218 | shutil.rmtree(save_folder) 219 | 220 | torch.distributed.barrier() 221 | time.sleep(10) 222 | 223 | 224 | def cache_latents(vae, 225 | data_loader: Iterable, 226 | device: torch.device, 227 | args=None): 228 | metric_logger = misc.MetricLogger(delimiter=" ") 229 | header = 'Caching: ' 230 | print_freq = 20 231 | 232 | for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 233 | 234 | samples = samples.to(device, non_blocking=True) 235 | 236 | with torch.no_grad(): 237 | posterior = vae.encode(samples) 238 | moments = posterior.parameters 239 | posterior_flip = vae.encode(samples.flip(dims=[3])) 240 | moments_flip = posterior_flip.parameters 241 | 242 | for i, path in enumerate(paths): 243 | save_path = os.path.join(args.cached_path, path + '.npz') 244 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 245 | np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy()) 246 | 247 | if misc.is_dist_avail_and_initialized(): 248 | torch.cuda.synchronize() 249 | 250 | return 251 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mar 2 | channels: 3 | - pytorch 4 | - defaults 5 | - nvidia 6 | dependencies: 7 | - python=3.8.5 8 | - pip=20.3 9 | - pytorch-cuda=11.8 10 | - pytorch=2.2.2 11 | - torchvision=0.17.2 12 | - numpy=1.22 13 | - pip: 14 | - opencv-python==4.1.2.30 15 | - timm==0.9.12 16 | - tensorboard==2.10.0 17 | - scipy==1.9.1 18 | - gdown==5.2.0 19 | - -e git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity 20 | -------------------------------------------------------------------------------- /fid_stats/adm_in256_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LTH14/mar/fe470ac24afbee924668d8c5c83e9fec60af3a73/fid_stats/adm_in256_stats.npz -------------------------------------------------------------------------------- /main_cache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | import time 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torchvision.transforms as transforms 12 | 13 | import util.misc as misc 14 | from util.loader import ImageFolderWithFilename 15 | 16 | from models.vae import AutoencoderKL 17 | from engine_mar import cache_latents 18 | 19 | from util.crop import center_crop_arr 20 | 21 | 22 | def get_args_parser(): 23 | parser = argparse.ArgumentParser('Cache VAE latents', add_help=False) 24 | parser.add_argument('--batch_size', default=128, type=int, 25 | help='Batch size per GPU (effective batch size is batch_size * # gpus') 26 | 27 | # VAE parameters 28 | parser.add_argument('--img_size', default=256, type=int, 29 | help='images input size') 30 | parser.add_argument('--vae_path', default="pretrained_models/vae/kl16.ckpt", type=str, 31 | help='images input size') 32 | parser.add_argument('--vae_embed_dim', default=16, type=int, 33 | help='vae output embedding dimension') 34 | # Dataset parameters 35 | parser.add_argument('--data_path', default='./data/imagenet', type=str, 36 | help='dataset path') 37 | parser.add_argument('--device', default='cuda', 38 | help='device to use for training / testing') 39 | parser.add_argument('--seed', default=0, type=int) 40 | 41 | parser.add_argument('--num_workers', default=10, type=int) 42 | parser.add_argument('--pin_mem', action='store_true', 43 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 44 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 45 | parser.set_defaults(pin_mem=True) 46 | 47 | # distributed training parameters 48 | parser.add_argument('--world_size', default=1, type=int, 49 | help='number of distributed processes') 50 | parser.add_argument('--local_rank', default=-1, type=int) 51 | parser.add_argument('--dist_on_itp', action='store_true') 52 | parser.add_argument('--dist_url', default='env://', 53 | help='url used to set up distributed training') 54 | 55 | # caching latents 56 | parser.add_argument('--cached_path', default='', help='path to cached latents') 57 | 58 | return parser 59 | 60 | 61 | def main(args): 62 | misc.init_distributed_mode(args) 63 | 64 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 65 | print("{}".format(args).replace(', ', ',\n')) 66 | 67 | device = torch.device(args.device) 68 | 69 | # fix the seed for reproducibility 70 | seed = args.seed + misc.get_rank() 71 | torch.manual_seed(seed) 72 | np.random.seed(seed) 73 | 74 | cudnn.benchmark = True 75 | 76 | num_tasks = misc.get_world_size() 77 | global_rank = misc.get_rank() 78 | 79 | # augmentation following DiT and ADM 80 | transform_train = transforms.Compose([ 81 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 82 | # transforms.RandomHorizontalFlip(), 83 | transforms.ToTensor(), 84 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 85 | ]) 86 | 87 | dataset_train = ImageFolderWithFilename(os.path.join(args.data_path, 'train'), transform=transform_train) 88 | print(dataset_train) 89 | 90 | sampler_train = torch.utils.data.DistributedSampler( 91 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=False, 92 | ) 93 | print("Sampler_train = %s" % str(sampler_train)) 94 | 95 | data_loader_train = torch.utils.data.DataLoader( 96 | dataset_train, sampler=sampler_train, 97 | batch_size=args.batch_size, 98 | num_workers=args.num_workers, 99 | pin_memory=args.pin_mem, 100 | drop_last=False, # Don't drop in cache 101 | ) 102 | 103 | # define the vae 104 | vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval() 105 | 106 | # training 107 | print(f"Start caching VAE latents") 108 | start_time = time.time() 109 | cache_latents( 110 | vae, 111 | data_loader_train, 112 | device, 113 | args=args 114 | ) 115 | total_time = time.time() - start_time 116 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 117 | print('Caching time {}'.format(total_time_str)) 118 | 119 | 120 | if __name__ == '__main__': 121 | args = get_args_parser() 122 | args = args.parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /main_mar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | import time 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | 14 | from util.crop import center_crop_arr 15 | import util.misc as misc 16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 17 | from util.loader import CachedFolder 18 | 19 | from models.vae import AutoencoderKL 20 | from models import mar 21 | from engine_mar import train_one_epoch, evaluate 22 | import copy 23 | 24 | 25 | def get_args_parser(): 26 | parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False) 27 | parser.add_argument('--batch_size', default=16, type=int, 28 | help='Batch size per GPU (effective batch size is batch_size * # gpus') 29 | parser.add_argument('--epochs', default=400, type=int) 30 | 31 | # Model parameters 32 | parser.add_argument('--model', default='mar_large', type=str, metavar='MODEL', 33 | help='Name of model to train') 34 | 35 | # VAE parameters 36 | parser.add_argument('--img_size', default=256, type=int, 37 | help='images input size') 38 | parser.add_argument('--vae_path', default="pretrained_models/vae/kl16.ckpt", type=str, 39 | help='images input size') 40 | parser.add_argument('--vae_embed_dim', default=16, type=int, 41 | help='vae output embedding dimension') 42 | parser.add_argument('--vae_stride', default=16, type=int, 43 | help='tokenizer stride, default use KL16') 44 | parser.add_argument('--patch_size', default=1, type=int, 45 | help='number of tokens to group as a patch.') 46 | 47 | # Generation parameters 48 | parser.add_argument('--num_iter', default=64, type=int, 49 | help='number of autoregressive iterations to generate an image') 50 | parser.add_argument('--num_images', default=50000, type=int, 51 | help='number of images to generate') 52 | parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance") 53 | parser.add_argument('--cfg_schedule', default="linear", type=str) 54 | parser.add_argument('--label_drop_prob', default=0.1, type=float) 55 | parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency') 56 | parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency') 57 | parser.add_argument('--online_eval', action='store_true') 58 | parser.add_argument('--evaluate', action='store_true') 59 | parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size') 60 | 61 | # Optimizer parameters 62 | parser.add_argument('--weight_decay', type=float, default=0.02, 63 | help='weight decay (default: 0.02)') 64 | 65 | parser.add_argument('--grad_checkpointing', action='store_true') 66 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 67 | help='learning rate (absolute lr)') 68 | parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', 69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 70 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 71 | help='lower lr bound for cyclic schedulers that hit 0') 72 | parser.add_argument('--lr_schedule', type=str, default='constant', 73 | help='learning rate schedule') 74 | parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N', 75 | help='epochs to warmup LR') 76 | parser.add_argument('--ema_rate', default=0.9999, type=float) 77 | 78 | # MAR params 79 | parser.add_argument('--mask_ratio_min', type=float, default=0.7, 80 | help='Minimum mask ratio') 81 | parser.add_argument('--grad_clip', type=float, default=3.0, 82 | help='Gradient clip') 83 | parser.add_argument('--attn_dropout', type=float, default=0.1, 84 | help='attention dropout') 85 | parser.add_argument('--proj_dropout', type=float, default=0.1, 86 | help='projection dropout') 87 | parser.add_argument('--buffer_size', type=int, default=64) 88 | 89 | # Diffusion Loss params 90 | parser.add_argument('--diffloss_d', type=int, default=12) 91 | parser.add_argument('--diffloss_w', type=int, default=1536) 92 | parser.add_argument('--num_sampling_steps', type=str, default="100") 93 | parser.add_argument('--diffusion_batch_mul', type=int, default=1) 94 | parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature') 95 | 96 | # Dataset parameters 97 | parser.add_argument('--data_path', default='./data/imagenet', type=str, 98 | help='dataset path') 99 | parser.add_argument('--class_num', default=1000, type=int) 100 | 101 | parser.add_argument('--output_dir', default='./output_dir', 102 | help='path where to save, empty for no saving') 103 | parser.add_argument('--log_dir', default='./output_dir', 104 | help='path where to tensorboard log') 105 | parser.add_argument('--device', default='cuda', 106 | help='device to use for training / testing') 107 | parser.add_argument('--seed', default=1, type=int) 108 | parser.add_argument('--resume', default='', 109 | help='resume from checkpoint') 110 | 111 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 112 | help='start epoch') 113 | parser.add_argument('--num_workers', default=10, type=int) 114 | parser.add_argument('--pin_mem', action='store_true', 115 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 116 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 117 | parser.set_defaults(pin_mem=True) 118 | 119 | # distributed training parameters 120 | parser.add_argument('--world_size', default=1, type=int, 121 | help='number of distributed processes') 122 | parser.add_argument('--local_rank', default=-1, type=int) 123 | parser.add_argument('--dist_on_itp', action='store_true') 124 | parser.add_argument('--dist_url', default='env://', 125 | help='url used to set up distributed training') 126 | 127 | # caching latents 128 | parser.add_argument('--use_cached', action='store_true', dest='use_cached', 129 | help='Use cached latents') 130 | parser.set_defaults(use_cached=False) 131 | parser.add_argument('--cached_path', default='', help='path to cached latents') 132 | 133 | return parser 134 | 135 | 136 | def main(args): 137 | misc.init_distributed_mode(args) 138 | 139 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 140 | print("{}".format(args).replace(', ', ',\n')) 141 | 142 | device = torch.device(args.device) 143 | 144 | # fix the seed for reproducibility 145 | seed = args.seed + misc.get_rank() 146 | torch.manual_seed(seed) 147 | np.random.seed(seed) 148 | 149 | cudnn.benchmark = True 150 | 151 | num_tasks = misc.get_world_size() 152 | global_rank = misc.get_rank() 153 | 154 | if global_rank == 0 and args.log_dir is not None: 155 | os.makedirs(args.log_dir, exist_ok=True) 156 | log_writer = SummaryWriter(log_dir=args.log_dir) 157 | else: 158 | log_writer = None 159 | 160 | # augmentation following DiT and ADM 161 | transform_train = transforms.Compose([ 162 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 163 | transforms.RandomHorizontalFlip(), 164 | transforms.ToTensor(), 165 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 166 | ]) 167 | 168 | if args.use_cached: 169 | dataset_train = CachedFolder(args.cached_path) 170 | else: 171 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 172 | print(dataset_train) 173 | 174 | sampler_train = torch.utils.data.DistributedSampler( 175 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 176 | ) 177 | print("Sampler_train = %s" % str(sampler_train)) 178 | 179 | data_loader_train = torch.utils.data.DataLoader( 180 | dataset_train, sampler=sampler_train, 181 | batch_size=args.batch_size, 182 | num_workers=args.num_workers, 183 | pin_memory=args.pin_mem, 184 | drop_last=True, 185 | ) 186 | 187 | # define the vae and mar model 188 | vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval() 189 | for param in vae.parameters(): 190 | param.requires_grad = False 191 | 192 | model = mar.__dict__[args.model]( 193 | img_size=args.img_size, 194 | vae_stride=args.vae_stride, 195 | patch_size=args.patch_size, 196 | vae_embed_dim=args.vae_embed_dim, 197 | mask_ratio_min=args.mask_ratio_min, 198 | label_drop_prob=args.label_drop_prob, 199 | class_num=args.class_num, 200 | attn_dropout=args.attn_dropout, 201 | proj_dropout=args.proj_dropout, 202 | buffer_size=args.buffer_size, 203 | diffloss_d=args.diffloss_d, 204 | diffloss_w=args.diffloss_w, 205 | num_sampling_steps=args.num_sampling_steps, 206 | diffusion_batch_mul=args.diffusion_batch_mul, 207 | grad_checkpointing=args.grad_checkpointing, 208 | ) 209 | 210 | print("Model = %s" % str(model)) 211 | # following timm: set wd as 0 for bias and norm layers 212 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | print("Number of trainable parameters: {}M".format(n_params / 1e6)) 214 | 215 | model.to(device) 216 | model_without_ddp = model 217 | 218 | eff_batch_size = args.batch_size * misc.get_world_size() 219 | 220 | if args.lr is None: # only base_lr is specified 221 | args.lr = args.blr * eff_batch_size / 256 222 | 223 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 224 | print("actual lr: %.2e" % args.lr) 225 | print("effective batch size: %d" % eff_batch_size) 226 | 227 | if args.distributed: 228 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 229 | model_without_ddp = model.module 230 | 231 | # no weight decay on bias, norm layers, and diffloss MLP 232 | param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay) 233 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 234 | print(optimizer) 235 | loss_scaler = NativeScaler() 236 | 237 | # resume training 238 | if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")): 239 | checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu') 240 | model_without_ddp.load_state_dict(checkpoint['model']) 241 | model_params = list(model_without_ddp.parameters()) 242 | ema_state_dict = checkpoint['model_ema'] 243 | ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()] 244 | print("Resume checkpoint %s" % args.resume) 245 | 246 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 247 | optimizer.load_state_dict(checkpoint['optimizer']) 248 | args.start_epoch = checkpoint['epoch'] + 1 249 | if 'scaler' in checkpoint: 250 | loss_scaler.load_state_dict(checkpoint['scaler']) 251 | print("With optim & sched!") 252 | del checkpoint 253 | else: 254 | model_params = list(model_without_ddp.parameters()) 255 | ema_params = copy.deepcopy(model_params) 256 | print("Training from scratch") 257 | 258 | # evaluate FID and IS 259 | if args.evaluate: 260 | torch.cuda.empty_cache() 261 | evaluate(model_without_ddp, vae, ema_params, args, 0, batch_size=args.eval_bsz, log_writer=log_writer, 262 | cfg=args.cfg, use_ema=True) 263 | return 264 | 265 | # training 266 | print(f"Start training for {args.epochs} epochs") 267 | start_time = time.time() 268 | for epoch in range(args.start_epoch, args.epochs): 269 | if args.distributed: 270 | data_loader_train.sampler.set_epoch(epoch) 271 | 272 | train_one_epoch( 273 | model, vae, 274 | model_params, ema_params, 275 | data_loader_train, 276 | optimizer, device, epoch, loss_scaler, 277 | log_writer=log_writer, 278 | args=args 279 | ) 280 | 281 | # save checkpoint 282 | if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs: 283 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 284 | loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name="last") 285 | 286 | # online evaluation 287 | if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs): 288 | torch.cuda.empty_cache() 289 | evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz, log_writer=log_writer, 290 | cfg=1.0, use_ema=True) 291 | if not (args.cfg == 1.0 or args.cfg == 0.0): 292 | evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz // 2, 293 | log_writer=log_writer, cfg=args.cfg, use_ema=True) 294 | torch.cuda.empty_cache() 295 | 296 | if misc.is_main_process(): 297 | if log_writer is not None: 298 | log_writer.flush() 299 | 300 | total_time = time.time() - start_time 301 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 302 | print('Training time {}'.format(total_time_str)) 303 | 304 | 305 | if __name__ == '__main__': 306 | args = get_args_parser() 307 | args = args.parse_args() 308 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 309 | args.log_dir = args.output_dir 310 | main(args) 311 | -------------------------------------------------------------------------------- /models/diffloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | 6 | from diffusion import create_diffusion 7 | 8 | 9 | class DiffLoss(nn.Module): 10 | """Diffusion Loss""" 11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False): 12 | super(DiffLoss, self).__init__() 13 | self.in_channels = target_channels 14 | self.net = SimpleMLPAdaLN( 15 | in_channels=target_channels, 16 | model_channels=width, 17 | out_channels=target_channels * 2, # for vlb loss 18 | z_channels=z_channels, 19 | num_res_blocks=depth, 20 | grad_checkpointing=grad_checkpointing 21 | ) 22 | 23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") 24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") 25 | 26 | def forward(self, target, z, mask=None): 27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) 28 | model_kwargs = dict(c=z) 29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 30 | loss = loss_dict["loss"] 31 | if mask is not None: 32 | loss = (loss * mask).sum() / mask.sum() 33 | return loss.mean() 34 | 35 | def sample(self, z, temperature=1.0, cfg=1.0): 36 | # diffusion loss sampling 37 | if not cfg == 1.0: 38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 39 | noise = torch.cat([noise, noise], dim=0) 40 | model_kwargs = dict(c=z, cfg_scale=cfg) 41 | sample_fn = self.net.forward_with_cfg 42 | else: 43 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 44 | model_kwargs = dict(c=z) 45 | sample_fn = self.net.forward 46 | 47 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 49 | temperature=temperature 50 | ) 51 | 52 | return sampled_token_latent 53 | 54 | 55 | def modulate(x, shift, scale): 56 | return x * (1 + scale) + shift 57 | 58 | 59 | class TimestepEmbedder(nn.Module): 60 | """ 61 | Embeds scalar timesteps into vector representations. 62 | """ 63 | def __init__(self, hidden_size, frequency_embedding_size=256): 64 | super().__init__() 65 | self.mlp = nn.Sequential( 66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 67 | nn.SiLU(), 68 | nn.Linear(hidden_size, hidden_size, bias=True), 69 | ) 70 | self.frequency_embedding_size = frequency_embedding_size 71 | 72 | @staticmethod 73 | def timestep_embedding(t, dim, max_period=10000): 74 | """ 75 | Create sinusoidal timestep embeddings. 76 | :param t: a 1-D Tensor of N indices, one per batch element. 77 | These may be fractional. 78 | :param dim: the dimension of the output. 79 | :param max_period: controls the minimum frequency of the embeddings. 80 | :return: an (N, D) Tensor of positional embeddings. 81 | """ 82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 83 | half = dim // 2 84 | freqs = torch.exp( 85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 86 | ).to(device=t.device) 87 | args = t[:, None].float() * freqs[None] 88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 89 | if dim % 2: 90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 91 | return embedding 92 | 93 | def forward(self, t): 94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 95 | t_emb = self.mlp(t_freq) 96 | return t_emb 97 | 98 | 99 | class ResBlock(nn.Module): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | :param channels: the number of input channels. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | channels 108 | ): 109 | super().__init__() 110 | self.channels = channels 111 | 112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 113 | self.mlp = nn.Sequential( 114 | nn.Linear(channels, channels, bias=True), 115 | nn.SiLU(), 116 | nn.Linear(channels, channels, bias=True), 117 | ) 118 | 119 | self.adaLN_modulation = nn.Sequential( 120 | nn.SiLU(), 121 | nn.Linear(channels, 3 * channels, bias=True) 122 | ) 123 | 124 | def forward(self, x, y): 125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 127 | h = self.mlp(h) 128 | return x + gate_mlp * h 129 | 130 | 131 | class FinalLayer(nn.Module): 132 | """ 133 | The final layer adopted from DiT. 134 | """ 135 | def __init__(self, model_channels, out_channels): 136 | super().__init__() 137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 138 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 139 | self.adaLN_modulation = nn.Sequential( 140 | nn.SiLU(), 141 | nn.Linear(model_channels, 2 * model_channels, bias=True) 142 | ) 143 | 144 | def forward(self, x, c): 145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 146 | x = modulate(self.norm_final(x), shift, scale) 147 | x = self.linear(x) 148 | return x 149 | 150 | 151 | class SimpleMLPAdaLN(nn.Module): 152 | """ 153 | The MLP for Diffusion Loss. 154 | :param in_channels: channels in the input Tensor. 155 | :param model_channels: base channel count for the model. 156 | :param out_channels: channels in the output Tensor. 157 | :param z_channels: channels in the condition. 158 | :param num_res_blocks: number of residual blocks per downsample. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | model_channels, 165 | out_channels, 166 | z_channels, 167 | num_res_blocks, 168 | grad_checkpointing=False 169 | ): 170 | super().__init__() 171 | 172 | self.in_channels = in_channels 173 | self.model_channels = model_channels 174 | self.out_channels = out_channels 175 | self.num_res_blocks = num_res_blocks 176 | self.grad_checkpointing = grad_checkpointing 177 | 178 | self.time_embed = TimestepEmbedder(model_channels) 179 | self.cond_embed = nn.Linear(z_channels, model_channels) 180 | 181 | self.input_proj = nn.Linear(in_channels, model_channels) 182 | 183 | res_blocks = [] 184 | for i in range(num_res_blocks): 185 | res_blocks.append(ResBlock( 186 | model_channels, 187 | )) 188 | 189 | self.res_blocks = nn.ModuleList(res_blocks) 190 | self.final_layer = FinalLayer(model_channels, out_channels) 191 | 192 | self.initialize_weights() 193 | 194 | def initialize_weights(self): 195 | def _basic_init(module): 196 | if isinstance(module, nn.Linear): 197 | torch.nn.init.xavier_uniform_(module.weight) 198 | if module.bias is not None: 199 | nn.init.constant_(module.bias, 0) 200 | self.apply(_basic_init) 201 | 202 | # Initialize timestep embedding MLP 203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 205 | 206 | # Zero-out adaLN modulation layers 207 | for block in self.res_blocks: 208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 210 | 211 | # Zero-out output layers 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | def forward(self, x, t, c): 218 | """ 219 | Apply the model to an input batch. 220 | :param x: an [N x C] Tensor of inputs. 221 | :param t: a 1-D batch of timesteps. 222 | :param c: conditioning from AR transformer. 223 | :return: an [N x C] Tensor of outputs. 224 | """ 225 | x = self.input_proj(x) 226 | t = self.time_embed(t) 227 | c = self.cond_embed(c) 228 | 229 | y = t + c 230 | 231 | if self.grad_checkpointing and not torch.jit.is_scripting(): 232 | for block in self.res_blocks: 233 | x = checkpoint(block, x, y) 234 | else: 235 | for block in self.res_blocks: 236 | x = block(x, y) 237 | 238 | return self.final_layer(x, y) 239 | 240 | def forward_with_cfg(self, x, t, c, cfg_scale): 241 | half = x[: len(x) // 2] 242 | combined = torch.cat([half, half], dim=0) 243 | model_out = self.forward(combined, t, c) 244 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 245 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 246 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 247 | eps = torch.cat([half_eps, half_eps], dim=0) 248 | return torch.cat([eps, rest], dim=1) 249 | -------------------------------------------------------------------------------- /models/mar.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | import scipy.stats as stats 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.checkpoint import checkpoint 10 | 11 | from timm.models.vision_transformer import Block 12 | 13 | from models.diffloss import DiffLoss 14 | 15 | 16 | def mask_by_order(mask_len, order, bsz, seq_len): 17 | masking = torch.zeros(bsz, seq_len).cuda() 18 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool() 19 | return masking 20 | 21 | 22 | class MAR(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=256, vae_stride=16, patch_size=1, 26 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, 27 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, 29 | vae_embed_dim=16, 30 | mask_ratio_min=0.7, 31 | label_drop_prob=0.1, 32 | class_num=1000, 33 | attn_dropout=0.1, 34 | proj_dropout=0.1, 35 | buffer_size=64, 36 | diffloss_d=3, 37 | diffloss_w=1024, 38 | num_sampling_steps='100', 39 | diffusion_batch_mul=4, 40 | grad_checkpointing=False, 41 | ): 42 | super().__init__() 43 | 44 | # -------------------------------------------------------------------------- 45 | # VAE and patchify specifics 46 | self.vae_embed_dim = vae_embed_dim 47 | 48 | self.img_size = img_size 49 | self.vae_stride = vae_stride 50 | self.patch_size = patch_size 51 | self.seq_h = self.seq_w = img_size // vae_stride // patch_size 52 | self.seq_len = self.seq_h * self.seq_w 53 | self.token_embed_dim = vae_embed_dim * patch_size**2 54 | self.grad_checkpointing = grad_checkpointing 55 | 56 | # -------------------------------------------------------------------------- 57 | # Class Embedding 58 | self.num_classes = class_num 59 | self.class_emb = nn.Embedding(class_num, encoder_embed_dim) 60 | self.label_drop_prob = label_drop_prob 61 | # Fake class embedding for CFG's unconditional generation 62 | self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim)) 63 | 64 | # -------------------------------------------------------------------------- 65 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25 66 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25) 67 | 68 | # -------------------------------------------------------------------------- 69 | # MAR encoder specifics 70 | self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True) 71 | self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6) 72 | self.buffer_size = buffer_size 73 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim)) 74 | 75 | self.encoder_blocks = nn.ModuleList([ 76 | Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 77 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)]) 78 | self.encoder_norm = norm_layer(encoder_embed_dim) 79 | 80 | # -------------------------------------------------------------------------- 81 | # MAR decoder specifics 82 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) 83 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 84 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) 85 | 86 | self.decoder_blocks = nn.ModuleList([ 87 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 88 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) 89 | 90 | self.decoder_norm = norm_layer(decoder_embed_dim) 91 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) 92 | 93 | self.initialize_weights() 94 | 95 | # -------------------------------------------------------------------------- 96 | # Diffusion Loss 97 | self.diffloss = DiffLoss( 98 | target_channels=self.token_embed_dim, 99 | z_channels=decoder_embed_dim, 100 | width=diffloss_w, 101 | depth=diffloss_d, 102 | num_sampling_steps=num_sampling_steps, 103 | grad_checkpointing=grad_checkpointing 104 | ) 105 | self.diffusion_batch_mul = diffusion_batch_mul 106 | 107 | def initialize_weights(self): 108 | # parameters 109 | torch.nn.init.normal_(self.class_emb.weight, std=.02) 110 | torch.nn.init.normal_(self.fake_latent, std=.02) 111 | torch.nn.init.normal_(self.mask_token, std=.02) 112 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02) 113 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) 114 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) 115 | 116 | # initialize nn.Linear and nn.LayerNorm 117 | self.apply(self._init_weights) 118 | 119 | def _init_weights(self, m): 120 | if isinstance(m, nn.Linear): 121 | # we use xavier_uniform following official JAX ViT: 122 | torch.nn.init.xavier_uniform_(m.weight) 123 | if isinstance(m, nn.Linear) and m.bias is not None: 124 | nn.init.constant_(m.bias, 0) 125 | elif isinstance(m, nn.LayerNorm): 126 | if m.bias is not None: 127 | nn.init.constant_(m.bias, 0) 128 | if m.weight is not None: 129 | nn.init.constant_(m.weight, 1.0) 130 | 131 | def patchify(self, x): 132 | bsz, c, h, w = x.shape 133 | p = self.patch_size 134 | h_, w_ = h // p, w // p 135 | 136 | x = x.reshape(bsz, c, h_, p, w_, p) 137 | x = torch.einsum('nchpwq->nhwcpq', x) 138 | x = x.reshape(bsz, h_ * w_, c * p ** 2) 139 | return x # [n, l, d] 140 | 141 | def unpatchify(self, x): 142 | bsz = x.shape[0] 143 | p = self.patch_size 144 | c = self.vae_embed_dim 145 | h_, w_ = self.seq_h, self.seq_w 146 | 147 | x = x.reshape(bsz, h_, w_, c, p, p) 148 | x = torch.einsum('nhwcpq->nchpwq', x) 149 | x = x.reshape(bsz, c, h_ * p, w_ * p) 150 | return x # [n, c, h, w] 151 | 152 | def sample_orders(self, bsz): 153 | # generate a batch of random generation orders 154 | orders = [] 155 | for _ in range(bsz): 156 | order = np.array(list(range(self.seq_len))) 157 | np.random.shuffle(order) 158 | orders.append(order) 159 | orders = torch.Tensor(np.array(orders)).cuda().long() 160 | return orders 161 | 162 | def random_masking(self, x, orders): 163 | # generate token mask 164 | bsz, seq_len, embed_dim = x.shape 165 | mask_rate = self.mask_ratio_generator.rvs(1)[0] 166 | num_masked_tokens = int(np.ceil(seq_len * mask_rate)) 167 | mask = torch.zeros(bsz, seq_len, device=x.device) 168 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], 169 | src=torch.ones(bsz, seq_len, device=x.device)) 170 | return mask 171 | 172 | def forward_mae_encoder(self, x, mask, class_embedding): 173 | x = self.z_proj(x) 174 | bsz, seq_len, embed_dim = x.shape 175 | 176 | # concat buffer 177 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1) 178 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 179 | 180 | # random drop class embedding during training 181 | if self.training: 182 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob 183 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype) 184 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding 185 | 186 | x[:, :self.buffer_size] = class_embedding.unsqueeze(1) 187 | 188 | # encoder position embedding 189 | x = x + self.encoder_pos_embed_learned 190 | x = self.z_proj_ln(x) 191 | 192 | # dropping 193 | x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim) 194 | 195 | # apply Transformer blocks 196 | if self.grad_checkpointing and not torch.jit.is_scripting(): 197 | for block in self.encoder_blocks: 198 | x = checkpoint(block, x) 199 | else: 200 | for block in self.encoder_blocks: 201 | x = block(x) 202 | x = self.encoder_norm(x) 203 | 204 | return x 205 | 206 | def forward_mae_decoder(self, x, mask): 207 | 208 | x = self.decoder_embed(x) 209 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 210 | 211 | # pad mask tokens 212 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) 213 | x_after_pad = mask_tokens.clone() 214 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) 215 | 216 | # decoder position embedding 217 | x = x_after_pad + self.decoder_pos_embed_learned 218 | 219 | # apply Transformer blocks 220 | if self.grad_checkpointing and not torch.jit.is_scripting(): 221 | for block in self.decoder_blocks: 222 | x = checkpoint(block, x) 223 | else: 224 | for block in self.decoder_blocks: 225 | x = block(x) 226 | x = self.decoder_norm(x) 227 | 228 | x = x[:, self.buffer_size:] 229 | x = x + self.diffusion_pos_embed_learned 230 | return x 231 | 232 | def forward_loss(self, z, target, mask): 233 | bsz, seq_len, _ = target.shape 234 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 235 | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) 236 | mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) 237 | loss = self.diffloss(z=z, target=target, mask=mask) 238 | return loss 239 | 240 | def forward(self, imgs, labels): 241 | 242 | # class embed 243 | class_embedding = self.class_emb(labels) 244 | 245 | # patchify and mask (drop) tokens 246 | x = self.patchify(imgs) 247 | gt_latents = x.clone().detach() 248 | orders = self.sample_orders(bsz=x.size(0)) 249 | mask = self.random_masking(x, orders) 250 | 251 | # mae encoder 252 | x = self.forward_mae_encoder(x, mask, class_embedding) 253 | 254 | # mae decoder 255 | z = self.forward_mae_decoder(x, mask) 256 | 257 | # diffloss 258 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask) 259 | 260 | return loss 261 | 262 | def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): 263 | 264 | # init and sample generation orders 265 | mask = torch.ones(bsz, self.seq_len).cuda() 266 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda() 267 | orders = self.sample_orders(bsz) 268 | 269 | indices = list(range(num_iter)) 270 | if progress: 271 | indices = tqdm(indices) 272 | # generate latents 273 | for step in indices: 274 | cur_tokens = tokens.clone() 275 | 276 | # class embedding and CFG 277 | if labels is not None: 278 | class_embedding = self.class_emb(labels) 279 | else: 280 | class_embedding = self.fake_latent.repeat(bsz, 1) 281 | if not cfg == 1.0: 282 | tokens = torch.cat([tokens, tokens], dim=0) 283 | class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0) 284 | mask = torch.cat([mask, mask], dim=0) 285 | 286 | # mae encoder 287 | x = self.forward_mae_encoder(tokens, mask, class_embedding) 288 | 289 | # mae decoder 290 | z = self.forward_mae_decoder(x, mask) 291 | 292 | # mask ratio for the next round, following MaskGIT and MAGE. 293 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) 294 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() 295 | 296 | # masks out at least one for the next iteration 297 | mask_len = torch.maximum(torch.Tensor([1]).cuda(), 298 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) 299 | 300 | # get masking for next iteration and locations to be predicted in this iteration 301 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) 302 | if step >= num_iter - 1: 303 | mask_to_pred = mask[:bsz].bool() 304 | else: 305 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) 306 | mask = mask_next 307 | if not cfg == 1.0: 308 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) 309 | 310 | # sample token latents for this step 311 | z = z[mask_to_pred.nonzero(as_tuple=True)] 312 | # cfg schedule follow Muse 313 | if cfg_schedule == "linear": 314 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len 315 | elif cfg_schedule == "constant": 316 | cfg_iter = cfg 317 | else: 318 | raise NotImplementedError 319 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) 320 | if not cfg == 1.0: 321 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples 322 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) 323 | 324 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent 325 | tokens = cur_tokens.clone() 326 | 327 | # unpatchify 328 | tokens = self.unpatchify(tokens) 329 | return tokens 330 | 331 | 332 | def mar_base(**kwargs): 333 | model = MAR( 334 | encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, 335 | decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12, 336 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 337 | return model 338 | 339 | 340 | def mar_large(**kwargs): 341 | model = MAR( 342 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, 343 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, 344 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 345 | return model 346 | 347 | 348 | def mar_huge(**kwargs): 349 | model = MAR( 350 | encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16, 351 | decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16, 352 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 353 | return model 354 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | # Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion 2 | import torch 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | 7 | 8 | def nonlinearity(x): 9 | # swish 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | def Normalize(in_channels, num_groups=32): 14 | return torch.nn.GroupNorm( 15 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 16 | ) 17 | 18 | 19 | class Upsample(nn.Module): 20 | def __init__(self, in_channels, with_conv): 21 | super().__init__() 22 | self.with_conv = with_conv 23 | if self.with_conv: 24 | self.conv = torch.nn.Conv2d( 25 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 26 | ) 27 | 28 | def forward(self, x): 29 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 30 | if self.with_conv: 31 | x = self.conv(x) 32 | return x 33 | 34 | 35 | class Downsample(nn.Module): 36 | def __init__(self, in_channels, with_conv): 37 | super().__init__() 38 | self.with_conv = with_conv 39 | if self.with_conv: 40 | # no asymmetric padding in torch conv, must do it ourselves 41 | self.conv = torch.nn.Conv2d( 42 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 43 | ) 44 | 45 | def forward(self, x): 46 | if self.with_conv: 47 | pad = (0, 1, 0, 1) 48 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 49 | x = self.conv(x) 50 | else: 51 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 52 | return x 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__( 57 | self, 58 | *, 59 | in_channels, 60 | out_channels=None, 61 | conv_shortcut=False, 62 | dropout, 63 | temb_channels=512, 64 | ): 65 | super().__init__() 66 | self.in_channels = in_channels 67 | out_channels = in_channels if out_channels is None else out_channels 68 | self.out_channels = out_channels 69 | self.use_conv_shortcut = conv_shortcut 70 | 71 | self.norm1 = Normalize(in_channels) 72 | self.conv1 = torch.nn.Conv2d( 73 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 74 | ) 75 | if temb_channels > 0: 76 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 77 | self.norm2 = Normalize(out_channels) 78 | self.dropout = torch.nn.Dropout(dropout) 79 | self.conv2 = torch.nn.Conv2d( 80 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 81 | ) 82 | if self.in_channels != self.out_channels: 83 | if self.use_conv_shortcut: 84 | self.conv_shortcut = torch.nn.Conv2d( 85 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 86 | ) 87 | else: 88 | self.nin_shortcut = torch.nn.Conv2d( 89 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 90 | ) 91 | 92 | def forward(self, x, temb): 93 | h = x 94 | h = self.norm1(h) 95 | h = nonlinearity(h) 96 | h = self.conv1(h) 97 | 98 | if temb is not None: 99 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 100 | 101 | h = self.norm2(h) 102 | h = nonlinearity(h) 103 | h = self.dropout(h) 104 | h = self.conv2(h) 105 | 106 | if self.in_channels != self.out_channels: 107 | if self.use_conv_shortcut: 108 | x = self.conv_shortcut(x) 109 | else: 110 | x = self.nin_shortcut(x) 111 | 112 | return x + h 113 | 114 | 115 | class AttnBlock(nn.Module): 116 | def __init__(self, in_channels): 117 | super().__init__() 118 | self.in_channels = in_channels 119 | 120 | self.norm = Normalize(in_channels) 121 | self.q = torch.nn.Conv2d( 122 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 123 | ) 124 | self.k = torch.nn.Conv2d( 125 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 126 | ) 127 | self.v = torch.nn.Conv2d( 128 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 129 | ) 130 | self.proj_out = torch.nn.Conv2d( 131 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 132 | ) 133 | 134 | def forward(self, x): 135 | h_ = x 136 | h_ = self.norm(h_) 137 | q = self.q(h_) 138 | k = self.k(h_) 139 | v = self.v(h_) 140 | 141 | # compute attention 142 | b, c, h, w = q.shape 143 | q = q.reshape(b, c, h * w) 144 | q = q.permute(0, 2, 1) # b,hw,c 145 | k = k.reshape(b, c, h * w) # b,c,hw 146 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 147 | w_ = w_ * (int(c) ** (-0.5)) 148 | w_ = torch.nn.functional.softmax(w_, dim=2) 149 | 150 | # attend to values 151 | v = v.reshape(b, c, h * w) 152 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 153 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 154 | h_ = h_.reshape(b, c, h, w) 155 | 156 | h_ = self.proj_out(h_) 157 | 158 | return x + h_ 159 | 160 | 161 | class Encoder(nn.Module): 162 | def __init__( 163 | self, 164 | *, 165 | ch=128, 166 | out_ch=3, 167 | ch_mult=(1, 1, 2, 2, 4), 168 | num_res_blocks=2, 169 | attn_resolutions=(16,), 170 | dropout=0.0, 171 | resamp_with_conv=True, 172 | in_channels=3, 173 | resolution=256, 174 | z_channels=16, 175 | double_z=True, 176 | **ignore_kwargs, 177 | ): 178 | super().__init__() 179 | self.ch = ch 180 | self.temb_ch = 0 181 | self.num_resolutions = len(ch_mult) 182 | self.num_res_blocks = num_res_blocks 183 | self.resolution = resolution 184 | self.in_channels = in_channels 185 | 186 | # downsampling 187 | self.conv_in = torch.nn.Conv2d( 188 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 189 | ) 190 | 191 | curr_res = resolution 192 | in_ch_mult = (1,) + tuple(ch_mult) 193 | self.down = nn.ModuleList() 194 | for i_level in range(self.num_resolutions): 195 | block = nn.ModuleList() 196 | attn = nn.ModuleList() 197 | block_in = ch * in_ch_mult[i_level] 198 | block_out = ch * ch_mult[i_level] 199 | for i_block in range(self.num_res_blocks): 200 | block.append( 201 | ResnetBlock( 202 | in_channels=block_in, 203 | out_channels=block_out, 204 | temb_channels=self.temb_ch, 205 | dropout=dropout, 206 | ) 207 | ) 208 | block_in = block_out 209 | if curr_res in attn_resolutions: 210 | attn.append(AttnBlock(block_in)) 211 | down = nn.Module() 212 | down.block = block 213 | down.attn = attn 214 | if i_level != self.num_resolutions - 1: 215 | down.downsample = Downsample(block_in, resamp_with_conv) 216 | curr_res = curr_res // 2 217 | self.down.append(down) 218 | 219 | # middle 220 | self.mid = nn.Module() 221 | self.mid.block_1 = ResnetBlock( 222 | in_channels=block_in, 223 | out_channels=block_in, 224 | temb_channels=self.temb_ch, 225 | dropout=dropout, 226 | ) 227 | self.mid.attn_1 = AttnBlock(block_in) 228 | self.mid.block_2 = ResnetBlock( 229 | in_channels=block_in, 230 | out_channels=block_in, 231 | temb_channels=self.temb_ch, 232 | dropout=dropout, 233 | ) 234 | 235 | # end 236 | self.norm_out = Normalize(block_in) 237 | self.conv_out = torch.nn.Conv2d( 238 | block_in, 239 | 2 * z_channels if double_z else z_channels, 240 | kernel_size=3, 241 | stride=1, 242 | padding=1, 243 | ) 244 | 245 | def forward(self, x): 246 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 247 | 248 | # timestep embedding 249 | temb = None 250 | 251 | # downsampling 252 | hs = [self.conv_in(x)] 253 | for i_level in range(self.num_resolutions): 254 | for i_block in range(self.num_res_blocks): 255 | h = self.down[i_level].block[i_block](hs[-1], temb) 256 | if len(self.down[i_level].attn) > 0: 257 | h = self.down[i_level].attn[i_block](h) 258 | hs.append(h) 259 | if i_level != self.num_resolutions - 1: 260 | hs.append(self.down[i_level].downsample(hs[-1])) 261 | 262 | # middle 263 | h = hs[-1] 264 | h = self.mid.block_1(h, temb) 265 | h = self.mid.attn_1(h) 266 | h = self.mid.block_2(h, temb) 267 | 268 | # end 269 | h = self.norm_out(h) 270 | h = nonlinearity(h) 271 | h = self.conv_out(h) 272 | return h 273 | 274 | 275 | class Decoder(nn.Module): 276 | def __init__( 277 | self, 278 | *, 279 | ch=128, 280 | out_ch=3, 281 | ch_mult=(1, 1, 2, 2, 4), 282 | num_res_blocks=2, 283 | attn_resolutions=(), 284 | dropout=0.0, 285 | resamp_with_conv=True, 286 | in_channels=3, 287 | resolution=256, 288 | z_channels=16, 289 | give_pre_end=False, 290 | **ignore_kwargs, 291 | ): 292 | super().__init__() 293 | self.ch = ch 294 | self.temb_ch = 0 295 | self.num_resolutions = len(ch_mult) 296 | self.num_res_blocks = num_res_blocks 297 | self.resolution = resolution 298 | self.in_channels = in_channels 299 | self.give_pre_end = give_pre_end 300 | 301 | # compute in_ch_mult, block_in and curr_res at lowest res 302 | in_ch_mult = (1,) + tuple(ch_mult) 303 | block_in = ch * ch_mult[self.num_resolutions - 1] 304 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 305 | self.z_shape = (1, z_channels, curr_res, curr_res) 306 | print( 307 | "Working with z of shape {} = {} dimensions.".format( 308 | self.z_shape, np.prod(self.z_shape) 309 | ) 310 | ) 311 | 312 | # z to block_in 313 | self.conv_in = torch.nn.Conv2d( 314 | z_channels, block_in, kernel_size=3, stride=1, padding=1 315 | ) 316 | 317 | # middle 318 | self.mid = nn.Module() 319 | self.mid.block_1 = ResnetBlock( 320 | in_channels=block_in, 321 | out_channels=block_in, 322 | temb_channels=self.temb_ch, 323 | dropout=dropout, 324 | ) 325 | self.mid.attn_1 = AttnBlock(block_in) 326 | self.mid.block_2 = ResnetBlock( 327 | in_channels=block_in, 328 | out_channels=block_in, 329 | temb_channels=self.temb_ch, 330 | dropout=dropout, 331 | ) 332 | 333 | # upsampling 334 | self.up = nn.ModuleList() 335 | for i_level in reversed(range(self.num_resolutions)): 336 | block = nn.ModuleList() 337 | attn = nn.ModuleList() 338 | block_out = ch * ch_mult[i_level] 339 | for i_block in range(self.num_res_blocks + 1): 340 | block.append( 341 | ResnetBlock( 342 | in_channels=block_in, 343 | out_channels=block_out, 344 | temb_channels=self.temb_ch, 345 | dropout=dropout, 346 | ) 347 | ) 348 | block_in = block_out 349 | if curr_res in attn_resolutions: 350 | attn.append(AttnBlock(block_in)) 351 | up = nn.Module() 352 | up.block = block 353 | up.attn = attn 354 | if i_level != 0: 355 | up.upsample = Upsample(block_in, resamp_with_conv) 356 | curr_res = curr_res * 2 357 | self.up.insert(0, up) # prepend to get consistent order 358 | 359 | # end 360 | self.norm_out = Normalize(block_in) 361 | self.conv_out = torch.nn.Conv2d( 362 | block_in, out_ch, kernel_size=3, stride=1, padding=1 363 | ) 364 | 365 | def forward(self, z): 366 | # assert z.shape[1:] == self.z_shape[1:] 367 | self.last_z_shape = z.shape 368 | 369 | # timestep embedding 370 | temb = None 371 | 372 | # z to block_in 373 | h = self.conv_in(z) 374 | 375 | # middle 376 | h = self.mid.block_1(h, temb) 377 | h = self.mid.attn_1(h) 378 | h = self.mid.block_2(h, temb) 379 | 380 | # upsampling 381 | for i_level in reversed(range(self.num_resolutions)): 382 | for i_block in range(self.num_res_blocks + 1): 383 | h = self.up[i_level].block[i_block](h, temb) 384 | if len(self.up[i_level].attn) > 0: 385 | h = self.up[i_level].attn[i_block](h) 386 | if i_level != 0: 387 | h = self.up[i_level].upsample(h) 388 | 389 | # end 390 | if self.give_pre_end: 391 | return h 392 | 393 | h = self.norm_out(h) 394 | h = nonlinearity(h) 395 | h = self.conv_out(h) 396 | return h 397 | 398 | 399 | class DiagonalGaussianDistribution(object): 400 | def __init__(self, parameters, deterministic=False): 401 | self.parameters = parameters 402 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 403 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 404 | self.deterministic = deterministic 405 | self.std = torch.exp(0.5 * self.logvar) 406 | self.var = torch.exp(self.logvar) 407 | if self.deterministic: 408 | self.var = self.std = torch.zeros_like(self.mean).to( 409 | device=self.parameters.device 410 | ) 411 | 412 | def sample(self): 413 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 414 | device=self.parameters.device 415 | ) 416 | return x 417 | 418 | def kl(self, other=None): 419 | if self.deterministic: 420 | return torch.Tensor([0.0]) 421 | else: 422 | if other is None: 423 | return 0.5 * torch.sum( 424 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 425 | dim=[1, 2, 3], 426 | ) 427 | else: 428 | return 0.5 * torch.sum( 429 | torch.pow(self.mean - other.mean, 2) / other.var 430 | + self.var / other.var 431 | - 1.0 432 | - self.logvar 433 | + other.logvar, 434 | dim=[1, 2, 3], 435 | ) 436 | 437 | def nll(self, sample, dims=[1, 2, 3]): 438 | if self.deterministic: 439 | return torch.Tensor([0.0]) 440 | logtwopi = np.log(2.0 * np.pi) 441 | return 0.5 * torch.sum( 442 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 443 | dim=dims, 444 | ) 445 | 446 | def mode(self): 447 | return self.mean 448 | 449 | 450 | class AutoencoderKL(nn.Module): 451 | def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None): 452 | super().__init__() 453 | self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim) 454 | self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim) 455 | self.use_variational = use_variational 456 | mult = 2 if self.use_variational else 1 457 | self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1) 458 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1) 459 | self.embed_dim = embed_dim 460 | if ckpt_path is not None: 461 | self.init_from_ckpt(ckpt_path) 462 | 463 | def init_from_ckpt(self, path): 464 | sd = torch.load(path, map_location="cpu")["model"] 465 | msg = self.load_state_dict(sd, strict=False) 466 | print("Loading pre-trained KL-VAE") 467 | print("Missing keys:") 468 | print(msg.missing_keys) 469 | print("Unexpected keys:") 470 | print(msg.unexpected_keys) 471 | print(f"Restored from {path}") 472 | 473 | def encode(self, x): 474 | h = self.encoder(x) 475 | moments = self.quant_conv(h) 476 | if not self.use_variational: 477 | moments = torch.cat((moments, torch.ones_like(moments)), 1) 478 | posterior = DiagonalGaussianDistribution(moments) 479 | return posterior 480 | 481 | def decode(self, z): 482 | z = self.post_quant_conv(z) 483 | dec = self.decoder(z) 484 | return dec 485 | 486 | def forward(self, inputs, disable=True, train=True, optimizer_idx=0): 487 | if train: 488 | return self.training_step(inputs, disable, optimizer_idx) 489 | else: 490 | return self.validation_step(inputs, disable) 491 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def center_crop_arr(pil_image, image_size): 6 | """ 7 | Center cropping implementation from ADM. 8 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 9 | """ 10 | while min(*pil_image.size) >= 2 * image_size: 11 | pil_image = pil_image.resize( 12 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 13 | ) 14 | 15 | scale = image_size / min(*pil_image.size) 16 | pil_image = pil_image.resize( 17 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 18 | ) 19 | 20 | arr = np.array(pil_image) 21 | crop_y = (arr.shape[0] - image_size) // 2 22 | crop_x = (arr.shape[1] - image_size) // 2 23 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 24 | -------------------------------------------------------------------------------- /util/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import requests 4 | 5 | 6 | def download_pretrained_vae(overwrite=False): 7 | download_path = "pretrained_models/vae/kl16.ckpt" 8 | if not os.path.exists(download_path) or overwrite: 9 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 10 | os.makedirs("pretrained_models/vae", exist_ok=True) 11 | r = requests.get("https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0", stream=True, headers=headers) 12 | print("Downloading KL-16 VAE...") 13 | with open(download_path, 'wb') as f: 14 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=254): 15 | if chunk: 16 | f.write(chunk) 17 | 18 | 19 | def download_pretrained_marb(overwrite=False): 20 | download_path = "pretrained_models/mar/mar_base/checkpoint-last.pth" 21 | if not os.path.exists(download_path) or overwrite: 22 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 23 | os.makedirs("pretrained_models/mar/mar_base", exist_ok=True) 24 | r = requests.get("https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0", stream=True, headers=headers) 25 | print("Downloading MAR-B...") 26 | with open(download_path, 'wb') as f: 27 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1587): 28 | if chunk: 29 | f.write(chunk) 30 | 31 | 32 | def download_pretrained_marl(overwrite=False): 33 | download_path = "pretrained_models/mar/mar_large/checkpoint-last.pth" 34 | if not os.path.exists(download_path) or overwrite: 35 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 36 | os.makedirs("pretrained_models/mar/mar_large", exist_ok=True) 37 | r = requests.get("https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0", stream=True, headers=headers) 38 | print("Downloading MAR-L...") 39 | with open(download_path, 'wb') as f: 40 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=3650): 41 | if chunk: 42 | f.write(chunk) 43 | 44 | 45 | def download_pretrained_marh(overwrite=False): 46 | download_path = "pretrained_models/mar/mar_huge/checkpoint-last.pth" 47 | if not os.path.exists(download_path) or overwrite: 48 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 49 | os.makedirs("pretrained_models/mar/mar_huge", exist_ok=True) 50 | r = requests.get("https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0", stream=True, headers=headers) 51 | print("Downloading MAR-H...") 52 | with open(download_path, 'wb') as f: 53 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=7191): 54 | if chunk: 55 | f.write(chunk) 56 | 57 | 58 | if __name__ == "__main__": 59 | download_pretrained_vae() 60 | download_pretrained_marb() 61 | download_pretrained_marl() 62 | download_pretrained_marh() 63 | -------------------------------------------------------------------------------- /util/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision.datasets as datasets 6 | 7 | 8 | class ImageFolderWithFilename(datasets.ImageFolder): 9 | def __getitem__(self, index: int): 10 | """ 11 | Args: 12 | index (int): Index 13 | 14 | Returns: 15 | tuple: (sample, target, filename). 16 | """ 17 | path, target = self.samples[index] 18 | sample = self.loader(path) 19 | if self.transform is not None: 20 | sample = self.transform(sample) 21 | if self.target_transform is not None: 22 | target = self.target_transform(target) 23 | 24 | filename = path.split(os.path.sep)[-2:] 25 | filename = os.path.join(*filename) 26 | return sample, target, filename 27 | 28 | 29 | class CachedFolder(datasets.DatasetFolder): 30 | def __init__( 31 | self, 32 | root: str, 33 | ): 34 | super().__init__( 35 | root, 36 | loader=None, 37 | extensions=(".npz",), 38 | ) 39 | 40 | def __getitem__(self, index: int): 41 | """ 42 | Args: 43 | index (int): Index 44 | 45 | Returns: 46 | tuple: (moments, target). 47 | """ 48 | path, target = self.samples[index] 49 | 50 | data = np.load(path) 51 | if torch.rand(1) < 0.5: # randomly hflip 52 | moments = data['moments'] 53 | else: 54 | moments = data['moments_flip'] 55 | 56 | return moments, target 57 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, epoch, args): 5 | """Decay the learning rate with half-cycle cosine after warmup""" 6 | if epoch < args.warmup_epochs: 7 | lr = args.lr * epoch / args.warmup_epochs 8 | else: 9 | if args.lr_schedule == "constant": 10 | lr = args.lr 11 | elif args.lr_schedule == "cosine": 12 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 13 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 14 | else: 15 | raise NotImplementedError 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import os 4 | import time 5 | from collections import defaultdict, deque 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.distributed as dist 10 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 11 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 12 | 13 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 14 | from torch._six import inf 15 | else: 16 | from torch import inf 17 | import copy 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if v is None: 90 | continue 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = [ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ] 136 | if torch.cuda.is_available(): 137 | log_msg.append('max mem: {memory:.0f}') 138 | log_msg = self.delimiter.join(log_msg) 139 | MB = 1024.0 * 1024.0 140 | for obj in iterable: 141 | data_time.update(time.time() - end) 142 | yield obj 143 | iter_time.update(time.time() - end) 144 | if i % print_freq == 0 or i == len(iterable) - 1: 145 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 147 | if torch.cuda.is_available(): 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time), 152 | memory=torch.cuda.max_memory_allocated() / MB)) 153 | else: 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {} ({:.4f} s / it)'.format( 163 | header, total_time_str, total_time / len(iterable))) 164 | 165 | 166 | def setup_for_distributed(is_master): 167 | """ 168 | This function disables printing when not in master process 169 | """ 170 | builtin_print = builtins.print 171 | 172 | def print(*args, **kwargs): 173 | force = kwargs.pop('force', False) 174 | force = force or (get_world_size() > 8) 175 | if is_master or force: 176 | now = datetime.datetime.now().time() 177 | builtin_print('[{}] '.format(now), end='') # print with time stamp 178 | builtin_print(*args, **kwargs) 179 | 180 | builtins.print = print 181 | 182 | 183 | def is_dist_avail_and_initialized(): 184 | if not dist.is_available(): 185 | return False 186 | if not dist.is_initialized(): 187 | return False 188 | return True 189 | 190 | 191 | def get_world_size(): 192 | if not is_dist_avail_and_initialized(): 193 | return 1 194 | return dist.get_world_size() 195 | 196 | 197 | def get_rank(): 198 | if not is_dist_avail_and_initialized(): 199 | return 0 200 | return dist.get_rank() 201 | 202 | 203 | def is_main_process(): 204 | return get_rank() == 0 205 | 206 | 207 | def save_on_master(*args, **kwargs): 208 | if is_main_process(): 209 | torch.save(*args, **kwargs) 210 | 211 | 212 | def init_distributed_mode(args): 213 | if args.dist_on_itp: 214 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 215 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 216 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 217 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 218 | os.environ['LOCAL_RANK'] = str(args.gpu) 219 | os.environ['RANK'] = str(args.rank) 220 | os.environ['WORLD_SIZE'] = str(args.world_size) 221 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 222 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 223 | args.rank = int(os.environ["RANK"]) 224 | args.world_size = int(os.environ['WORLD_SIZE']) 225 | args.gpu = int(os.environ['LOCAL_RANK']) 226 | elif 'SLURM_PROCID' in os.environ: 227 | args.rank = int(os.environ['SLURM_PROCID']) 228 | args.gpu = args.rank % torch.cuda.device_count() 229 | else: 230 | print('Not using distributed mode') 231 | setup_for_distributed(is_master=True) # hack 232 | args.distributed = False 233 | return 234 | 235 | args.distributed = True 236 | 237 | torch.cuda.set_device(args.gpu) 238 | args.dist_backend = 'nccl' 239 | print('| distributed init (rank {}): {}, gpu {}'.format( 240 | args.rank, args.dist_url, args.gpu), flush=True) 241 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 242 | world_size=args.world_size, rank=args.rank) 243 | torch.distributed.barrier() 244 | setup_for_distributed(args.rank == 0) 245 | 246 | 247 | class NativeScalerWithGradNormCount: 248 | state_dict_key = "amp_scaler" 249 | 250 | def __init__(self): 251 | self._scaler = torch.cuda.amp.GradScaler() 252 | 253 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 254 | self._scaler.scale(loss).backward(create_graph=create_graph) 255 | if update_grad: 256 | if clip_grad is not None: 257 | assert parameters is not None 258 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 259 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 260 | else: 261 | self._scaler.unscale_(optimizer) 262 | norm = get_grad_norm_(parameters) 263 | self._scaler.step(optimizer) 264 | self._scaler.update() 265 | else: 266 | norm = None 267 | return norm 268 | 269 | def state_dict(self): 270 | return self._scaler.state_dict() 271 | 272 | def load_state_dict(self, state_dict): 273 | self._scaler.load_state_dict(state_dict) 274 | 275 | 276 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 277 | if isinstance(parameters, torch.Tensor): 278 | parameters = [parameters] 279 | parameters = [p for p in parameters if p.grad is not None] 280 | norm_type = float(norm_type) 281 | if len(parameters) == 0: 282 | return torch.tensor(0.) 283 | device = parameters[0].grad.device 284 | if norm_type == inf: 285 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 286 | else: 287 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 288 | return total_norm 289 | 290 | 291 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 292 | decay = [] 293 | no_decay = [] 294 | for name, param in model.named_parameters(): 295 | if not param.requires_grad: 296 | continue # frozen weights 297 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: 298 | no_decay.append(param) # no weight decay on bias, norm and diffloss 299 | else: 300 | decay.append(param) 301 | return [ 302 | {'params': no_decay, 'weight_decay': 0.}, 303 | {'params': decay, 'weight_decay': weight_decay}] 304 | 305 | 306 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None): 307 | if epoch_name is None: 308 | epoch_name = str(epoch) 309 | output_dir = Path(args.output_dir) 310 | checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name) 311 | 312 | # ema 313 | if ema_params is not None: 314 | ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) 315 | for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): 316 | assert name in ema_state_dict 317 | ema_state_dict[name] = ema_params[i] 318 | else: 319 | ema_state_dict = None 320 | 321 | to_save = { 322 | 'model': model_without_ddp.state_dict(), 323 | 'model_ema': ema_state_dict, 324 | 'optimizer': optimizer.state_dict(), 325 | 'epoch': epoch, 326 | 'scaler': loss_scaler.state_dict(), 327 | 'args': args, 328 | } 329 | save_on_master(to_save, checkpoint_path) 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x --------------------------------------------------------------------------------