├── LICENSE
├── README.md
├── demo
├── gradio_app.py
├── run_mar.ipynb
└── visual.png
├── diffusion
├── __init__.py
├── diffusion_utils.py
├── gaussian_diffusion.py
└── respace.py
├── engine_mar.py
├── environment.yaml
├── fid_stats
└── adm_in256_stats.npz
├── main_cache.py
├── main_mar.py
├── models
├── diffloss.py
├── mar.py
└── vae.py
└── util
├── crop.py
├── download.py
├── loader.py
├── lr_sched.py
└── misc.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Tianhong Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Autoregressive Image Generation without Vector Quantization
Official PyTorch Implementation
2 |
3 | [](https://arxiv.org/abs/2406.11838)
4 | [](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=autoregressive-image-generation-without)
5 | [](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb)
6 | [](https://huggingface.co/jadechoghari/mar)
7 |
8 |
9 |
10 |
11 |
12 | This is a PyTorch/GPU implementation of the paper [Autoregressive Image Generation without Vector Quantization](https://arxiv.org/abs/2406.11838) (Neurips 2024 Spotlight Presentation):
13 |
14 | ```
15 | @article{li2024autoregressive,
16 | title={Autoregressive Image Generation without Vector Quantization},
17 | author={Li, Tianhong and Tian, Yonglong and Li, He and Deng, Mingyang and He, Kaiming},
18 | journal={arXiv preprint arXiv:2406.11838},
19 | year={2024}
20 | }
21 | ```
22 |
23 | This repo contains:
24 |
25 | * 🪐 A simple PyTorch implementation of [MAR](models/mar.py) and [DiffLoss](models/diffloss.py)
26 | * ⚡️ Pre-trained class-conditional MAR models trained on ImageNet 256x256
27 | * 💥 A self-contained [Colab notebook](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb) for running various pre-trained MAR models
28 | * 🛸 An MAR+DiffLoss [training and evaluation script](main_mar.py) using PyTorch DDP
29 | * 🎉 Also checkout our [Hugging Face model cards](https://huggingface.co/jadechoghari/mar) and [Gradio demo](https://huggingface.co/spaces/jadechoghari/mar) (thanks [@jadechoghari](https://github.com/jadechoghari)).
30 |
31 | ## Preparation
32 |
33 | ### Dataset
34 | Download [ImageNet](http://image-net.org/download) dataset, and place it in your `IMAGENET_PATH`.
35 |
36 | ### Installation
37 |
38 | Download the code:
39 | ```
40 | git clone https://github.com/LTH14/mar.git
41 | cd mar
42 | ```
43 |
44 | A suitable [conda](https://conda.io/) environment named `mar` can be created and activated with:
45 |
46 | ```
47 | conda env create -f environment.yaml
48 | conda activate mar
49 | ```
50 |
51 | Download pre-trained VAE and MAR models:
52 |
53 | ```
54 | python util/download.py
55 | ```
56 |
57 | For convenience, our pre-trained MAR models can be downloaded directly here as well:
58 |
59 | | MAR Model | FID-50K | Inception Score | #params |
60 | |------------------------------------------------------------------------|---------|-----------------|---------|
61 | | [MAR-B](https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0) | 2.31 | 281.7 | 208M |
62 | | [MAR-L](https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0) | 1.78 | 296.0 | 479M |
63 | | [MAR-H](https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0) | 1.55 | 303.7 | 943M |
64 |
65 | ### (Optional) Caching VAE Latents
66 |
67 | Given that our data augmentation consists of simple center cropping and random flipping,
68 | the VAE latents can be pre-computed and saved to `CACHED_PATH` to save computations during MAR training:
69 |
70 | ```
71 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
72 | main_cache.py \
73 | --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 \
74 | --batch_size 128 \
75 | --data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}
76 | ```
77 |
78 | ## Usage
79 |
80 | ### Demo
81 | Run our interactive visualization [demo](http://colab.research.google.com/github/LTH14/mar/blob/main/demo/run_mar.ipynb) using Colab notebook!
82 |
83 | ### Local Gradio App
84 |
85 | ```
86 | python demo/gradio_app.py
87 | ```
88 |
89 |
90 |
91 | ### Training
92 | Script for the default setting (MAR-L, DiffLoss MLP with 3 blocks and a width of 1024 channels, 400 epochs):
93 | ```
94 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
95 | main_mar.py \
96 | --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
97 | --model mar_large --diffloss_d 3 --diffloss_w 1024 \
98 | --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
99 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
100 | --data_path ${IMAGENET_PATH}
101 | ```
102 | - Training time is ~1d7h on 32 H100 GPUs with `--batch_size 64`.
103 | - Add `--online_eval` to evaluate FID during training (every 40 epochs).
104 | - (Optional) To train with cached VAE latents, add `--use_cached --cached_path ${CACHED_PATH}` to the arguments.
105 | Training time with cached latents is ~1d11h on 16 H100 GPUs with `--batch_size 128` (nearly 2x faster than without caching).
106 | - (Optional) To save GPU memory during training by using gradient checkpointing (thanks to @Jiawei-Yang), add `--grad_checkpointing` to the arguments.
107 | Note that this may slightly reduce training speed.
108 |
109 | ### Evaluation (ImageNet 256x256)
110 |
111 | Evaluate MAR-B (DiffLoss MLP with 6 blocks and a width of 1024 channels, 800 epochs) with classifier-free guidance:
112 | ```
113 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
114 | main_mar.py \
115 | --model mar_base --diffloss_d 6 --diffloss_w 1024 \
116 | --eval_bsz 256 --num_images 50000 \
117 | --num_iter 256 --num_sampling_steps 100 --cfg 2.9 --cfg_schedule linear --temperature 1.0 \
118 | --output_dir pretrained_models/mar/mar_base \
119 | --resume pretrained_models/mar/mar_base \
120 | --data_path ${IMAGENET_PATH} --evaluate
121 | ```
122 |
123 | Evaluate MAR-L (DiffLoss MLP with 8 blocks and a width of 1280 channels, 800 epochs) with classifier-free guidance:
124 | ```
125 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
126 | main_mar.py \
127 | --model mar_large --diffloss_d 8 --diffloss_w 1280 \
128 | --eval_bsz 256 --num_images 50000 \
129 | --num_iter 256 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
130 | --output_dir pretrained_models/mar/mar_large \
131 | --resume pretrained_models/mar/mar_large \
132 | --data_path ${IMAGENET_PATH} --evaluate
133 | ```
134 |
135 | Evaluate MAR-H (DiffLoss MLP with 12 blocks and a width of 1536 channels, 800 epochs) with classifier-free guidance:
136 | ```
137 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
138 | main_mar.py \
139 | --model mar_huge --diffloss_d 12 --diffloss_w 1536 \
140 | --eval_bsz 128 --num_images 50000 \
141 | --num_iter 256 --num_sampling_steps 100 --cfg 3.2 --cfg_schedule linear --temperature 1.0 \
142 | --output_dir pretrained_models/mar/mar_huge \
143 | --resume pretrained_models/mar/mar_huge \
144 | --data_path ${IMAGENET_PATH} --evaluate
145 | ```
146 |
147 | - Set `--cfg 1.0 --temperature 0.95` to evaluate without classifier-free guidance.
148 | - Generation speed can be significantly increased by reducing the number of autoregressive iterations (e.g., `--num_iter 64`).
149 |
150 | ## Acknowledgements
151 | We thank Congyue Deng and Xinlei Chen for helpful discussion. We thank
152 | Google TPU Research Cloud (TRC) for granting us access to TPUs, and Google Cloud Platform for
153 | supporting GPU resources.
154 |
155 | A large portion of codes in this repo is based on [MAE](https://github.com/facebookresearch/mae), [MAGE](https://github.com/LTH14/mage) and [DiT](https://github.com/facebookresearch/DiT).
156 |
157 | ## Contact
158 |
159 | If you have any questions, feel free to contact me through email (tianhong@mit.edu). Enjoy!
160 |
--------------------------------------------------------------------------------
/demo/gradio_app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from diffusers import DiffusionPipeline
3 | import os
4 | import torch
5 | import shutil
6 | import spaces
7 |
8 |
9 | def find_cuda():
10 | # Check if CUDA_HOME or CUDA_PATH environment variables are set
11 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
12 |
13 | if cuda_home and os.path.exists(cuda_home):
14 | return cuda_home
15 |
16 | # Search for the nvcc executable in the system's PATH
17 | nvcc_path = shutil.which('nvcc')
18 |
19 | if nvcc_path:
20 | # Remove the 'bin/nvcc' part to get the CUDA installation path
21 | cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
22 | return cuda_path
23 |
24 | return None
25 |
26 |
27 | cuda_path = find_cuda()
28 |
29 | if cuda_path:
30 | print(f"CUDA installation found at: {cuda_path}")
31 | else:
32 | print("CUDA installation not found")
33 |
34 | # check if cuda is available
35 | device = "cuda" if torch.cuda.is_available() else "cpu"
36 |
37 | # load the pipeline/model
38 | pipeline = DiffusionPipeline.from_pretrained("jadechoghari/mar", trust_remote_code=True,
39 | custom_pipeline="jadechoghari/mar")
40 |
41 |
42 | # function that generates images
43 | @spaces.GPU
44 | def generate_image(seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule):
45 | generated_image = pipeline(
46 | model_type="mar_huge", # using mar_huge
47 | seed=seed,
48 | num_ar_steps=num_ar_steps,
49 | class_labels=[int(label.strip()) for label in class_labels.split(',')],
50 | cfg_scale=cfg_scale,
51 | cfg_schedule=cfg_schedule,
52 | output_dir="./images"
53 | )
54 | return generated_image
55 |
56 |
57 | with gr.Blocks() as demo:
58 | gr.Markdown("""
59 | # MAR Image Generation Demo 🚀
60 |
61 | Welcome to the demo for **MAR** (Masked Autoregressive Model), a novel approach to image generation that eliminates the need for vector quantization. MAR uses a diffusion process to generate images in a continuous-valued space, resulting in faster, more efficient, and higher-quality outputs.
62 |
63 | Simply adjust the parameters below to create your custom images in real-time.
64 |
65 | Make sure to provide valid **ImageNet class labels** to see the translation of text to image. For a complete list of ImageNet classes, check out [this reference](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/).
66 |
67 | For more details, visit the [GitHub repository](https://github.com/LTH14/mar).
68 | """)
69 |
70 | seed = gr.Number(value=0, label="Seed")
71 | num_ar_steps = gr.Slider(minimum=1, maximum=256, value=64, label="Number of AR Steps")
72 | class_labels = gr.Textbox(value="207, 360, 388, 113, 355, 980, 323, 979",
73 | label="Class Labels (comma-separated ImageNet labels)")
74 | cfg_scale = gr.Slider(minimum=1, maximum=10, value=4, label="CFG Scale")
75 | cfg_schedule = gr.Dropdown(choices=["constant", "linear"], label="CFG Schedule", value="constant")
76 |
77 | image_output = gr.Image(label="Generated Image")
78 |
79 | generate_button = gr.Button("Generate Image")
80 |
81 | # we link the button to the function and display the output
82 | generate_button.click(generate_image, inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule],
83 | outputs=image_output)
84 |
85 | gr.Interface(
86 | generate_image,
87 | inputs=[seed, num_ar_steps, class_labels, cfg_scale, cfg_schedule],
88 | outputs=image_output,
89 | )
90 |
91 | demo.launch()
--------------------------------------------------------------------------------
/demo/visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LTH14/mar/fe470ac24afbee924668d8c5c83e9fec60af3a73/demo/visual.png
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos
2 | # DiT: https://github.com/facebookresearch/DiT/diffusion
3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6 |
7 | from . import gaussian_diffusion as gd
8 | from .respace import SpacedDiffusion, space_timesteps
9 |
10 |
11 | def create_diffusion(
12 | timestep_respacing,
13 | noise_schedule="linear",
14 | use_kl=False,
15 | sigma_small=False,
16 | predict_xstart=False,
17 | learn_sigma=True,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000
20 | ):
21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22 | if use_kl:
23 | loss_type = gd.LossType.RESCALED_KL
24 | elif rescale_learned_sigmas:
25 | loss_type = gd.LossType.RESCALED_MSE
26 | else:
27 | loss_type = gd.LossType.MSE
28 | if timestep_respacing is None or timestep_respacing == "":
29 | timestep_respacing = [diffusion_steps]
30 | return SpacedDiffusion(
31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32 | betas=betas,
33 | model_mean_type=(
34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35 | ),
36 | model_var_type=(
37 | (
38 | gd.ModelVarType.FIXED_LARGE
39 | if not sigma_small
40 | else gd.ModelVarType.FIXED_SMALL
41 | )
42 | if not learn_sigma
43 | else gd.ModelVarType.LEARNED_RANGE
44 | ),
45 | loss_type=loss_type
46 | # rescale_timesteps=rescale_timesteps,
47 | )
48 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a Gaussian distribution discretizing to a
50 | given image.
51 | :param x: the target images. It is assumed that this was uint8 values,
52 | rescaled to the range [-1, 1].
53 | :param means: the Gaussian mean Tensor.
54 | :param log_scales: the Gaussian log stddev Tensor.
55 | :return: a tensor like x of log probabilities (in nats).
56 | """
57 | assert x.shape == means.shape == log_scales.shape
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
61 | cdf_plus = approx_standard_normal_cdf(plus_in)
62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
63 | cdf_min = approx_standard_normal_cdf(min_in)
64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
66 | cdf_delta = cdf_plus - cdf_min
67 | log_probs = th.where(
68 | x < -0.999,
69 | log_cdf_plus,
70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
71 | )
72 | assert log_probs.shape == x.shape
73 | return log_probs
74 |
--------------------------------------------------------------------------------
/diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 |
7 | import math
8 |
9 | import numpy as np
10 | import torch as th
11 | import enum
12 |
13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14 |
15 |
16 | def mean_flat(tensor):
17 | """
18 | Take the mean over all non-batch dimensions.
19 | """
20 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
21 |
22 |
23 | class ModelMeanType(enum.Enum):
24 | """
25 | Which type of output the model predicts.
26 | """
27 |
28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29 | START_X = enum.auto() # the model predicts x_0
30 | EPSILON = enum.auto() # the model predicts epsilon
31 |
32 |
33 | class ModelVarType(enum.Enum):
34 | """
35 | What is used as the model's output variance.
36 | The LEARNED_RANGE option has been added to allow the model to predict
37 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38 | """
39 |
40 | LEARNED = enum.auto()
41 | FIXED_SMALL = enum.auto()
42 | FIXED_LARGE = enum.auto()
43 | LEARNED_RANGE = enum.auto()
44 |
45 |
46 | class LossType(enum.Enum):
47 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48 | RESCALED_MSE = (
49 | enum.auto()
50 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
51 | KL = enum.auto() # use the variational lower-bound
52 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53 |
54 | def is_vb(self):
55 | return self == LossType.KL or self == LossType.RESCALED_KL
56 |
57 |
58 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
61 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62 | return betas
63 |
64 |
65 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66 | """
67 | This is the deprecated API for creating beta schedules.
68 | See get_named_beta_schedule() for the new library of schedules.
69 | """
70 | if beta_schedule == "quad":
71 | betas = (
72 | np.linspace(
73 | beta_start ** 0.5,
74 | beta_end ** 0.5,
75 | num_diffusion_timesteps,
76 | dtype=np.float64,
77 | )
78 | ** 2
79 | )
80 | elif beta_schedule == "linear":
81 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82 | elif beta_schedule == "warmup10":
83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84 | elif beta_schedule == "warmup50":
85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86 | elif beta_schedule == "const":
87 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89 | betas = 1.0 / np.linspace(
90 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91 | )
92 | else:
93 | raise NotImplementedError(beta_schedule)
94 | assert betas.shape == (num_diffusion_timesteps,)
95 | return betas
96 |
97 |
98 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99 | """
100 | Get a pre-defined beta schedule for the given name.
101 | The beta schedule library consists of beta schedules which remain similar
102 | in the limit of num_diffusion_timesteps.
103 | Beta schedules may be added, but should not be removed or changed once
104 | they are committed to maintain backwards compatibility.
105 | """
106 | if schedule_name == "linear":
107 | # Linear schedule from Ho et al, extended to work for any number of
108 | # diffusion steps.
109 | scale = 1000 / num_diffusion_timesteps
110 | return get_beta_schedule(
111 | "linear",
112 | beta_start=scale * 0.0001,
113 | beta_end=scale * 0.02,
114 | num_diffusion_timesteps=num_diffusion_timesteps,
115 | )
116 | elif schedule_name == "cosine":
117 | return betas_for_alpha_bar(
118 | num_diffusion_timesteps,
119 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120 | )
121 | else:
122 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123 |
124 |
125 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126 | """
127 | Create a beta schedule that discretizes the given alpha_t_bar function,
128 | which defines the cumulative product of (1-beta) over time from t = [0,1].
129 | :param num_diffusion_timesteps: the number of betas to produce.
130 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131 | produces the cumulative product of (1-beta) up to that
132 | part of the diffusion process.
133 | :param max_beta: the maximum beta to use; use values lower than 1 to
134 | prevent singularities.
135 | """
136 | betas = []
137 | for i in range(num_diffusion_timesteps):
138 | t1 = i / num_diffusion_timesteps
139 | t2 = (i + 1) / num_diffusion_timesteps
140 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141 | return np.array(betas)
142 |
143 |
144 | class GaussianDiffusion:
145 | """
146 | Utilities for training and sampling diffusion models.
147 | Original ported from this codebase:
148 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
150 | starting at T and going to 1.
151 | """
152 |
153 | def __init__(
154 | self,
155 | *,
156 | betas,
157 | model_mean_type,
158 | model_var_type,
159 | loss_type
160 | ):
161 |
162 | self.model_mean_type = model_mean_type
163 | self.model_var_type = model_var_type
164 | self.loss_type = loss_type
165 |
166 | # Use float64 for accuracy.
167 | betas = np.array(betas, dtype=np.float64)
168 | self.betas = betas
169 | assert len(betas.shape) == 1, "betas must be 1-D"
170 | assert (betas > 0).all() and (betas <= 1).all()
171 |
172 | self.num_timesteps = int(betas.shape[0])
173 |
174 | alphas = 1.0 - betas
175 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
176 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179 |
180 | # calculations for diffusion q(x_t | x_{t-1}) and others
181 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186 |
187 | # calculations for posterior q(x_{t-1} | x_t, x_0)
188 | self.posterior_variance = (
189 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190 | )
191 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192 | self.posterior_log_variance_clipped = np.log(
193 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
194 | ) if len(self.posterior_variance) > 1 else np.array([])
195 |
196 | self.posterior_mean_coef1 = (
197 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198 | )
199 | self.posterior_mean_coef2 = (
200 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201 | )
202 |
203 | def q_mean_variance(self, x_start, t):
204 | """
205 | Get the distribution q(x_t | x_0).
206 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
207 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209 | """
210 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213 | return mean, variance, log_variance
214 |
215 | def q_sample(self, x_start, t, noise=None):
216 | """
217 | Diffuse the data for a given number of diffusion steps.
218 | In other words, sample from q(x_t | x_0).
219 | :param x_start: the initial data batch.
220 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221 | :param noise: if specified, the split-out normal noise.
222 | :return: A noisy version of x_start.
223 | """
224 | if noise is None:
225 | noise = th.randn_like(x_start)
226 | assert noise.shape == x_start.shape
227 | return (
228 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230 | )
231 |
232 | def q_posterior_mean_variance(self, x_start, x_t, t):
233 | """
234 | Compute the mean and variance of the diffusion posterior:
235 | q(x_{t-1} | x_t, x_0)
236 | """
237 | assert x_start.shape == x_t.shape
238 | posterior_mean = (
239 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241 | )
242 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243 | posterior_log_variance_clipped = _extract_into_tensor(
244 | self.posterior_log_variance_clipped, t, x_t.shape
245 | )
246 | assert (
247 | posterior_mean.shape[0]
248 | == posterior_variance.shape[0]
249 | == posterior_log_variance_clipped.shape[0]
250 | == x_start.shape[0]
251 | )
252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
253 |
254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255 | """
256 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257 | the initial x, x_0.
258 | :param model: the model, which takes a signal and a batch of timesteps
259 | as input.
260 | :param x: the [N x C x ...] tensor at time t.
261 | :param t: a 1-D Tensor of timesteps.
262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263 | :param denoised_fn: if not None, a function which applies to the
264 | x_start prediction before it is used to sample. Applies before
265 | clip_denoised.
266 | :param model_kwargs: if not None, a dict of extra keyword arguments to
267 | pass to the model. This can be used for conditioning.
268 | :return: a dict with the following keys:
269 | - 'mean': the model mean output.
270 | - 'variance': the model variance output.
271 | - 'log_variance': the log of 'variance'.
272 | - 'pred_xstart': the prediction for x_0.
273 | """
274 | if model_kwargs is None:
275 | model_kwargs = {}
276 |
277 | B, C = x.shape[:2]
278 | assert t.shape == (B,)
279 | model_output = model(x, t, **model_kwargs)
280 | if isinstance(model_output, tuple):
281 | model_output, extra = model_output
282 | else:
283 | extra = None
284 |
285 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286 | assert model_output.shape == (B, C * 2, *x.shape[2:])
287 | model_output, model_var_values = th.split(model_output, C, dim=1)
288 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290 | # The model_var_values is [-1, 1] for [min_var, max_var].
291 | frac = (model_var_values + 1) / 2
292 | model_log_variance = frac * max_log + (1 - frac) * min_log
293 | model_variance = th.exp(model_log_variance)
294 | else:
295 | model_variance, model_log_variance = {
296 | # for fixedlarge, we set the initial (log-)variance like so
297 | # to get a better decoder log likelihood.
298 | ModelVarType.FIXED_LARGE: (
299 | np.append(self.posterior_variance[1], self.betas[1:]),
300 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301 | ),
302 | ModelVarType.FIXED_SMALL: (
303 | self.posterior_variance,
304 | self.posterior_log_variance_clipped,
305 | ),
306 | }[self.model_var_type]
307 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
308 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309 |
310 | def process_xstart(x):
311 | if denoised_fn is not None:
312 | x = denoised_fn(x)
313 | if clip_denoised:
314 | return x.clamp(-1, 1)
315 | return x
316 |
317 | if self.model_mean_type == ModelMeanType.START_X:
318 | pred_xstart = process_xstart(model_output)
319 | else:
320 | pred_xstart = process_xstart(
321 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322 | )
323 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324 |
325 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326 | return {
327 | "mean": model_mean,
328 | "variance": model_variance,
329 | "log_variance": model_log_variance,
330 | "pred_xstart": pred_xstart,
331 | "extra": extra,
332 | }
333 |
334 | def _predict_xstart_from_eps(self, x_t, t, eps):
335 | assert x_t.shape == eps.shape
336 | return (
337 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339 | )
340 |
341 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342 | return (
343 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345 |
346 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347 | """
348 | Compute the mean for the previous step, given a function cond_fn that
349 | computes the gradient of a conditional log probability with respect to
350 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351 | condition on y.
352 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353 | """
354 | gradient = cond_fn(x, t, **model_kwargs)
355 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356 | return new_mean
357 |
358 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359 | """
360 | Compute what the p_mean_variance output would have been, should the
361 | model's score function be conditioned by cond_fn.
362 | See condition_mean() for details on cond_fn.
363 | Unlike condition_mean(), this instead uses the conditioning strategy
364 | from Song et al (2020).
365 | """
366 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367 |
368 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370 |
371 | out = p_mean_var.copy()
372 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374 | return out
375 |
376 | def p_sample(
377 | self,
378 | model,
379 | x,
380 | t,
381 | clip_denoised=True,
382 | denoised_fn=None,
383 | cond_fn=None,
384 | model_kwargs=None,
385 | temperature=1.0
386 | ):
387 | """
388 | Sample x_{t-1} from the model at the given timestep.
389 | :param model: the model to sample from.
390 | :param x: the current tensor at x_{t-1}.
391 | :param t: the value of t, starting at 0 for the first diffusion step.
392 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393 | :param denoised_fn: if not None, a function which applies to the
394 | x_start prediction before it is used to sample.
395 | :param cond_fn: if not None, this is a gradient function that acts
396 | similarly to the model.
397 | :param model_kwargs: if not None, a dict of extra keyword arguments to
398 | pass to the model. This can be used for conditioning.
399 | :param temperature: temperature scaling during Diff Loss sampling.
400 | :return: a dict containing the following keys:
401 | - 'sample': a random sample from the model.
402 | - 'pred_xstart': a prediction of x_0.
403 | """
404 | out = self.p_mean_variance(
405 | model,
406 | x,
407 | t,
408 | clip_denoised=clip_denoised,
409 | denoised_fn=denoised_fn,
410 | model_kwargs=model_kwargs,
411 | )
412 | noise = th.randn_like(x)
413 | nonzero_mask = (
414 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
415 | ) # no noise when t == 0
416 | if cond_fn is not None:
417 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
418 | # scale the noise by temperature
419 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
420 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
421 |
422 | def p_sample_loop(
423 | self,
424 | model,
425 | shape,
426 | noise=None,
427 | clip_denoised=True,
428 | denoised_fn=None,
429 | cond_fn=None,
430 | model_kwargs=None,
431 | device=None,
432 | progress=False,
433 | temperature=1.0,
434 | ):
435 | """
436 | Generate samples from the model.
437 | :param model: the model module.
438 | :param shape: the shape of the samples, (N, C, H, W).
439 | :param noise: if specified, the noise from the encoder to sample.
440 | Should be of the same shape as `shape`.
441 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
442 | :param denoised_fn: if not None, a function which applies to the
443 | x_start prediction before it is used to sample.
444 | :param cond_fn: if not None, this is a gradient function that acts
445 | similarly to the model.
446 | :param model_kwargs: if not None, a dict of extra keyword arguments to
447 | pass to the model. This can be used for conditioning.
448 | :param device: if specified, the device to create the samples on.
449 | If not specified, use a model parameter's device.
450 | :param progress: if True, show a tqdm progress bar.
451 | :param temperature: temperature scaling during Diff Loss sampling.
452 | :return: a non-differentiable batch of samples.
453 | """
454 | final = None
455 | for sample in self.p_sample_loop_progressive(
456 | model,
457 | shape,
458 | noise=noise,
459 | clip_denoised=clip_denoised,
460 | denoised_fn=denoised_fn,
461 | cond_fn=cond_fn,
462 | model_kwargs=model_kwargs,
463 | device=device,
464 | progress=progress,
465 | temperature=temperature,
466 | ):
467 | final = sample
468 | return final["sample"]
469 |
470 | def p_sample_loop_progressive(
471 | self,
472 | model,
473 | shape,
474 | noise=None,
475 | clip_denoised=True,
476 | denoised_fn=None,
477 | cond_fn=None,
478 | model_kwargs=None,
479 | device=None,
480 | progress=False,
481 | temperature=1.0,
482 | ):
483 | """
484 | Generate samples from the model and yield intermediate samples from
485 | each timestep of diffusion.
486 | Arguments are the same as p_sample_loop().
487 | Returns a generator over dicts, where each dict is the return value of
488 | p_sample().
489 | """
490 | assert isinstance(shape, (tuple, list))
491 | if noise is not None:
492 | img = noise
493 | else:
494 | img = th.randn(*shape).cuda()
495 | indices = list(range(self.num_timesteps))[::-1]
496 |
497 | if progress:
498 | # Lazy import so that we don't depend on tqdm.
499 | from tqdm.auto import tqdm
500 |
501 | indices = tqdm(indices)
502 |
503 | for i in indices:
504 | t = th.tensor([i] * shape[0]).cuda()
505 | with th.no_grad():
506 | out = self.p_sample(
507 | model,
508 | img,
509 | t,
510 | clip_denoised=clip_denoised,
511 | denoised_fn=denoised_fn,
512 | cond_fn=cond_fn,
513 | model_kwargs=model_kwargs,
514 | temperature=temperature,
515 | )
516 | yield out
517 | img = out["sample"]
518 |
519 | def ddim_sample(
520 | self,
521 | model,
522 | x,
523 | t,
524 | clip_denoised=True,
525 | denoised_fn=None,
526 | cond_fn=None,
527 | model_kwargs=None,
528 | eta=0.0,
529 | ):
530 | """
531 | Sample x_{t-1} from the model using DDIM.
532 | Same usage as p_sample().
533 | """
534 | out = self.p_mean_variance(
535 | model,
536 | x,
537 | t,
538 | clip_denoised=clip_denoised,
539 | denoised_fn=denoised_fn,
540 | model_kwargs=model_kwargs,
541 | )
542 | if cond_fn is not None:
543 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
544 |
545 | # Usually our model outputs epsilon, but we re-derive it
546 | # in case we used x_start or x_prev prediction.
547 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
548 |
549 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
550 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
551 | sigma = (
552 | eta
553 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
554 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
555 | )
556 | # Equation 12.
557 | noise = th.randn_like(x)
558 | mean_pred = (
559 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
560 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
561 | )
562 | nonzero_mask = (
563 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
564 | ) # no noise when t == 0
565 | sample = mean_pred + nonzero_mask * sigma * noise
566 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
567 |
568 | def ddim_reverse_sample(
569 | self,
570 | model,
571 | x,
572 | t,
573 | clip_denoised=True,
574 | denoised_fn=None,
575 | cond_fn=None,
576 | model_kwargs=None,
577 | eta=0.0,
578 | ):
579 | """
580 | Sample x_{t+1} from the model using DDIM reverse ODE.
581 | """
582 | assert eta == 0.0, "Reverse ODE only for deterministic path"
583 | out = self.p_mean_variance(
584 | model,
585 | x,
586 | t,
587 | clip_denoised=clip_denoised,
588 | denoised_fn=denoised_fn,
589 | model_kwargs=model_kwargs,
590 | )
591 | if cond_fn is not None:
592 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
593 | # Usually our model outputs epsilon, but we re-derive it
594 | # in case we used x_start or x_prev prediction.
595 | eps = (
596 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
597 | - out["pred_xstart"]
598 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
599 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
600 |
601 | # Equation 12. reversed
602 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
603 |
604 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
605 |
606 | def ddim_sample_loop(
607 | self,
608 | model,
609 | shape,
610 | noise=None,
611 | clip_denoised=True,
612 | denoised_fn=None,
613 | cond_fn=None,
614 | model_kwargs=None,
615 | device=None,
616 | progress=False,
617 | eta=0.0,
618 | ):
619 | """
620 | Generate samples from the model using DDIM.
621 | Same usage as p_sample_loop().
622 | """
623 | final = None
624 | for sample in self.ddim_sample_loop_progressive(
625 | model,
626 | shape,
627 | noise=noise,
628 | clip_denoised=clip_denoised,
629 | denoised_fn=denoised_fn,
630 | cond_fn=cond_fn,
631 | model_kwargs=model_kwargs,
632 | device=device,
633 | progress=progress,
634 | eta=eta,
635 | ):
636 | final = sample
637 | return final["sample"]
638 |
639 | def ddim_sample_loop_progressive(
640 | self,
641 | model,
642 | shape,
643 | noise=None,
644 | clip_denoised=True,
645 | denoised_fn=None,
646 | cond_fn=None,
647 | model_kwargs=None,
648 | device=None,
649 | progress=False,
650 | eta=0.0,
651 | ):
652 | """
653 | Use DDIM to sample from the model and yield intermediate samples from
654 | each timestep of DDIM.
655 | Same usage as p_sample_loop_progressive().
656 | """
657 | assert isinstance(shape, (tuple, list))
658 | if noise is not None:
659 | img = noise
660 | else:
661 | img = th.randn(*shape).cuda()
662 | indices = list(range(self.num_timesteps))[::-1]
663 |
664 | if progress:
665 | # Lazy import so that we don't depend on tqdm.
666 | from tqdm.auto import tqdm
667 |
668 | indices = tqdm(indices)
669 |
670 | for i in indices:
671 | t = th.tensor([i] * shape[0]).cuda()
672 | with th.no_grad():
673 | out = self.ddim_sample(
674 | model,
675 | img,
676 | t,
677 | clip_denoised=clip_denoised,
678 | denoised_fn=denoised_fn,
679 | cond_fn=cond_fn,
680 | model_kwargs=model_kwargs,
681 | eta=eta,
682 | )
683 | yield out
684 | img = out["sample"]
685 |
686 | def _vb_terms_bpd(
687 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
688 | ):
689 | """
690 | Get a term for the variational lower-bound.
691 | The resulting units are bits (rather than nats, as one might expect).
692 | This allows for comparison to other papers.
693 | :return: a dict with the following keys:
694 | - 'output': a shape [N] tensor of NLLs or KLs.
695 | - 'pred_xstart': the x_0 predictions.
696 | """
697 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
698 | x_start=x_start, x_t=x_t, t=t
699 | )
700 | out = self.p_mean_variance(
701 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
702 | )
703 | kl = normal_kl(
704 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
705 | )
706 | kl = mean_flat(kl) / np.log(2.0)
707 |
708 | decoder_nll = -discretized_gaussian_log_likelihood(
709 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
710 | )
711 | assert decoder_nll.shape == x_start.shape
712 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
713 |
714 | # At the first timestep return the decoder NLL,
715 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
716 | output = th.where((t == 0), decoder_nll, kl)
717 | return {"output": output, "pred_xstart": out["pred_xstart"]}
718 |
719 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
720 | """
721 | Compute training losses for a single timestep.
722 | :param model: the model to evaluate loss on.
723 | :param x_start: the [N x C x ...] tensor of inputs.
724 | :param t: a batch of timestep indices.
725 | :param model_kwargs: if not None, a dict of extra keyword arguments to
726 | pass to the model. This can be used for conditioning.
727 | :param noise: if specified, the specific Gaussian noise to try to remove.
728 | :return: a dict with the key "loss" containing a tensor of shape [N].
729 | Some mean or variance settings may also have other keys.
730 | """
731 | if model_kwargs is None:
732 | model_kwargs = {}
733 | if noise is None:
734 | noise = th.randn_like(x_start)
735 | x_t = self.q_sample(x_start, t, noise=noise)
736 |
737 | terms = {}
738 |
739 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
740 | terms["loss"] = self._vb_terms_bpd(
741 | model=model,
742 | x_start=x_start,
743 | x_t=x_t,
744 | t=t,
745 | clip_denoised=False,
746 | model_kwargs=model_kwargs,
747 | )["output"]
748 | if self.loss_type == LossType.RESCALED_KL:
749 | terms["loss"] *= self.num_timesteps
750 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
751 | model_output = model(x_t, t, **model_kwargs)
752 |
753 | if self.model_var_type in [
754 | ModelVarType.LEARNED,
755 | ModelVarType.LEARNED_RANGE,
756 | ]:
757 | B, C = x_t.shape[:2]
758 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
759 | model_output, model_var_values = th.split(model_output, C, dim=1)
760 | # Learn the variance using the variational bound, but don't let
761 | # it affect our mean prediction.
762 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
763 | terms["vb"] = self._vb_terms_bpd(
764 | model=lambda *args, r=frozen_out: r,
765 | x_start=x_start,
766 | x_t=x_t,
767 | t=t,
768 | clip_denoised=False,
769 | )["output"]
770 | if self.loss_type == LossType.RESCALED_MSE:
771 | # Divide by 1000 for equivalence with initial implementation.
772 | # Without a factor of 1/1000, the VB term hurts the MSE term.
773 | terms["vb"] *= self.num_timesteps / 1000.0
774 |
775 | target = {
776 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
777 | x_start=x_start, x_t=x_t, t=t
778 | )[0],
779 | ModelMeanType.START_X: x_start,
780 | ModelMeanType.EPSILON: noise,
781 | }[self.model_mean_type]
782 | assert model_output.shape == target.shape == x_start.shape
783 | terms["mse"] = mean_flat((target - model_output) ** 2)
784 | if "vb" in terms:
785 | terms["loss"] = terms["mse"] + terms["vb"]
786 | else:
787 | terms["loss"] = terms["mse"]
788 | else:
789 | raise NotImplementedError(self.loss_type)
790 |
791 | return terms
792 |
793 | def _prior_bpd(self, x_start):
794 | """
795 | Get the prior KL term for the variational lower-bound, measured in
796 | bits-per-dim.
797 | This term can't be optimized, as it only depends on the encoder.
798 | :param x_start: the [N x C x ...] tensor of inputs.
799 | :return: a batch of [N] KL values (in bits), one per batch element.
800 | """
801 | batch_size = x_start.shape[0]
802 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
803 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
804 | kl_prior = normal_kl(
805 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
806 | )
807 | return mean_flat(kl_prior) / np.log(2.0)
808 |
809 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
810 | """
811 | Compute the entire variational lower-bound, measured in bits-per-dim,
812 | as well as other related quantities.
813 | :param model: the model to evaluate loss on.
814 | :param x_start: the [N x C x ...] tensor of inputs.
815 | :param clip_denoised: if True, clip denoised samples.
816 | :param model_kwargs: if not None, a dict of extra keyword arguments to
817 | pass to the model. This can be used for conditioning.
818 | :return: a dict containing the following keys:
819 | - total_bpd: the total variational lower-bound, per batch element.
820 | - prior_bpd: the prior term in the lower-bound.
821 | - vb: an [N x T] tensor of terms in the lower-bound.
822 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
823 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
824 | """
825 | device = x_start.device
826 | batch_size = x_start.shape[0]
827 |
828 | vb = []
829 | xstart_mse = []
830 | mse = []
831 | for t in list(range(self.num_timesteps))[::-1]:
832 | t_batch = th.tensor([t] * batch_size, device=device)
833 | noise = th.randn_like(x_start)
834 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
835 | # Calculate VLB term at the current timestep
836 | with th.no_grad():
837 | out = self._vb_terms_bpd(
838 | model,
839 | x_start=x_start,
840 | x_t=x_t,
841 | t=t_batch,
842 | clip_denoised=clip_denoised,
843 | model_kwargs=model_kwargs,
844 | )
845 | vb.append(out["output"])
846 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
847 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
848 | mse.append(mean_flat((eps - noise) ** 2))
849 |
850 | vb = th.stack(vb, dim=1)
851 | xstart_mse = th.stack(xstart_mse, dim=1)
852 | mse = th.stack(mse, dim=1)
853 |
854 | prior_bpd = self._prior_bpd(x_start)
855 | total_bpd = vb.sum(dim=1) + prior_bpd
856 | return {
857 | "total_bpd": total_bpd,
858 | "prior_bpd": prior_bpd,
859 | "vb": vb,
860 | "xstart_mse": xstart_mse,
861 | "mse": mse,
862 | }
863 |
864 |
865 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
866 | """
867 | Extract values from a 1-D numpy array for a batch of indices.
868 | :param arr: the 1-D numpy array.
869 | :param timesteps: a tensor of indices into the array to extract.
870 | :param broadcast_shape: a larger shape of K dimensions with the batch
871 | dimension equal to the length of timesteps.
872 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
873 | """
874 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
875 | while len(res.shape) < len(broadcast_shape):
876 | res = res[..., None]
877 | return res + th.zeros(broadcast_shape, device=timesteps.device)
878 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/engine_mar.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 |
5 | import torch
6 |
7 | import util.misc as misc
8 | import util.lr_sched as lr_sched
9 | from models.vae import DiagonalGaussianDistribution
10 | import torch_fidelity
11 | import shutil
12 | import cv2
13 | import numpy as np
14 | import os
15 | import copy
16 | import time
17 |
18 |
19 | def update_ema(target_params, source_params, rate=0.99):
20 | """
21 | Update target parameters to be closer to those of source parameters using
22 | an exponential moving average.
23 |
24 | :param target_params: the target parameter sequence.
25 | :param source_params: the source parameter sequence.
26 | :param rate: the EMA rate (closer to 1 means slower).
27 | """
28 | for targ, src in zip(target_params, source_params):
29 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
30 |
31 |
32 | def train_one_epoch(model, vae,
33 | model_params, ema_params,
34 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
35 | device: torch.device, epoch: int, loss_scaler,
36 | log_writer=None,
37 | args=None):
38 | model.train(True)
39 | metric_logger = misc.MetricLogger(delimiter=" ")
40 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
41 | header = 'Epoch: [{}]'.format(epoch)
42 | print_freq = 20
43 |
44 | optimizer.zero_grad()
45 |
46 | if log_writer is not None:
47 | print('log_dir: {}'.format(log_writer.log_dir))
48 |
49 | for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
50 |
51 | # we use a per iteration (instead of per epoch) lr scheduler
52 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
53 |
54 | samples = samples.to(device, non_blocking=True)
55 | labels = labels.to(device, non_blocking=True)
56 |
57 | with torch.no_grad():
58 | if args.use_cached:
59 | moments = samples
60 | posterior = DiagonalGaussianDistribution(moments)
61 | else:
62 | posterior = vae.encode(samples)
63 |
64 | # normalize the std of latent to be 1. Change it if you use a different tokenizer
65 | x = posterior.sample().mul_(0.2325)
66 |
67 | # forward
68 | with torch.cuda.amp.autocast():
69 | loss = model(x, labels)
70 |
71 | loss_value = loss.item()
72 |
73 | if not math.isfinite(loss_value):
74 | print("Loss is {}, stopping training".format(loss_value))
75 | sys.exit(1)
76 |
77 | loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
78 | optimizer.zero_grad()
79 |
80 | torch.cuda.synchronize()
81 |
82 | update_ema(ema_params, model_params, rate=args.ema_rate)
83 |
84 | metric_logger.update(loss=loss_value)
85 |
86 | lr = optimizer.param_groups[0]["lr"]
87 | metric_logger.update(lr=lr)
88 |
89 | loss_value_reduce = misc.all_reduce_mean(loss_value)
90 | if log_writer is not None:
91 | """ We use epoch_1000x as the x-axis in tensorboard.
92 | This calibrates different curves when batch size changes.
93 | """
94 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
95 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
96 | log_writer.add_scalar('lr', lr, epoch_1000x)
97 |
98 | # gather the stats from all processes
99 | metric_logger.synchronize_between_processes()
100 | print("Averaged stats:", metric_logger)
101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
102 |
103 |
104 | def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0,
105 | use_ema=True):
106 | model_without_ddp.eval()
107 | num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1
108 | save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter,
109 | args.num_sampling_steps,
110 | args.temperature,
111 | args.cfg_schedule,
112 | cfg,
113 | args.num_images))
114 | if use_ema:
115 | save_folder = save_folder + "_ema"
116 | if args.evaluate:
117 | save_folder = save_folder + "_evaluate"
118 | print("Save to:", save_folder)
119 | if misc.get_rank() == 0:
120 | if not os.path.exists(save_folder):
121 | os.makedirs(save_folder)
122 |
123 | # switch to ema params
124 | if use_ema:
125 | model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
126 | ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
127 | for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
128 | assert name in ema_state_dict
129 | ema_state_dict[name] = ema_params[i]
130 | print("Switch to ema")
131 | model_without_ddp.load_state_dict(ema_state_dict)
132 |
133 | class_num = args.class_num
134 | assert args.num_images % class_num == 0 # number of images per class must be the same
135 | class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
136 | class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
137 | world_size = misc.get_world_size()
138 | local_rank = misc.get_rank()
139 | used_time = 0
140 | gen_img_cnt = 0
141 |
142 | for i in range(num_steps):
143 | print("Generation step {}/{}".format(i, num_steps))
144 |
145 | labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size:
146 | world_size * batch_size * i + (local_rank + 1) * batch_size]
147 | labels_gen = torch.Tensor(labels_gen).long().cuda()
148 |
149 |
150 | torch.cuda.synchronize()
151 | start_time = time.time()
152 |
153 | # generation
154 | with torch.no_grad():
155 | with torch.cuda.amp.autocast():
156 | sampled_tokens = model_without_ddp.sample_tokens(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
157 | cfg_schedule=args.cfg_schedule, labels=labels_gen,
158 | temperature=args.temperature)
159 | sampled_images = vae.decode(sampled_tokens / 0.2325)
160 |
161 | # measure speed after the first generation batch
162 | if i >= 1:
163 | torch.cuda.synchronize()
164 | used_time += time.time() - start_time
165 | gen_img_cnt += batch_size
166 | print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
167 |
168 | torch.distributed.barrier()
169 | sampled_images = sampled_images.detach().cpu()
170 | sampled_images = (sampled_images + 1) / 2
171 |
172 | # distributed save
173 | for b_id in range(sampled_images.size(0)):
174 | img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
175 | if img_id >= args.num_images:
176 | break
177 | gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
178 | gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
179 | cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
180 |
181 | torch.distributed.barrier()
182 | time.sleep(10)
183 |
184 | # back to no ema
185 | if use_ema:
186 | print("Switch back from ema")
187 | model_without_ddp.load_state_dict(model_state_dict)
188 |
189 | # compute FID and IS
190 | if log_writer is not None:
191 | if args.img_size == 256:
192 | input2 = None
193 | fid_statistics_file = 'fid_stats/adm_in256_stats.npz'
194 | else:
195 | raise NotImplementedError
196 | metrics_dict = torch_fidelity.calculate_metrics(
197 | input1=save_folder,
198 | input2=input2,
199 | fid_statistics_file=fid_statistics_file,
200 | cuda=True,
201 | isc=True,
202 | fid=True,
203 | kid=False,
204 | prc=False,
205 | verbose=False,
206 | )
207 | fid = metrics_dict['frechet_inception_distance']
208 | inception_score = metrics_dict['inception_score_mean']
209 | postfix = ""
210 | if use_ema:
211 | postfix = postfix + "_ema"
212 | if not cfg == 1.0:
213 | postfix = postfix + "_cfg{}".format(cfg)
214 | log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
215 | log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
216 | print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
217 | # remove temporal saving folder
218 | shutil.rmtree(save_folder)
219 |
220 | torch.distributed.barrier()
221 | time.sleep(10)
222 |
223 |
224 | def cache_latents(vae,
225 | data_loader: Iterable,
226 | device: torch.device,
227 | args=None):
228 | metric_logger = misc.MetricLogger(delimiter=" ")
229 | header = 'Caching: '
230 | print_freq = 20
231 |
232 | for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
233 |
234 | samples = samples.to(device, non_blocking=True)
235 |
236 | with torch.no_grad():
237 | posterior = vae.encode(samples)
238 | moments = posterior.parameters
239 | posterior_flip = vae.encode(samples.flip(dims=[3]))
240 | moments_flip = posterior_flip.parameters
241 |
242 | for i, path in enumerate(paths):
243 | save_path = os.path.join(args.cached_path, path + '.npz')
244 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
245 | np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy())
246 |
247 | if misc.is_dist_avail_and_initialized():
248 | torch.cuda.synchronize()
249 |
250 | return
251 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: mar
2 | channels:
3 | - pytorch
4 | - defaults
5 | - nvidia
6 | dependencies:
7 | - python=3.8.5
8 | - pip=20.3
9 | - pytorch-cuda=11.8
10 | - pytorch=2.2.2
11 | - torchvision=0.17.2
12 | - numpy=1.22
13 | - pip:
14 | - opencv-python==4.1.2.30
15 | - timm==0.9.12
16 | - tensorboard==2.10.0
17 | - scipy==1.9.1
18 | - gdown==5.2.0
19 | - -e git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
20 |
--------------------------------------------------------------------------------
/fid_stats/adm_in256_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LTH14/mar/fe470ac24afbee924668d8c5c83e9fec60af3a73/fid_stats/adm_in256_stats.npz
--------------------------------------------------------------------------------
/main_cache.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import os
5 | import time
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torchvision.transforms as transforms
12 |
13 | import util.misc as misc
14 | from util.loader import ImageFolderWithFilename
15 |
16 | from models.vae import AutoencoderKL
17 | from engine_mar import cache_latents
18 |
19 | from util.crop import center_crop_arr
20 |
21 |
22 | def get_args_parser():
23 | parser = argparse.ArgumentParser('Cache VAE latents', add_help=False)
24 | parser.add_argument('--batch_size', default=128, type=int,
25 | help='Batch size per GPU (effective batch size is batch_size * # gpus')
26 |
27 | # VAE parameters
28 | parser.add_argument('--img_size', default=256, type=int,
29 | help='images input size')
30 | parser.add_argument('--vae_path', default="pretrained_models/vae/kl16.ckpt", type=str,
31 | help='images input size')
32 | parser.add_argument('--vae_embed_dim', default=16, type=int,
33 | help='vae output embedding dimension')
34 | # Dataset parameters
35 | parser.add_argument('--data_path', default='./data/imagenet', type=str,
36 | help='dataset path')
37 | parser.add_argument('--device', default='cuda',
38 | help='device to use for training / testing')
39 | parser.add_argument('--seed', default=0, type=int)
40 |
41 | parser.add_argument('--num_workers', default=10, type=int)
42 | parser.add_argument('--pin_mem', action='store_true',
43 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
44 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
45 | parser.set_defaults(pin_mem=True)
46 |
47 | # distributed training parameters
48 | parser.add_argument('--world_size', default=1, type=int,
49 | help='number of distributed processes')
50 | parser.add_argument('--local_rank', default=-1, type=int)
51 | parser.add_argument('--dist_on_itp', action='store_true')
52 | parser.add_argument('--dist_url', default='env://',
53 | help='url used to set up distributed training')
54 |
55 | # caching latents
56 | parser.add_argument('--cached_path', default='', help='path to cached latents')
57 |
58 | return parser
59 |
60 |
61 | def main(args):
62 | misc.init_distributed_mode(args)
63 |
64 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
65 | print("{}".format(args).replace(', ', ',\n'))
66 |
67 | device = torch.device(args.device)
68 |
69 | # fix the seed for reproducibility
70 | seed = args.seed + misc.get_rank()
71 | torch.manual_seed(seed)
72 | np.random.seed(seed)
73 |
74 | cudnn.benchmark = True
75 |
76 | num_tasks = misc.get_world_size()
77 | global_rank = misc.get_rank()
78 |
79 | # augmentation following DiT and ADM
80 | transform_train = transforms.Compose([
81 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
82 | # transforms.RandomHorizontalFlip(),
83 | transforms.ToTensor(),
84 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
85 | ])
86 |
87 | dataset_train = ImageFolderWithFilename(os.path.join(args.data_path, 'train'), transform=transform_train)
88 | print(dataset_train)
89 |
90 | sampler_train = torch.utils.data.DistributedSampler(
91 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=False,
92 | )
93 | print("Sampler_train = %s" % str(sampler_train))
94 |
95 | data_loader_train = torch.utils.data.DataLoader(
96 | dataset_train, sampler=sampler_train,
97 | batch_size=args.batch_size,
98 | num_workers=args.num_workers,
99 | pin_memory=args.pin_mem,
100 | drop_last=False, # Don't drop in cache
101 | )
102 |
103 | # define the vae
104 | vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
105 |
106 | # training
107 | print(f"Start caching VAE latents")
108 | start_time = time.time()
109 | cache_latents(
110 | vae,
111 | data_loader_train,
112 | device,
113 | args=args
114 | )
115 | total_time = time.time() - start_time
116 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
117 | print('Caching time {}'.format(total_time_str))
118 |
119 |
120 | if __name__ == '__main__':
121 | args = get_args_parser()
122 | args = args.parse_args()
123 | main(args)
124 |
--------------------------------------------------------------------------------
/main_mar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import os
5 | import time
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torchvision.transforms as transforms
12 | import torchvision.datasets as datasets
13 |
14 | from util.crop import center_crop_arr
15 | import util.misc as misc
16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
17 | from util.loader import CachedFolder
18 |
19 | from models.vae import AutoencoderKL
20 | from models import mar
21 | from engine_mar import train_one_epoch, evaluate
22 | import copy
23 |
24 |
25 | def get_args_parser():
26 | parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
27 | parser.add_argument('--batch_size', default=16, type=int,
28 | help='Batch size per GPU (effective batch size is batch_size * # gpus')
29 | parser.add_argument('--epochs', default=400, type=int)
30 |
31 | # Model parameters
32 | parser.add_argument('--model', default='mar_large', type=str, metavar='MODEL',
33 | help='Name of model to train')
34 |
35 | # VAE parameters
36 | parser.add_argument('--img_size', default=256, type=int,
37 | help='images input size')
38 | parser.add_argument('--vae_path', default="pretrained_models/vae/kl16.ckpt", type=str,
39 | help='images input size')
40 | parser.add_argument('--vae_embed_dim', default=16, type=int,
41 | help='vae output embedding dimension')
42 | parser.add_argument('--vae_stride', default=16, type=int,
43 | help='tokenizer stride, default use KL16')
44 | parser.add_argument('--patch_size', default=1, type=int,
45 | help='number of tokens to group as a patch.')
46 |
47 | # Generation parameters
48 | parser.add_argument('--num_iter', default=64, type=int,
49 | help='number of autoregressive iterations to generate an image')
50 | parser.add_argument('--num_images', default=50000, type=int,
51 | help='number of images to generate')
52 | parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
53 | parser.add_argument('--cfg_schedule', default="linear", type=str)
54 | parser.add_argument('--label_drop_prob', default=0.1, type=float)
55 | parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
56 | parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
57 | parser.add_argument('--online_eval', action='store_true')
58 | parser.add_argument('--evaluate', action='store_true')
59 | parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
60 |
61 | # Optimizer parameters
62 | parser.add_argument('--weight_decay', type=float, default=0.02,
63 | help='weight decay (default: 0.02)')
64 |
65 | parser.add_argument('--grad_checkpointing', action='store_true')
66 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
67 | help='learning rate (absolute lr)')
68 | parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
70 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
71 | help='lower lr bound for cyclic schedulers that hit 0')
72 | parser.add_argument('--lr_schedule', type=str, default='constant',
73 | help='learning rate schedule')
74 | parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
75 | help='epochs to warmup LR')
76 | parser.add_argument('--ema_rate', default=0.9999, type=float)
77 |
78 | # MAR params
79 | parser.add_argument('--mask_ratio_min', type=float, default=0.7,
80 | help='Minimum mask ratio')
81 | parser.add_argument('--grad_clip', type=float, default=3.0,
82 | help='Gradient clip')
83 | parser.add_argument('--attn_dropout', type=float, default=0.1,
84 | help='attention dropout')
85 | parser.add_argument('--proj_dropout', type=float, default=0.1,
86 | help='projection dropout')
87 | parser.add_argument('--buffer_size', type=int, default=64)
88 |
89 | # Diffusion Loss params
90 | parser.add_argument('--diffloss_d', type=int, default=12)
91 | parser.add_argument('--diffloss_w', type=int, default=1536)
92 | parser.add_argument('--num_sampling_steps', type=str, default="100")
93 | parser.add_argument('--diffusion_batch_mul', type=int, default=1)
94 | parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
95 |
96 | # Dataset parameters
97 | parser.add_argument('--data_path', default='./data/imagenet', type=str,
98 | help='dataset path')
99 | parser.add_argument('--class_num', default=1000, type=int)
100 |
101 | parser.add_argument('--output_dir', default='./output_dir',
102 | help='path where to save, empty for no saving')
103 | parser.add_argument('--log_dir', default='./output_dir',
104 | help='path where to tensorboard log')
105 | parser.add_argument('--device', default='cuda',
106 | help='device to use for training / testing')
107 | parser.add_argument('--seed', default=1, type=int)
108 | parser.add_argument('--resume', default='',
109 | help='resume from checkpoint')
110 |
111 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
112 | help='start epoch')
113 | parser.add_argument('--num_workers', default=10, type=int)
114 | parser.add_argument('--pin_mem', action='store_true',
115 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
116 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
117 | parser.set_defaults(pin_mem=True)
118 |
119 | # distributed training parameters
120 | parser.add_argument('--world_size', default=1, type=int,
121 | help='number of distributed processes')
122 | parser.add_argument('--local_rank', default=-1, type=int)
123 | parser.add_argument('--dist_on_itp', action='store_true')
124 | parser.add_argument('--dist_url', default='env://',
125 | help='url used to set up distributed training')
126 |
127 | # caching latents
128 | parser.add_argument('--use_cached', action='store_true', dest='use_cached',
129 | help='Use cached latents')
130 | parser.set_defaults(use_cached=False)
131 | parser.add_argument('--cached_path', default='', help='path to cached latents')
132 |
133 | return parser
134 |
135 |
136 | def main(args):
137 | misc.init_distributed_mode(args)
138 |
139 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
140 | print("{}".format(args).replace(', ', ',\n'))
141 |
142 | device = torch.device(args.device)
143 |
144 | # fix the seed for reproducibility
145 | seed = args.seed + misc.get_rank()
146 | torch.manual_seed(seed)
147 | np.random.seed(seed)
148 |
149 | cudnn.benchmark = True
150 |
151 | num_tasks = misc.get_world_size()
152 | global_rank = misc.get_rank()
153 |
154 | if global_rank == 0 and args.log_dir is not None:
155 | os.makedirs(args.log_dir, exist_ok=True)
156 | log_writer = SummaryWriter(log_dir=args.log_dir)
157 | else:
158 | log_writer = None
159 |
160 | # augmentation following DiT and ADM
161 | transform_train = transforms.Compose([
162 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
163 | transforms.RandomHorizontalFlip(),
164 | transforms.ToTensor(),
165 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
166 | ])
167 |
168 | if args.use_cached:
169 | dataset_train = CachedFolder(args.cached_path)
170 | else:
171 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
172 | print(dataset_train)
173 |
174 | sampler_train = torch.utils.data.DistributedSampler(
175 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
176 | )
177 | print("Sampler_train = %s" % str(sampler_train))
178 |
179 | data_loader_train = torch.utils.data.DataLoader(
180 | dataset_train, sampler=sampler_train,
181 | batch_size=args.batch_size,
182 | num_workers=args.num_workers,
183 | pin_memory=args.pin_mem,
184 | drop_last=True,
185 | )
186 |
187 | # define the vae and mar model
188 | vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
189 | for param in vae.parameters():
190 | param.requires_grad = False
191 |
192 | model = mar.__dict__[args.model](
193 | img_size=args.img_size,
194 | vae_stride=args.vae_stride,
195 | patch_size=args.patch_size,
196 | vae_embed_dim=args.vae_embed_dim,
197 | mask_ratio_min=args.mask_ratio_min,
198 | label_drop_prob=args.label_drop_prob,
199 | class_num=args.class_num,
200 | attn_dropout=args.attn_dropout,
201 | proj_dropout=args.proj_dropout,
202 | buffer_size=args.buffer_size,
203 | diffloss_d=args.diffloss_d,
204 | diffloss_w=args.diffloss_w,
205 | num_sampling_steps=args.num_sampling_steps,
206 | diffusion_batch_mul=args.diffusion_batch_mul,
207 | grad_checkpointing=args.grad_checkpointing,
208 | )
209 |
210 | print("Model = %s" % str(model))
211 | # following timm: set wd as 0 for bias and norm layers
212 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
213 | print("Number of trainable parameters: {}M".format(n_params / 1e6))
214 |
215 | model.to(device)
216 | model_without_ddp = model
217 |
218 | eff_batch_size = args.batch_size * misc.get_world_size()
219 |
220 | if args.lr is None: # only base_lr is specified
221 | args.lr = args.blr * eff_batch_size / 256
222 |
223 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
224 | print("actual lr: %.2e" % args.lr)
225 | print("effective batch size: %d" % eff_batch_size)
226 |
227 | if args.distributed:
228 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
229 | model_without_ddp = model.module
230 |
231 | # no weight decay on bias, norm layers, and diffloss MLP
232 | param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
233 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
234 | print(optimizer)
235 | loss_scaler = NativeScaler()
236 |
237 | # resume training
238 | if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
239 | checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
240 | model_without_ddp.load_state_dict(checkpoint['model'])
241 | model_params = list(model_without_ddp.parameters())
242 | ema_state_dict = checkpoint['model_ema']
243 | ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
244 | print("Resume checkpoint %s" % args.resume)
245 |
246 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
247 | optimizer.load_state_dict(checkpoint['optimizer'])
248 | args.start_epoch = checkpoint['epoch'] + 1
249 | if 'scaler' in checkpoint:
250 | loss_scaler.load_state_dict(checkpoint['scaler'])
251 | print("With optim & sched!")
252 | del checkpoint
253 | else:
254 | model_params = list(model_without_ddp.parameters())
255 | ema_params = copy.deepcopy(model_params)
256 | print("Training from scratch")
257 |
258 | # evaluate FID and IS
259 | if args.evaluate:
260 | torch.cuda.empty_cache()
261 | evaluate(model_without_ddp, vae, ema_params, args, 0, batch_size=args.eval_bsz, log_writer=log_writer,
262 | cfg=args.cfg, use_ema=True)
263 | return
264 |
265 | # training
266 | print(f"Start training for {args.epochs} epochs")
267 | start_time = time.time()
268 | for epoch in range(args.start_epoch, args.epochs):
269 | if args.distributed:
270 | data_loader_train.sampler.set_epoch(epoch)
271 |
272 | train_one_epoch(
273 | model, vae,
274 | model_params, ema_params,
275 | data_loader_train,
276 | optimizer, device, epoch, loss_scaler,
277 | log_writer=log_writer,
278 | args=args
279 | )
280 |
281 | # save checkpoint
282 | if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
283 | misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
284 | loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name="last")
285 |
286 | # online evaluation
287 | if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
288 | torch.cuda.empty_cache()
289 | evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz, log_writer=log_writer,
290 | cfg=1.0, use_ema=True)
291 | if not (args.cfg == 1.0 or args.cfg == 0.0):
292 | evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz // 2,
293 | log_writer=log_writer, cfg=args.cfg, use_ema=True)
294 | torch.cuda.empty_cache()
295 |
296 | if misc.is_main_process():
297 | if log_writer is not None:
298 | log_writer.flush()
299 |
300 | total_time = time.time() - start_time
301 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
302 | print('Training time {}'.format(total_time_str))
303 |
304 |
305 | if __name__ == '__main__':
306 | args = get_args_parser()
307 | args = args.parse_args()
308 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
309 | args.log_dir = args.output_dir
310 | main(args)
311 |
--------------------------------------------------------------------------------
/models/diffloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 |
6 | from diffusion import create_diffusion
7 |
8 |
9 | class DiffLoss(nn.Module):
10 | """Diffusion Loss"""
11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12 | super(DiffLoss, self).__init__()
13 | self.in_channels = target_channels
14 | self.net = SimpleMLPAdaLN(
15 | in_channels=target_channels,
16 | model_channels=width,
17 | out_channels=target_channels * 2, # for vlb loss
18 | z_channels=z_channels,
19 | num_res_blocks=depth,
20 | grad_checkpointing=grad_checkpointing
21 | )
22 |
23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25 |
26 | def forward(self, target, z, mask=None):
27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28 | model_kwargs = dict(c=z)
29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30 | loss = loss_dict["loss"]
31 | if mask is not None:
32 | loss = (loss * mask).sum() / mask.sum()
33 | return loss.mean()
34 |
35 | def sample(self, z, temperature=1.0, cfg=1.0):
36 | # diffusion loss sampling
37 | if not cfg == 1.0:
38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39 | noise = torch.cat([noise, noise], dim=0)
40 | model_kwargs = dict(c=z, cfg_scale=cfg)
41 | sample_fn = self.net.forward_with_cfg
42 | else:
43 | noise = torch.randn(z.shape[0], self.in_channels).cuda()
44 | model_kwargs = dict(c=z)
45 | sample_fn = self.net.forward
46 |
47 | sampled_token_latent = self.gen_diffusion.p_sample_loop(
48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49 | temperature=temperature
50 | )
51 |
52 | return sampled_token_latent
53 |
54 |
55 | def modulate(x, shift, scale):
56 | return x * (1 + scale) + shift
57 |
58 |
59 | class TimestepEmbedder(nn.Module):
60 | """
61 | Embeds scalar timesteps into vector representations.
62 | """
63 | def __init__(self, hidden_size, frequency_embedding_size=256):
64 | super().__init__()
65 | self.mlp = nn.Sequential(
66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67 | nn.SiLU(),
68 | nn.Linear(hidden_size, hidden_size, bias=True),
69 | )
70 | self.frequency_embedding_size = frequency_embedding_size
71 |
72 | @staticmethod
73 | def timestep_embedding(t, dim, max_period=10000):
74 | """
75 | Create sinusoidal timestep embeddings.
76 | :param t: a 1-D Tensor of N indices, one per batch element.
77 | These may be fractional.
78 | :param dim: the dimension of the output.
79 | :param max_period: controls the minimum frequency of the embeddings.
80 | :return: an (N, D) Tensor of positional embeddings.
81 | """
82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83 | half = dim // 2
84 | freqs = torch.exp(
85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86 | ).to(device=t.device)
87 | args = t[:, None].float() * freqs[None]
88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89 | if dim % 2:
90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91 | return embedding
92 |
93 | def forward(self, t):
94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95 | t_emb = self.mlp(t_freq)
96 | return t_emb
97 |
98 |
99 | class ResBlock(nn.Module):
100 | """
101 | A residual block that can optionally change the number of channels.
102 | :param channels: the number of input channels.
103 | """
104 |
105 | def __init__(
106 | self,
107 | channels
108 | ):
109 | super().__init__()
110 | self.channels = channels
111 |
112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113 | self.mlp = nn.Sequential(
114 | nn.Linear(channels, channels, bias=True),
115 | nn.SiLU(),
116 | nn.Linear(channels, channels, bias=True),
117 | )
118 |
119 | self.adaLN_modulation = nn.Sequential(
120 | nn.SiLU(),
121 | nn.Linear(channels, 3 * channels, bias=True)
122 | )
123 |
124 | def forward(self, x, y):
125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127 | h = self.mlp(h)
128 | return x + gate_mlp * h
129 |
130 |
131 | class FinalLayer(nn.Module):
132 | """
133 | The final layer adopted from DiT.
134 | """
135 | def __init__(self, model_channels, out_channels):
136 | super().__init__()
137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138 | self.linear = nn.Linear(model_channels, out_channels, bias=True)
139 | self.adaLN_modulation = nn.Sequential(
140 | nn.SiLU(),
141 | nn.Linear(model_channels, 2 * model_channels, bias=True)
142 | )
143 |
144 | def forward(self, x, c):
145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146 | x = modulate(self.norm_final(x), shift, scale)
147 | x = self.linear(x)
148 | return x
149 |
150 |
151 | class SimpleMLPAdaLN(nn.Module):
152 | """
153 | The MLP for Diffusion Loss.
154 | :param in_channels: channels in the input Tensor.
155 | :param model_channels: base channel count for the model.
156 | :param out_channels: channels in the output Tensor.
157 | :param z_channels: channels in the condition.
158 | :param num_res_blocks: number of residual blocks per downsample.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels,
164 | model_channels,
165 | out_channels,
166 | z_channels,
167 | num_res_blocks,
168 | grad_checkpointing=False
169 | ):
170 | super().__init__()
171 |
172 | self.in_channels = in_channels
173 | self.model_channels = model_channels
174 | self.out_channels = out_channels
175 | self.num_res_blocks = num_res_blocks
176 | self.grad_checkpointing = grad_checkpointing
177 |
178 | self.time_embed = TimestepEmbedder(model_channels)
179 | self.cond_embed = nn.Linear(z_channels, model_channels)
180 |
181 | self.input_proj = nn.Linear(in_channels, model_channels)
182 |
183 | res_blocks = []
184 | for i in range(num_res_blocks):
185 | res_blocks.append(ResBlock(
186 | model_channels,
187 | ))
188 |
189 | self.res_blocks = nn.ModuleList(res_blocks)
190 | self.final_layer = FinalLayer(model_channels, out_channels)
191 |
192 | self.initialize_weights()
193 |
194 | def initialize_weights(self):
195 | def _basic_init(module):
196 | if isinstance(module, nn.Linear):
197 | torch.nn.init.xavier_uniform_(module.weight)
198 | if module.bias is not None:
199 | nn.init.constant_(module.bias, 0)
200 | self.apply(_basic_init)
201 |
202 | # Initialize timestep embedding MLP
203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205 |
206 | # Zero-out adaLN modulation layers
207 | for block in self.res_blocks:
208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210 |
211 | # Zero-out output layers
212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214 | nn.init.constant_(self.final_layer.linear.weight, 0)
215 | nn.init.constant_(self.final_layer.linear.bias, 0)
216 |
217 | def forward(self, x, t, c):
218 | """
219 | Apply the model to an input batch.
220 | :param x: an [N x C] Tensor of inputs.
221 | :param t: a 1-D batch of timesteps.
222 | :param c: conditioning from AR transformer.
223 | :return: an [N x C] Tensor of outputs.
224 | """
225 | x = self.input_proj(x)
226 | t = self.time_embed(t)
227 | c = self.cond_embed(c)
228 |
229 | y = t + c
230 |
231 | if self.grad_checkpointing and not torch.jit.is_scripting():
232 | for block in self.res_blocks:
233 | x = checkpoint(block, x, y)
234 | else:
235 | for block in self.res_blocks:
236 | x = block(x, y)
237 |
238 | return self.final_layer(x, y)
239 |
240 | def forward_with_cfg(self, x, t, c, cfg_scale):
241 | half = x[: len(x) // 2]
242 | combined = torch.cat([half, half], dim=0)
243 | model_out = self.forward(combined, t, c)
244 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
245 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
246 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
247 | eps = torch.cat([half_eps, half_eps], dim=0)
248 | return torch.cat([eps, rest], dim=1)
249 |
--------------------------------------------------------------------------------
/models/mar.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import numpy as np
4 | from tqdm import tqdm
5 | import scipy.stats as stats
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.checkpoint import checkpoint
10 |
11 | from timm.models.vision_transformer import Block
12 |
13 | from models.diffloss import DiffLoss
14 |
15 |
16 | def mask_by_order(mask_len, order, bsz, seq_len):
17 | masking = torch.zeros(bsz, seq_len).cuda()
18 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
19 | return masking
20 |
21 |
22 | class MAR(nn.Module):
23 | """ Masked Autoencoder with VisionTransformer backbone
24 | """
25 | def __init__(self, img_size=256, vae_stride=16, patch_size=1,
26 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
27 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
28 | mlp_ratio=4., norm_layer=nn.LayerNorm,
29 | vae_embed_dim=16,
30 | mask_ratio_min=0.7,
31 | label_drop_prob=0.1,
32 | class_num=1000,
33 | attn_dropout=0.1,
34 | proj_dropout=0.1,
35 | buffer_size=64,
36 | diffloss_d=3,
37 | diffloss_w=1024,
38 | num_sampling_steps='100',
39 | diffusion_batch_mul=4,
40 | grad_checkpointing=False,
41 | ):
42 | super().__init__()
43 |
44 | # --------------------------------------------------------------------------
45 | # VAE and patchify specifics
46 | self.vae_embed_dim = vae_embed_dim
47 |
48 | self.img_size = img_size
49 | self.vae_stride = vae_stride
50 | self.patch_size = patch_size
51 | self.seq_h = self.seq_w = img_size // vae_stride // patch_size
52 | self.seq_len = self.seq_h * self.seq_w
53 | self.token_embed_dim = vae_embed_dim * patch_size**2
54 | self.grad_checkpointing = grad_checkpointing
55 |
56 | # --------------------------------------------------------------------------
57 | # Class Embedding
58 | self.num_classes = class_num
59 | self.class_emb = nn.Embedding(class_num, encoder_embed_dim)
60 | self.label_drop_prob = label_drop_prob
61 | # Fake class embedding for CFG's unconditional generation
62 | self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
63 |
64 | # --------------------------------------------------------------------------
65 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
66 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
67 |
68 | # --------------------------------------------------------------------------
69 | # MAR encoder specifics
70 | self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
71 | self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
72 | self.buffer_size = buffer_size
73 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
74 |
75 | self.encoder_blocks = nn.ModuleList([
76 | Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
77 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
78 | self.encoder_norm = norm_layer(encoder_embed_dim)
79 |
80 | # --------------------------------------------------------------------------
81 | # MAR decoder specifics
82 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
83 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
84 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
85 |
86 | self.decoder_blocks = nn.ModuleList([
87 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
88 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
89 |
90 | self.decoder_norm = norm_layer(decoder_embed_dim)
91 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
92 |
93 | self.initialize_weights()
94 |
95 | # --------------------------------------------------------------------------
96 | # Diffusion Loss
97 | self.diffloss = DiffLoss(
98 | target_channels=self.token_embed_dim,
99 | z_channels=decoder_embed_dim,
100 | width=diffloss_w,
101 | depth=diffloss_d,
102 | num_sampling_steps=num_sampling_steps,
103 | grad_checkpointing=grad_checkpointing
104 | )
105 | self.diffusion_batch_mul = diffusion_batch_mul
106 |
107 | def initialize_weights(self):
108 | # parameters
109 | torch.nn.init.normal_(self.class_emb.weight, std=.02)
110 | torch.nn.init.normal_(self.fake_latent, std=.02)
111 | torch.nn.init.normal_(self.mask_token, std=.02)
112 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
113 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
114 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
115 |
116 | # initialize nn.Linear and nn.LayerNorm
117 | self.apply(self._init_weights)
118 |
119 | def _init_weights(self, m):
120 | if isinstance(m, nn.Linear):
121 | # we use xavier_uniform following official JAX ViT:
122 | torch.nn.init.xavier_uniform_(m.weight)
123 | if isinstance(m, nn.Linear) and m.bias is not None:
124 | nn.init.constant_(m.bias, 0)
125 | elif isinstance(m, nn.LayerNorm):
126 | if m.bias is not None:
127 | nn.init.constant_(m.bias, 0)
128 | if m.weight is not None:
129 | nn.init.constant_(m.weight, 1.0)
130 |
131 | def patchify(self, x):
132 | bsz, c, h, w = x.shape
133 | p = self.patch_size
134 | h_, w_ = h // p, w // p
135 |
136 | x = x.reshape(bsz, c, h_, p, w_, p)
137 | x = torch.einsum('nchpwq->nhwcpq', x)
138 | x = x.reshape(bsz, h_ * w_, c * p ** 2)
139 | return x # [n, l, d]
140 |
141 | def unpatchify(self, x):
142 | bsz = x.shape[0]
143 | p = self.patch_size
144 | c = self.vae_embed_dim
145 | h_, w_ = self.seq_h, self.seq_w
146 |
147 | x = x.reshape(bsz, h_, w_, c, p, p)
148 | x = torch.einsum('nhwcpq->nchpwq', x)
149 | x = x.reshape(bsz, c, h_ * p, w_ * p)
150 | return x # [n, c, h, w]
151 |
152 | def sample_orders(self, bsz):
153 | # generate a batch of random generation orders
154 | orders = []
155 | for _ in range(bsz):
156 | order = np.array(list(range(self.seq_len)))
157 | np.random.shuffle(order)
158 | orders.append(order)
159 | orders = torch.Tensor(np.array(orders)).cuda().long()
160 | return orders
161 |
162 | def random_masking(self, x, orders):
163 | # generate token mask
164 | bsz, seq_len, embed_dim = x.shape
165 | mask_rate = self.mask_ratio_generator.rvs(1)[0]
166 | num_masked_tokens = int(np.ceil(seq_len * mask_rate))
167 | mask = torch.zeros(bsz, seq_len, device=x.device)
168 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
169 | src=torch.ones(bsz, seq_len, device=x.device))
170 | return mask
171 |
172 | def forward_mae_encoder(self, x, mask, class_embedding):
173 | x = self.z_proj(x)
174 | bsz, seq_len, embed_dim = x.shape
175 |
176 | # concat buffer
177 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
178 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
179 |
180 | # random drop class embedding during training
181 | if self.training:
182 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
183 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
184 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
185 |
186 | x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
187 |
188 | # encoder position embedding
189 | x = x + self.encoder_pos_embed_learned
190 | x = self.z_proj_ln(x)
191 |
192 | # dropping
193 | x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
194 |
195 | # apply Transformer blocks
196 | if self.grad_checkpointing and not torch.jit.is_scripting():
197 | for block in self.encoder_blocks:
198 | x = checkpoint(block, x)
199 | else:
200 | for block in self.encoder_blocks:
201 | x = block(x)
202 | x = self.encoder_norm(x)
203 |
204 | return x
205 |
206 | def forward_mae_decoder(self, x, mask):
207 |
208 | x = self.decoder_embed(x)
209 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
210 |
211 | # pad mask tokens
212 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
213 | x_after_pad = mask_tokens.clone()
214 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
215 |
216 | # decoder position embedding
217 | x = x_after_pad + self.decoder_pos_embed_learned
218 |
219 | # apply Transformer blocks
220 | if self.grad_checkpointing and not torch.jit.is_scripting():
221 | for block in self.decoder_blocks:
222 | x = checkpoint(block, x)
223 | else:
224 | for block in self.decoder_blocks:
225 | x = block(x)
226 | x = self.decoder_norm(x)
227 |
228 | x = x[:, self.buffer_size:]
229 | x = x + self.diffusion_pos_embed_learned
230 | return x
231 |
232 | def forward_loss(self, z, target, mask):
233 | bsz, seq_len, _ = target.shape
234 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
235 | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
236 | mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
237 | loss = self.diffloss(z=z, target=target, mask=mask)
238 | return loss
239 |
240 | def forward(self, imgs, labels):
241 |
242 | # class embed
243 | class_embedding = self.class_emb(labels)
244 |
245 | # patchify and mask (drop) tokens
246 | x = self.patchify(imgs)
247 | gt_latents = x.clone().detach()
248 | orders = self.sample_orders(bsz=x.size(0))
249 | mask = self.random_masking(x, orders)
250 |
251 | # mae encoder
252 | x = self.forward_mae_encoder(x, mask, class_embedding)
253 |
254 | # mae decoder
255 | z = self.forward_mae_decoder(x, mask)
256 |
257 | # diffloss
258 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
259 |
260 | return loss
261 |
262 | def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
263 |
264 | # init and sample generation orders
265 | mask = torch.ones(bsz, self.seq_len).cuda()
266 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
267 | orders = self.sample_orders(bsz)
268 |
269 | indices = list(range(num_iter))
270 | if progress:
271 | indices = tqdm(indices)
272 | # generate latents
273 | for step in indices:
274 | cur_tokens = tokens.clone()
275 |
276 | # class embedding and CFG
277 | if labels is not None:
278 | class_embedding = self.class_emb(labels)
279 | else:
280 | class_embedding = self.fake_latent.repeat(bsz, 1)
281 | if not cfg == 1.0:
282 | tokens = torch.cat([tokens, tokens], dim=0)
283 | class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
284 | mask = torch.cat([mask, mask], dim=0)
285 |
286 | # mae encoder
287 | x = self.forward_mae_encoder(tokens, mask, class_embedding)
288 |
289 | # mae decoder
290 | z = self.forward_mae_decoder(x, mask)
291 |
292 | # mask ratio for the next round, following MaskGIT and MAGE.
293 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
294 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
295 |
296 | # masks out at least one for the next iteration
297 | mask_len = torch.maximum(torch.Tensor([1]).cuda(),
298 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
299 |
300 | # get masking for next iteration and locations to be predicted in this iteration
301 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
302 | if step >= num_iter - 1:
303 | mask_to_pred = mask[:bsz].bool()
304 | else:
305 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
306 | mask = mask_next
307 | if not cfg == 1.0:
308 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
309 |
310 | # sample token latents for this step
311 | z = z[mask_to_pred.nonzero(as_tuple=True)]
312 | # cfg schedule follow Muse
313 | if cfg_schedule == "linear":
314 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
315 | elif cfg_schedule == "constant":
316 | cfg_iter = cfg
317 | else:
318 | raise NotImplementedError
319 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
320 | if not cfg == 1.0:
321 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
322 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
323 |
324 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
325 | tokens = cur_tokens.clone()
326 |
327 | # unpatchify
328 | tokens = self.unpatchify(tokens)
329 | return tokens
330 |
331 |
332 | def mar_base(**kwargs):
333 | model = MAR(
334 | encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
335 | decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
336 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
337 | return model
338 |
339 |
340 | def mar_large(**kwargs):
341 | model = MAR(
342 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
343 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
344 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
345 | return model
346 |
347 |
348 | def mar_huge(**kwargs):
349 | model = MAR(
350 | encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
351 | decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
352 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
353 | return model
354 |
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | # Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
2 | import torch
3 | import torch.nn as nn
4 |
5 | import numpy as np
6 |
7 |
8 | def nonlinearity(x):
9 | # swish
10 | return x * torch.sigmoid(x)
11 |
12 |
13 | def Normalize(in_channels, num_groups=32):
14 | return torch.nn.GroupNorm(
15 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
16 | )
17 |
18 |
19 | class Upsample(nn.Module):
20 | def __init__(self, in_channels, with_conv):
21 | super().__init__()
22 | self.with_conv = with_conv
23 | if self.with_conv:
24 | self.conv = torch.nn.Conv2d(
25 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
26 | )
27 |
28 | def forward(self, x):
29 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
30 | if self.with_conv:
31 | x = self.conv(x)
32 | return x
33 |
34 |
35 | class Downsample(nn.Module):
36 | def __init__(self, in_channels, with_conv):
37 | super().__init__()
38 | self.with_conv = with_conv
39 | if self.with_conv:
40 | # no asymmetric padding in torch conv, must do it ourselves
41 | self.conv = torch.nn.Conv2d(
42 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
43 | )
44 |
45 | def forward(self, x):
46 | if self.with_conv:
47 | pad = (0, 1, 0, 1)
48 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
49 | x = self.conv(x)
50 | else:
51 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
52 | return x
53 |
54 |
55 | class ResnetBlock(nn.Module):
56 | def __init__(
57 | self,
58 | *,
59 | in_channels,
60 | out_channels=None,
61 | conv_shortcut=False,
62 | dropout,
63 | temb_channels=512,
64 | ):
65 | super().__init__()
66 | self.in_channels = in_channels
67 | out_channels = in_channels if out_channels is None else out_channels
68 | self.out_channels = out_channels
69 | self.use_conv_shortcut = conv_shortcut
70 |
71 | self.norm1 = Normalize(in_channels)
72 | self.conv1 = torch.nn.Conv2d(
73 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
74 | )
75 | if temb_channels > 0:
76 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
77 | self.norm2 = Normalize(out_channels)
78 | self.dropout = torch.nn.Dropout(dropout)
79 | self.conv2 = torch.nn.Conv2d(
80 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
81 | )
82 | if self.in_channels != self.out_channels:
83 | if self.use_conv_shortcut:
84 | self.conv_shortcut = torch.nn.Conv2d(
85 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
86 | )
87 | else:
88 | self.nin_shortcut = torch.nn.Conv2d(
89 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
90 | )
91 |
92 | def forward(self, x, temb):
93 | h = x
94 | h = self.norm1(h)
95 | h = nonlinearity(h)
96 | h = self.conv1(h)
97 |
98 | if temb is not None:
99 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
100 |
101 | h = self.norm2(h)
102 | h = nonlinearity(h)
103 | h = self.dropout(h)
104 | h = self.conv2(h)
105 |
106 | if self.in_channels != self.out_channels:
107 | if self.use_conv_shortcut:
108 | x = self.conv_shortcut(x)
109 | else:
110 | x = self.nin_shortcut(x)
111 |
112 | return x + h
113 |
114 |
115 | class AttnBlock(nn.Module):
116 | def __init__(self, in_channels):
117 | super().__init__()
118 | self.in_channels = in_channels
119 |
120 | self.norm = Normalize(in_channels)
121 | self.q = torch.nn.Conv2d(
122 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
123 | )
124 | self.k = torch.nn.Conv2d(
125 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
126 | )
127 | self.v = torch.nn.Conv2d(
128 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
129 | )
130 | self.proj_out = torch.nn.Conv2d(
131 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
132 | )
133 |
134 | def forward(self, x):
135 | h_ = x
136 | h_ = self.norm(h_)
137 | q = self.q(h_)
138 | k = self.k(h_)
139 | v = self.v(h_)
140 |
141 | # compute attention
142 | b, c, h, w = q.shape
143 | q = q.reshape(b, c, h * w)
144 | q = q.permute(0, 2, 1) # b,hw,c
145 | k = k.reshape(b, c, h * w) # b,c,hw
146 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
147 | w_ = w_ * (int(c) ** (-0.5))
148 | w_ = torch.nn.functional.softmax(w_, dim=2)
149 |
150 | # attend to values
151 | v = v.reshape(b, c, h * w)
152 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
153 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
154 | h_ = h_.reshape(b, c, h, w)
155 |
156 | h_ = self.proj_out(h_)
157 |
158 | return x + h_
159 |
160 |
161 | class Encoder(nn.Module):
162 | def __init__(
163 | self,
164 | *,
165 | ch=128,
166 | out_ch=3,
167 | ch_mult=(1, 1, 2, 2, 4),
168 | num_res_blocks=2,
169 | attn_resolutions=(16,),
170 | dropout=0.0,
171 | resamp_with_conv=True,
172 | in_channels=3,
173 | resolution=256,
174 | z_channels=16,
175 | double_z=True,
176 | **ignore_kwargs,
177 | ):
178 | super().__init__()
179 | self.ch = ch
180 | self.temb_ch = 0
181 | self.num_resolutions = len(ch_mult)
182 | self.num_res_blocks = num_res_blocks
183 | self.resolution = resolution
184 | self.in_channels = in_channels
185 |
186 | # downsampling
187 | self.conv_in = torch.nn.Conv2d(
188 | in_channels, self.ch, kernel_size=3, stride=1, padding=1
189 | )
190 |
191 | curr_res = resolution
192 | in_ch_mult = (1,) + tuple(ch_mult)
193 | self.down = nn.ModuleList()
194 | for i_level in range(self.num_resolutions):
195 | block = nn.ModuleList()
196 | attn = nn.ModuleList()
197 | block_in = ch * in_ch_mult[i_level]
198 | block_out = ch * ch_mult[i_level]
199 | for i_block in range(self.num_res_blocks):
200 | block.append(
201 | ResnetBlock(
202 | in_channels=block_in,
203 | out_channels=block_out,
204 | temb_channels=self.temb_ch,
205 | dropout=dropout,
206 | )
207 | )
208 | block_in = block_out
209 | if curr_res in attn_resolutions:
210 | attn.append(AttnBlock(block_in))
211 | down = nn.Module()
212 | down.block = block
213 | down.attn = attn
214 | if i_level != self.num_resolutions - 1:
215 | down.downsample = Downsample(block_in, resamp_with_conv)
216 | curr_res = curr_res // 2
217 | self.down.append(down)
218 |
219 | # middle
220 | self.mid = nn.Module()
221 | self.mid.block_1 = ResnetBlock(
222 | in_channels=block_in,
223 | out_channels=block_in,
224 | temb_channels=self.temb_ch,
225 | dropout=dropout,
226 | )
227 | self.mid.attn_1 = AttnBlock(block_in)
228 | self.mid.block_2 = ResnetBlock(
229 | in_channels=block_in,
230 | out_channels=block_in,
231 | temb_channels=self.temb_ch,
232 | dropout=dropout,
233 | )
234 |
235 | # end
236 | self.norm_out = Normalize(block_in)
237 | self.conv_out = torch.nn.Conv2d(
238 | block_in,
239 | 2 * z_channels if double_z else z_channels,
240 | kernel_size=3,
241 | stride=1,
242 | padding=1,
243 | )
244 |
245 | def forward(self, x):
246 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
247 |
248 | # timestep embedding
249 | temb = None
250 |
251 | # downsampling
252 | hs = [self.conv_in(x)]
253 | for i_level in range(self.num_resolutions):
254 | for i_block in range(self.num_res_blocks):
255 | h = self.down[i_level].block[i_block](hs[-1], temb)
256 | if len(self.down[i_level].attn) > 0:
257 | h = self.down[i_level].attn[i_block](h)
258 | hs.append(h)
259 | if i_level != self.num_resolutions - 1:
260 | hs.append(self.down[i_level].downsample(hs[-1]))
261 |
262 | # middle
263 | h = hs[-1]
264 | h = self.mid.block_1(h, temb)
265 | h = self.mid.attn_1(h)
266 | h = self.mid.block_2(h, temb)
267 |
268 | # end
269 | h = self.norm_out(h)
270 | h = nonlinearity(h)
271 | h = self.conv_out(h)
272 | return h
273 |
274 |
275 | class Decoder(nn.Module):
276 | def __init__(
277 | self,
278 | *,
279 | ch=128,
280 | out_ch=3,
281 | ch_mult=(1, 1, 2, 2, 4),
282 | num_res_blocks=2,
283 | attn_resolutions=(),
284 | dropout=0.0,
285 | resamp_with_conv=True,
286 | in_channels=3,
287 | resolution=256,
288 | z_channels=16,
289 | give_pre_end=False,
290 | **ignore_kwargs,
291 | ):
292 | super().__init__()
293 | self.ch = ch
294 | self.temb_ch = 0
295 | self.num_resolutions = len(ch_mult)
296 | self.num_res_blocks = num_res_blocks
297 | self.resolution = resolution
298 | self.in_channels = in_channels
299 | self.give_pre_end = give_pre_end
300 |
301 | # compute in_ch_mult, block_in and curr_res at lowest res
302 | in_ch_mult = (1,) + tuple(ch_mult)
303 | block_in = ch * ch_mult[self.num_resolutions - 1]
304 | curr_res = resolution // 2 ** (self.num_resolutions - 1)
305 | self.z_shape = (1, z_channels, curr_res, curr_res)
306 | print(
307 | "Working with z of shape {} = {} dimensions.".format(
308 | self.z_shape, np.prod(self.z_shape)
309 | )
310 | )
311 |
312 | # z to block_in
313 | self.conv_in = torch.nn.Conv2d(
314 | z_channels, block_in, kernel_size=3, stride=1, padding=1
315 | )
316 |
317 | # middle
318 | self.mid = nn.Module()
319 | self.mid.block_1 = ResnetBlock(
320 | in_channels=block_in,
321 | out_channels=block_in,
322 | temb_channels=self.temb_ch,
323 | dropout=dropout,
324 | )
325 | self.mid.attn_1 = AttnBlock(block_in)
326 | self.mid.block_2 = ResnetBlock(
327 | in_channels=block_in,
328 | out_channels=block_in,
329 | temb_channels=self.temb_ch,
330 | dropout=dropout,
331 | )
332 |
333 | # upsampling
334 | self.up = nn.ModuleList()
335 | for i_level in reversed(range(self.num_resolutions)):
336 | block = nn.ModuleList()
337 | attn = nn.ModuleList()
338 | block_out = ch * ch_mult[i_level]
339 | for i_block in range(self.num_res_blocks + 1):
340 | block.append(
341 | ResnetBlock(
342 | in_channels=block_in,
343 | out_channels=block_out,
344 | temb_channels=self.temb_ch,
345 | dropout=dropout,
346 | )
347 | )
348 | block_in = block_out
349 | if curr_res in attn_resolutions:
350 | attn.append(AttnBlock(block_in))
351 | up = nn.Module()
352 | up.block = block
353 | up.attn = attn
354 | if i_level != 0:
355 | up.upsample = Upsample(block_in, resamp_with_conv)
356 | curr_res = curr_res * 2
357 | self.up.insert(0, up) # prepend to get consistent order
358 |
359 | # end
360 | self.norm_out = Normalize(block_in)
361 | self.conv_out = torch.nn.Conv2d(
362 | block_in, out_ch, kernel_size=3, stride=1, padding=1
363 | )
364 |
365 | def forward(self, z):
366 | # assert z.shape[1:] == self.z_shape[1:]
367 | self.last_z_shape = z.shape
368 |
369 | # timestep embedding
370 | temb = None
371 |
372 | # z to block_in
373 | h = self.conv_in(z)
374 |
375 | # middle
376 | h = self.mid.block_1(h, temb)
377 | h = self.mid.attn_1(h)
378 | h = self.mid.block_2(h, temb)
379 |
380 | # upsampling
381 | for i_level in reversed(range(self.num_resolutions)):
382 | for i_block in range(self.num_res_blocks + 1):
383 | h = self.up[i_level].block[i_block](h, temb)
384 | if len(self.up[i_level].attn) > 0:
385 | h = self.up[i_level].attn[i_block](h)
386 | if i_level != 0:
387 | h = self.up[i_level].upsample(h)
388 |
389 | # end
390 | if self.give_pre_end:
391 | return h
392 |
393 | h = self.norm_out(h)
394 | h = nonlinearity(h)
395 | h = self.conv_out(h)
396 | return h
397 |
398 |
399 | class DiagonalGaussianDistribution(object):
400 | def __init__(self, parameters, deterministic=False):
401 | self.parameters = parameters
402 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
403 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
404 | self.deterministic = deterministic
405 | self.std = torch.exp(0.5 * self.logvar)
406 | self.var = torch.exp(self.logvar)
407 | if self.deterministic:
408 | self.var = self.std = torch.zeros_like(self.mean).to(
409 | device=self.parameters.device
410 | )
411 |
412 | def sample(self):
413 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
414 | device=self.parameters.device
415 | )
416 | return x
417 |
418 | def kl(self, other=None):
419 | if self.deterministic:
420 | return torch.Tensor([0.0])
421 | else:
422 | if other is None:
423 | return 0.5 * torch.sum(
424 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
425 | dim=[1, 2, 3],
426 | )
427 | else:
428 | return 0.5 * torch.sum(
429 | torch.pow(self.mean - other.mean, 2) / other.var
430 | + self.var / other.var
431 | - 1.0
432 | - self.logvar
433 | + other.logvar,
434 | dim=[1, 2, 3],
435 | )
436 |
437 | def nll(self, sample, dims=[1, 2, 3]):
438 | if self.deterministic:
439 | return torch.Tensor([0.0])
440 | logtwopi = np.log(2.0 * np.pi)
441 | return 0.5 * torch.sum(
442 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
443 | dim=dims,
444 | )
445 |
446 | def mode(self):
447 | return self.mean
448 |
449 |
450 | class AutoencoderKL(nn.Module):
451 | def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None):
452 | super().__init__()
453 | self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim)
454 | self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim)
455 | self.use_variational = use_variational
456 | mult = 2 if self.use_variational else 1
457 | self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1)
458 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
459 | self.embed_dim = embed_dim
460 | if ckpt_path is not None:
461 | self.init_from_ckpt(ckpt_path)
462 |
463 | def init_from_ckpt(self, path):
464 | sd = torch.load(path, map_location="cpu")["model"]
465 | msg = self.load_state_dict(sd, strict=False)
466 | print("Loading pre-trained KL-VAE")
467 | print("Missing keys:")
468 | print(msg.missing_keys)
469 | print("Unexpected keys:")
470 | print(msg.unexpected_keys)
471 | print(f"Restored from {path}")
472 |
473 | def encode(self, x):
474 | h = self.encoder(x)
475 | moments = self.quant_conv(h)
476 | if not self.use_variational:
477 | moments = torch.cat((moments, torch.ones_like(moments)), 1)
478 | posterior = DiagonalGaussianDistribution(moments)
479 | return posterior
480 |
481 | def decode(self, z):
482 | z = self.post_quant_conv(z)
483 | dec = self.decoder(z)
484 | return dec
485 |
486 | def forward(self, inputs, disable=True, train=True, optimizer_idx=0):
487 | if train:
488 | return self.training_step(inputs, disable, optimizer_idx)
489 | else:
490 | return self.validation_step(inputs, disable)
491 |
--------------------------------------------------------------------------------
/util/crop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | def center_crop_arr(pil_image, image_size):
6 | """
7 | Center cropping implementation from ADM.
8 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
9 | """
10 | while min(*pil_image.size) >= 2 * image_size:
11 | pil_image = pil_image.resize(
12 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
13 | )
14 |
15 | scale = image_size / min(*pil_image.size)
16 | pil_image = pil_image.resize(
17 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
18 | )
19 |
20 | arr = np.array(pil_image)
21 | crop_y = (arr.shape[0] - image_size) // 2
22 | crop_x = (arr.shape[1] - image_size) // 2
23 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
24 |
--------------------------------------------------------------------------------
/util/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import requests
4 |
5 |
6 | def download_pretrained_vae(overwrite=False):
7 | download_path = "pretrained_models/vae/kl16.ckpt"
8 | if not os.path.exists(download_path) or overwrite:
9 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
10 | os.makedirs("pretrained_models/vae", exist_ok=True)
11 | r = requests.get("https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0", stream=True, headers=headers)
12 | print("Downloading KL-16 VAE...")
13 | with open(download_path, 'wb') as f:
14 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=254):
15 | if chunk:
16 | f.write(chunk)
17 |
18 |
19 | def download_pretrained_marb(overwrite=False):
20 | download_path = "pretrained_models/mar/mar_base/checkpoint-last.pth"
21 | if not os.path.exists(download_path) or overwrite:
22 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
23 | os.makedirs("pretrained_models/mar/mar_base", exist_ok=True)
24 | r = requests.get("https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0", stream=True, headers=headers)
25 | print("Downloading MAR-B...")
26 | with open(download_path, 'wb') as f:
27 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1587):
28 | if chunk:
29 | f.write(chunk)
30 |
31 |
32 | def download_pretrained_marl(overwrite=False):
33 | download_path = "pretrained_models/mar/mar_large/checkpoint-last.pth"
34 | if not os.path.exists(download_path) or overwrite:
35 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
36 | os.makedirs("pretrained_models/mar/mar_large", exist_ok=True)
37 | r = requests.get("https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0", stream=True, headers=headers)
38 | print("Downloading MAR-L...")
39 | with open(download_path, 'wb') as f:
40 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=3650):
41 | if chunk:
42 | f.write(chunk)
43 |
44 |
45 | def download_pretrained_marh(overwrite=False):
46 | download_path = "pretrained_models/mar/mar_huge/checkpoint-last.pth"
47 | if not os.path.exists(download_path) or overwrite:
48 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
49 | os.makedirs("pretrained_models/mar/mar_huge", exist_ok=True)
50 | r = requests.get("https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0", stream=True, headers=headers)
51 | print("Downloading MAR-H...")
52 | with open(download_path, 'wb') as f:
53 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=7191):
54 | if chunk:
55 | f.write(chunk)
56 |
57 |
58 | if __name__ == "__main__":
59 | download_pretrained_vae()
60 | download_pretrained_marb()
61 | download_pretrained_marl()
62 | download_pretrained_marh()
63 |
--------------------------------------------------------------------------------
/util/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import torch
5 | import torchvision.datasets as datasets
6 |
7 |
8 | class ImageFolderWithFilename(datasets.ImageFolder):
9 | def __getitem__(self, index: int):
10 | """
11 | Args:
12 | index (int): Index
13 |
14 | Returns:
15 | tuple: (sample, target, filename).
16 | """
17 | path, target = self.samples[index]
18 | sample = self.loader(path)
19 | if self.transform is not None:
20 | sample = self.transform(sample)
21 | if self.target_transform is not None:
22 | target = self.target_transform(target)
23 |
24 | filename = path.split(os.path.sep)[-2:]
25 | filename = os.path.join(*filename)
26 | return sample, target, filename
27 |
28 |
29 | class CachedFolder(datasets.DatasetFolder):
30 | def __init__(
31 | self,
32 | root: str,
33 | ):
34 | super().__init__(
35 | root,
36 | loader=None,
37 | extensions=(".npz",),
38 | )
39 |
40 | def __getitem__(self, index: int):
41 | """
42 | Args:
43 | index (int): Index
44 |
45 | Returns:
46 | tuple: (moments, target).
47 | """
48 | path, target = self.samples[index]
49 |
50 | data = np.load(path)
51 | if torch.rand(1) < 0.5: # randomly hflip
52 | moments = data['moments']
53 | else:
54 | moments = data['moments_flip']
55 |
56 | return moments, target
57 |
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def adjust_learning_rate(optimizer, epoch, args):
5 | """Decay the learning rate with half-cycle cosine after warmup"""
6 | if epoch < args.warmup_epochs:
7 | lr = args.lr * epoch / args.warmup_epochs
8 | else:
9 | if args.lr_schedule == "constant":
10 | lr = args.lr
11 | elif args.lr_schedule == "cosine":
12 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
13 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
14 | else:
15 | raise NotImplementedError
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import datetime
3 | import os
4 | import time
5 | from collections import defaultdict, deque
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.distributed as dist
10 | TORCH_MAJOR = int(torch.__version__.split('.')[0])
11 | TORCH_MINOR = int(torch.__version__.split('.')[1])
12 |
13 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
14 | from torch._six import inf
15 | else:
16 | from torch import inf
17 | import copy
18 |
19 |
20 | class SmoothedValue(object):
21 | """Track a series of values and provide access to smoothed values over a
22 | window or the global series average.
23 | """
24 |
25 | def __init__(self, window_size=20, fmt=None):
26 | if fmt is None:
27 | fmt = "{median:.4f} ({global_avg:.4f})"
28 | self.deque = deque(maxlen=window_size)
29 | self.total = 0.0
30 | self.count = 0
31 | self.fmt = fmt
32 |
33 | def update(self, value, n=1):
34 | self.deque.append(value)
35 | self.count += n
36 | self.total += value * n
37 |
38 | def synchronize_between_processes(self):
39 | """
40 | Warning: does not synchronize the deque!
41 | """
42 | if not is_dist_avail_and_initialized():
43 | return
44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
45 | dist.barrier()
46 | dist.all_reduce(t)
47 | t = t.tolist()
48 | self.count = int(t[0])
49 | self.total = t[1]
50 |
51 | @property
52 | def median(self):
53 | d = torch.tensor(list(self.deque))
54 | return d.median().item()
55 |
56 | @property
57 | def avg(self):
58 | d = torch.tensor(list(self.deque), dtype=torch.float32)
59 | return d.mean().item()
60 |
61 | @property
62 | def global_avg(self):
63 | return self.total / self.count
64 |
65 | @property
66 | def max(self):
67 | return max(self.deque)
68 |
69 | @property
70 | def value(self):
71 | return self.deque[-1]
72 |
73 | def __str__(self):
74 | return self.fmt.format(
75 | median=self.median,
76 | avg=self.avg,
77 | global_avg=self.global_avg,
78 | max=self.max,
79 | value=self.value)
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if v is None:
90 | continue
91 | if isinstance(v, torch.Tensor):
92 | v = v.item()
93 | assert isinstance(v, (float, int))
94 | self.meters[k].update(v)
95 |
96 | def __getattr__(self, attr):
97 | if attr in self.meters:
98 | return self.meters[attr]
99 | if attr in self.__dict__:
100 | return self.__dict__[attr]
101 | raise AttributeError("'{}' object has no attribute '{}'".format(
102 | type(self).__name__, attr))
103 |
104 | def __str__(self):
105 | loss_str = []
106 | for name, meter in self.meters.items():
107 | loss_str.append(
108 | "{}: {}".format(name, str(meter))
109 | )
110 | return self.delimiter.join(loss_str)
111 |
112 | def synchronize_between_processes(self):
113 | for meter in self.meters.values():
114 | meter.synchronize_between_processes()
115 |
116 | def add_meter(self, name, meter):
117 | self.meters[name] = meter
118 |
119 | def log_every(self, iterable, print_freq, header=None):
120 | i = 0
121 | if not header:
122 | header = ''
123 | start_time = time.time()
124 | end = time.time()
125 | iter_time = SmoothedValue(fmt='{avg:.4f}')
126 | data_time = SmoothedValue(fmt='{avg:.4f}')
127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
128 | log_msg = [
129 | header,
130 | '[{0' + space_fmt + '}/{1}]',
131 | 'eta: {eta}',
132 | '{meters}',
133 | 'time: {time}',
134 | 'data: {data}'
135 | ]
136 | if torch.cuda.is_available():
137 | log_msg.append('max mem: {memory:.0f}')
138 | log_msg = self.delimiter.join(log_msg)
139 | MB = 1024.0 * 1024.0
140 | for obj in iterable:
141 | data_time.update(time.time() - end)
142 | yield obj
143 | iter_time.update(time.time() - end)
144 | if i % print_freq == 0 or i == len(iterable) - 1:
145 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
147 | if torch.cuda.is_available():
148 | print(log_msg.format(
149 | i, len(iterable), eta=eta_string,
150 | meters=str(self),
151 | time=str(iter_time), data=str(data_time),
152 | memory=torch.cuda.max_memory_allocated() / MB))
153 | else:
154 | print(log_msg.format(
155 | i, len(iterable), eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time), data=str(data_time)))
158 | i += 1
159 | end = time.time()
160 | total_time = time.time() - start_time
161 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162 | print('{} Total time: {} ({:.4f} s / it)'.format(
163 | header, total_time_str, total_time / len(iterable)))
164 |
165 |
166 | def setup_for_distributed(is_master):
167 | """
168 | This function disables printing when not in master process
169 | """
170 | builtin_print = builtins.print
171 |
172 | def print(*args, **kwargs):
173 | force = kwargs.pop('force', False)
174 | force = force or (get_world_size() > 8)
175 | if is_master or force:
176 | now = datetime.datetime.now().time()
177 | builtin_print('[{}] '.format(now), end='') # print with time stamp
178 | builtin_print(*args, **kwargs)
179 |
180 | builtins.print = print
181 |
182 |
183 | def is_dist_avail_and_initialized():
184 | if not dist.is_available():
185 | return False
186 | if not dist.is_initialized():
187 | return False
188 | return True
189 |
190 |
191 | def get_world_size():
192 | if not is_dist_avail_and_initialized():
193 | return 1
194 | return dist.get_world_size()
195 |
196 |
197 | def get_rank():
198 | if not is_dist_avail_and_initialized():
199 | return 0
200 | return dist.get_rank()
201 |
202 |
203 | def is_main_process():
204 | return get_rank() == 0
205 |
206 |
207 | def save_on_master(*args, **kwargs):
208 | if is_main_process():
209 | torch.save(*args, **kwargs)
210 |
211 |
212 | def init_distributed_mode(args):
213 | if args.dist_on_itp:
214 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
215 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
216 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
217 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
218 | os.environ['LOCAL_RANK'] = str(args.gpu)
219 | os.environ['RANK'] = str(args.rank)
220 | os.environ['WORLD_SIZE'] = str(args.world_size)
221 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
222 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
223 | args.rank = int(os.environ["RANK"])
224 | args.world_size = int(os.environ['WORLD_SIZE'])
225 | args.gpu = int(os.environ['LOCAL_RANK'])
226 | elif 'SLURM_PROCID' in os.environ:
227 | args.rank = int(os.environ['SLURM_PROCID'])
228 | args.gpu = args.rank % torch.cuda.device_count()
229 | else:
230 | print('Not using distributed mode')
231 | setup_for_distributed(is_master=True) # hack
232 | args.distributed = False
233 | return
234 |
235 | args.distributed = True
236 |
237 | torch.cuda.set_device(args.gpu)
238 | args.dist_backend = 'nccl'
239 | print('| distributed init (rank {}): {}, gpu {}'.format(
240 | args.rank, args.dist_url, args.gpu), flush=True)
241 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
242 | world_size=args.world_size, rank=args.rank)
243 | torch.distributed.barrier()
244 | setup_for_distributed(args.rank == 0)
245 |
246 |
247 | class NativeScalerWithGradNormCount:
248 | state_dict_key = "amp_scaler"
249 |
250 | def __init__(self):
251 | self._scaler = torch.cuda.amp.GradScaler()
252 |
253 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
254 | self._scaler.scale(loss).backward(create_graph=create_graph)
255 | if update_grad:
256 | if clip_grad is not None:
257 | assert parameters is not None
258 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
259 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
260 | else:
261 | self._scaler.unscale_(optimizer)
262 | norm = get_grad_norm_(parameters)
263 | self._scaler.step(optimizer)
264 | self._scaler.update()
265 | else:
266 | norm = None
267 | return norm
268 |
269 | def state_dict(self):
270 | return self._scaler.state_dict()
271 |
272 | def load_state_dict(self, state_dict):
273 | self._scaler.load_state_dict(state_dict)
274 |
275 |
276 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
277 | if isinstance(parameters, torch.Tensor):
278 | parameters = [parameters]
279 | parameters = [p for p in parameters if p.grad is not None]
280 | norm_type = float(norm_type)
281 | if len(parameters) == 0:
282 | return torch.tensor(0.)
283 | device = parameters[0].grad.device
284 | if norm_type == inf:
285 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
286 | else:
287 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
288 | return total_norm
289 |
290 |
291 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
292 | decay = []
293 | no_decay = []
294 | for name, param in model.named_parameters():
295 | if not param.requires_grad:
296 | continue # frozen weights
297 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
298 | no_decay.append(param) # no weight decay on bias, norm and diffloss
299 | else:
300 | decay.append(param)
301 | return [
302 | {'params': no_decay, 'weight_decay': 0.},
303 | {'params': decay, 'weight_decay': weight_decay}]
304 |
305 |
306 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
307 | if epoch_name is None:
308 | epoch_name = str(epoch)
309 | output_dir = Path(args.output_dir)
310 | checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
311 |
312 | # ema
313 | if ema_params is not None:
314 | ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
315 | for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
316 | assert name in ema_state_dict
317 | ema_state_dict[name] = ema_params[i]
318 | else:
319 | ema_state_dict = None
320 |
321 | to_save = {
322 | 'model': model_without_ddp.state_dict(),
323 | 'model_ema': ema_state_dict,
324 | 'optimizer': optimizer.state_dict(),
325 | 'epoch': epoch,
326 | 'scaler': loss_scaler.state_dict(),
327 | 'args': args,
328 | }
329 | save_on_master(to_save, checkpoint_path)
330 |
331 |
332 | def all_reduce_mean(x):
333 | world_size = get_world_size()
334 | if world_size > 1:
335 | x_reduce = torch.tensor(x).cuda()
336 | dist.all_reduce(x_reduce)
337 | x_reduce /= world_size
338 | return x_reduce.item()
339 | else:
340 | return x
--------------------------------------------------------------------------------