├── NestedPipeline.py ├── NestedScheduler.py ├── README.md ├── core └── util.py ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── download.py ├── environment.yml ├── figures ├── Nested_Egg.png ├── arxiv.svg └── generation.gif ├── generate_imagenet.py ├── main.py ├── models.py └── sample_text_to_image.py /NestedPipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import torch 5 | from diffusers.utils import replace_example_docstring 6 | from transformers import CLIPTokenizer 7 | 8 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, EXAMPLE_DOC_STRING 9 | 10 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 11 | 12 | from NestedScheduler import NestedScheduler, NestedDPMSolverSinglestepScheduler 13 | 14 | 15 | class NestedStableDiffusionPipeline(StableDiffusionPipeline): 16 | r""" 17 | Pipeline for text-to-image generation using Nested Stable Diffusion. 18 | 19 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 20 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 21 | 22 | In addition the pipeline inherits the following loading methods: 23 | - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] 24 | - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] 25 | - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] 26 | 27 | as well as the following saving methods: 28 | - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] 29 | 30 | Args: 31 | vae ([`AutoencoderKL`]): 32 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 33 | text_encoder ([`CLIPTextModel`]): 34 | Frozen text-encoder. Stable Diffusion uses the text portion of 35 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 36 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 37 | tokenizer (`CLIPTokenizer`): 38 | Tokenizer of class 39 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 40 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 41 | scheduler ([`SchedulerMixin`]): 42 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 43 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 44 | safety_checker ([`StableDiffusionSafetyChecker`]): 45 | Classification module that estimates whether generated images could be considered offensive or harmful. 46 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 47 | feature_extractor ([`CLIPImageProcessor`]): 48 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 49 | """ 50 | 51 | @torch.no_grad() 52 | @replace_example_docstring(EXAMPLE_DOC_STRING) 53 | def __call__( 54 | self, 55 | prompt: Union[str, List[str]] = None, 56 | height: Optional[int] = None, 57 | width: Optional[int] = None, 58 | num_inference_steps: int = 5, 59 | num_inner_steps: int = 20, 60 | guidance_scale: float = 7.5, 61 | negative_prompt: Optional[Union[str, List[str]]] = None, 62 | num_images_per_prompt: Optional[int] = 1, 63 | eta: float = 0.0, 64 | inner_scheduler: str = "DDIM", 65 | inner_eta: float = 0.85, 66 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 67 | latents: Optional[torch.FloatTensor] = None, 68 | prompt_embeds: Optional[torch.FloatTensor] = None, 69 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 70 | output_type: Optional[str] = "pil", 71 | return_dict: bool = True, 72 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 73 | callback_steps: int = 1, 74 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 75 | ): 76 | r""" 77 | Function invoked when calling the pipeline for generation. 78 | 79 | Args: 80 | prompt (`str` or `List[str]`, *optional*): 81 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 82 | instead. 83 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 84 | The height in pixels of the generated image. 85 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 86 | The width in pixels of the generated image. 87 | num_inference_steps (`int`, *optional*, defaults to 5): 88 | The number of outer denoising steps. 89 | num_inner_steps (`int`, *optional*, defaults to 20): 90 | The number of inner denoising steps. 91 | guidance_scale (`float`, *optional*, defaults to 7.5): 92 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 93 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 94 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 95 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 96 | usually at the expense of lower image quality. 97 | negative_prompt (`str` or `List[str]`, *optional*): 98 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 99 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 100 | less than `1`). 101 | num_images_per_prompt (`int`, *optional*, defaults to 1): 102 | The number of images to generate per prompt. 103 | eta (`float`, *optional*, defaults to 0.0): 104 | Corresponds to parameter eta (η) in the outer diffusion process 105 | inner_scheduler (`str`, *optional*, defaults to 'DDIM'): 106 | The scheduler to use for the inner diffusion process 107 | inner_eta (`float`, *optional*, defaults to 0.85): 108 | Corresponds to parameter eta (η) in the inner diffusion process 109 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 110 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 111 | to make generation deterministic. 112 | latents (`torch.FloatTensor`, *optional*): 113 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 114 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 115 | tensor will ge generated by sampling using the supplied random `generator`. 116 | prompt_embeds (`torch.FloatTensor`, *optional*): 117 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 118 | provided, text embeddings will be generated from `prompt` input argument. 119 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 120 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 121 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 122 | argument. 123 | output_type (`str`, *optional*, defaults to `"pil"`): 124 | The output format of the generate image. Choose between 125 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 126 | return_dict (`bool`, *optional*, defaults to `True`): 127 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 128 | plain tuple. 129 | callback (`Callable`, *optional*): 130 | A function that will be called every `callback_steps` steps during inference. The function will be 131 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 132 | callback_steps (`int`, *optional*, defaults to 1): 133 | The frequency at which the `callback` function will be called. If not specified, the callback will be 134 | called at every step. 135 | cross_attention_kwargs (`dict`, *optional*): 136 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 137 | `self.processor` in 138 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 139 | 140 | Examples: 141 | 142 | Returns: 143 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 144 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 145 | When returning a tuple, the first element is a list with the generated images, and the second element is a 146 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 147 | (nsfw) content, according to the `safety_checker`. 148 | """ 149 | # 0. Default height and width to unet 150 | height = height or self.unet.config.sample_size * self.vae_scale_factor 151 | width = width or self.unet.config.sample_size * self.vae_scale_factor 152 | 153 | # 1. Check inputs. Raise error if not correct 154 | self.check_inputs( 155 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 156 | ) 157 | 158 | # 2. Define call parameters 159 | if prompt is not None and isinstance(prompt, str): 160 | batch_size = 1 161 | elif prompt is not None and isinstance(prompt, list): 162 | batch_size = len(prompt) 163 | else: 164 | batch_size = prompt_embeds.shape[0] 165 | 166 | device = self._execution_device 167 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 168 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 169 | # corresponds to doing no classifier free guidance. 170 | do_classifier_free_guidance = guidance_scale > 1.0 171 | 172 | # 3. Encode input prompt 173 | prompt_embeds = self._encode_prompt( 174 | prompt, 175 | device, 176 | num_images_per_prompt, 177 | do_classifier_free_guidance, 178 | negative_prompt, 179 | prompt_embeds=prompt_embeds, 180 | negative_prompt_embeds=negative_prompt_embeds, 181 | ) 182 | 183 | # 4. Prepare timesteps 184 | self.scheduler.set_timesteps(num_inference_steps + 1, device=device) 185 | timesteps = self.scheduler.timesteps[:-1] 186 | 187 | # 5. Prepare latent variables 188 | num_channels_latents = self.unet.config.in_channels 189 | latents = self.prepare_latents( 190 | batch_size * num_images_per_prompt, 191 | num_channels_latents, 192 | height, 193 | width, 194 | prompt_embeds.dtype, 195 | device, 196 | generator, 197 | latents, 198 | ) 199 | 200 | # 6. Prepare extra step kwargs. 201 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 202 | inner_extra_step_kwargs = self.prepare_extra_step_kwargs(generator, inner_eta) 203 | 204 | # 7. Denoising loop 205 | outer_latents = latents.clone() 206 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 207 | with self.progress_bar(total=num_inference_steps) as progress_bar: 208 | # running the outer diffusion procees 209 | for i, t in enumerate(timesteps): 210 | # creating the inner diffusion process 211 | if inner_scheduler == "DDIM": 212 | self.inner_scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, 213 | beta_schedule="scaled_linear", clip_sample=False, 214 | set_alpha_to_one=False, thresholding=False) 215 | elif inner_scheduler == "DPMSolver": 216 | self.inner_scheduler = NestedDPMSolverSinglestepScheduler(beta_start=0.00085, beta_end=0.012, 217 | beta_schedule="scaled_linear") 218 | self.inner_scheduler.set_timesteps(num_inner_steps, max_timestep=t.item(), device=device) 219 | inner_timesteps = self.inner_scheduler.timesteps 220 | latents = outer_latents.clone() 221 | # running the inner diffusion procees 222 | for j, t_tag in enumerate(inner_timesteps): 223 | # expand the latents if we are doing classifier free guidance 224 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 225 | latent_model_input = self.inner_scheduler.scale_model_input(latent_model_input, t_tag) 226 | 227 | # predict the noise residual 228 | noise_pred = self.unet(latent_model_input, t_tag, encoder_hidden_states=prompt_embeds).sample 229 | 230 | # perform guidance 231 | if do_classifier_free_guidance: 232 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 233 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 234 | 235 | latents = self.inner_scheduler.step(noise_pred, t_tag, latents, **inner_extra_step_kwargs).prev_sample 236 | 237 | # compute the previous noisy sample x_t -> x_t-1 238 | outer_latents = self.scheduler.step(latents, t, outer_latents, **extra_step_kwargs).prev_sample 239 | 240 | # call the callback, if provided 241 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 242 | progress_bar.update() 243 | if callback is not None and i % callback_steps == 0: 244 | callback(i, t, latents) 245 | 246 | if output_type == "latent": 247 | image = latents 248 | has_nsfw_concept = None 249 | elif output_type == "pil": 250 | # 8. Post-processing 251 | image = self.decode_latents(latents) 252 | 253 | # 9. Run safety checker 254 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 255 | 256 | # 10. Convert to PIL 257 | image = self.numpy_to_pil(image) 258 | else: 259 | # 8. Post-processing 260 | image = self.decode_latents(latents) 261 | 262 | # 9. Run safety checker 263 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 264 | 265 | # Offload last model to CPU 266 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 267 | self.final_offload_hook.offload() 268 | 269 | if not return_dict: 270 | yield (image, has_nsfw_concept) 271 | 272 | yield StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) -------------------------------------------------------------------------------- /NestedScheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DDIMScheduler, DPMSolverSinglestepScheduler 8 | 9 | from diffusers.utils import BaseOutput 10 | 11 | 12 | @dataclass 13 | class NestedSchedulerOutput(BaseOutput): 14 | """ 15 | Output class for the scheduler's step function output. 16 | 17 | Args: 18 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 19 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 20 | denoising loop. 21 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 22 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. 23 | `pred_original_sample` can be used to preview progress or for guidance. 24 | """ 25 | 26 | prev_sample: torch.FloatTensor 27 | pred_original_sample: Optional[torch.FloatTensor] = None 28 | 29 | 30 | 31 | class NestedScheduler(DDIMScheduler): 32 | 33 | def set_timesteps(self, num_inference_steps: int, max_timestep: int = 1000, device: Union[str, torch.device] = None): 34 | """ 35 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 36 | 37 | Args: 38 | num_inference_steps (`int`): 39 | the number of diffusion steps used when generating figures with a pre-trained model. 40 | max_timestep (`int`): 41 | the highest timestep to use for choosing the timesteps 42 | """ 43 | 44 | if num_inference_steps > self.config.num_train_timesteps: 45 | raise ValueError( 46 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 47 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 48 | f" maximal {self.config.num_train_timesteps} timesteps." 49 | ) 50 | 51 | self.num_inference_steps = num_inference_steps 52 | max_timestep = min(self.config.num_train_timesteps - 1, max_timestep) 53 | timesteps = np.linspace(1, max_timestep, min(num_inference_steps, max_timestep)).round()[::-1].copy().astype(np.int64) 54 | self.timesteps = torch.from_numpy(timesteps).to(device) 55 | 56 | def step( 57 | self, 58 | model_output: torch.FloatTensor, 59 | timestep: int, 60 | sample: torch.FloatTensor, 61 | eta: float = 0.0, 62 | use_clipped_model_output: bool = False, 63 | generator=None, 64 | variance_noise: Optional[torch.FloatTensor] = None, 65 | return_dict: bool = True, 66 | override_prediction_type = '', 67 | ) -> Union[NestedSchedulerOutput, Tuple]: 68 | """ 69 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 70 | process from the learned model outputs (most often the predicted noise). 71 | 72 | Args: 73 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 74 | timestep (`int`): current discrete timestep in the diffusion chain. 75 | sample (`torch.FloatTensor`): 76 | current instance of sample being created by diffusion process. 77 | eta (`float`): weight of noise for added noise in diffusion step. 78 | use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped 79 | predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when 80 | `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would 81 | coincide with the one provided as input and `use_clipped_model_output` will have not effect. 82 | generator: random number generator. 83 | variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we 84 | can directly provide the noise for the variance itself. This is useful for methods such as 85 | CycleDiffusion. (https://arxiv.org/abs/2210.05559) 86 | return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class 87 | 88 | Returns: 89 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 90 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When 91 | returning a tuple, the first element is the sample tensor. 92 | 93 | """ 94 | if self.num_inference_steps is None: 95 | raise ValueError( 96 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 97 | ) 98 | 99 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 100 | # Ideally, read DDIM paper in-detail understanding 101 | 102 | # Notation ( -> 103 | # - pred_noise_t -> e_theta(x_t, t) 104 | # - pred_original_sample -> f_theta(x_t, t) or x_0 105 | # - std_dev_t -> sigma_t 106 | # - eta -> η 107 | # - pred_sample_direction -> "direction pointing to x_t" 108 | # - pred_prev_sample -> "x_t-1" 109 | 110 | # 1. get previous step value (=t-1) 111 | # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 112 | cur_idx = (self.timesteps == timestep).nonzero().item() 113 | prev_timestep = self.timesteps[cur_idx + 1] if cur_idx < len(self.timesteps) - 1 else 0 114 | 115 | # 2. compute alphas, betas 116 | alpha_prod_t = self.alphas_cumprod[timestep] 117 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 118 | 119 | beta_prod_t = 1 - alpha_prod_t 120 | 121 | # 3. compute predicted original sample from predicted noise also called 122 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 123 | prediction_type = override_prediction_type if override_prediction_type else self.config.prediction_type 124 | if prediction_type == "epsilon": 125 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 126 | pred_epsilon = model_output 127 | elif prediction_type == "sample": 128 | pred_original_sample = model_output 129 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 130 | elif prediction_type == "v_prediction": 131 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 132 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 133 | else: 134 | raise ValueError( 135 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 136 | " `v_prediction`" 137 | ) 138 | 139 | # 4. Clip or threshold "predicted x_0" 140 | if self.config.thresholding: 141 | pred_original_sample = self._threshold_sample(pred_original_sample) 142 | elif self.config.clip_sample: 143 | pred_original_sample = pred_original_sample.clamp( 144 | -self.config.clip_sample_range, self.config.clip_sample_range 145 | ) 146 | 147 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 148 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 149 | variance = self._get_variance(timestep, prev_timestep) 150 | std_dev_t = eta * variance ** (0.5) 151 | 152 | if use_clipped_model_output: 153 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 154 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 155 | 156 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 157 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 158 | 159 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 160 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 161 | 162 | if eta > 0: 163 | if variance_noise is not None and generator is not None: 164 | raise ValueError( 165 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" 166 | " `variance_noise` stays `None`." 167 | ) 168 | 169 | if variance_noise is None: 170 | variance_noise = torch.randn( 171 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 172 | ) 173 | variance = std_dev_t * variance_noise 174 | 175 | prev_sample = prev_sample + variance 176 | 177 | if not return_dict: 178 | return (prev_sample,) 179 | 180 | return NestedSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 181 | 182 | class NestedDPMSolverSinglestepScheduler(DPMSolverSinglestepScheduler): 183 | def set_timesteps(self, num_inference_steps: int, max_timestep: int = 1000, device: Union[str, torch.device] = None): 184 | """ 185 | Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. 186 | 187 | Args: 188 | num_inference_steps (`int`): 189 | the number of diffusion steps used when generating samples with a pre-trained model. 190 | device (`str` or `torch.device`, optional): 191 | the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 192 | """ 193 | self.num_inference_steps = num_inference_steps 194 | timesteps = np.linspace(0, max_timestep, min(num_inference_steps, max_timestep)+1).round()[::-1][:-1].copy().astype(np.int64) 195 | # timesteps = ( 196 | # np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) 197 | # .round()[::-1][:-1] 198 | # .copy() 199 | # .astype(np.int64) 200 | # ) 201 | self.timesteps = torch.from_numpy(timesteps).to(device) 202 | self.model_outputs = [ 203 | None, 204 | ] * self.config.solver_order 205 | self.lower_order_nums = 0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nested Diffusion Processes for Anytime Image Generation 2 | Noam Elata, Bahjat Kawar, Tomer Michaeli, and Michael Elad, Technion - Israel Institute of Technology.
3 | 4 | Nested Diffusion 5 | 6 | 🔗 [`Project Webpage`](https://noamelata.github.io/NestedDiffusion/) | 🤗 [`Huggingface Demo`](https://huggingface.co/spaces/noamelata/Nested-Diffusion) | [`ArXiv`](https://arxiv.org/abs/2305.19066) 7 | 8 | ## Preparation 9 | 10 | Please refer to `environment.yml` for packages that can be used to run the code. 11 | 12 | Pretrained models will download automatically upon running the relevant scripts. 13 | 14 | ## Sampling from Nested Stable Diffusion 15 | This code implements Nested Diffusion, used with Stable Diffusion v1-5 and the Diffusers library. 16 | 17 | Image generation can be run with the following script: 18 | ``` 19 | python sample_text_to_image.py --outer --inner \ 20 | --outdir samples --prompt "" 21 | ``` 22 | Default parameters reproduce the image shown at the top. 23 | 24 | To run Nested Diffusion with accelerated DPM-Solver++ inner diffusion scheduler, please add the `--dpm-solver` argument to the script. 25 | It is possible to use fewer inner steps with DPM-Solver++. 26 | 27 | If you have less than 10GB of GPU RAM available please add the `--fp16` argument to the script. 28 | 29 | 30 | ## ImageNet Generation with Nested Diffusion 31 | This code implements Nested Diffusion, used with DiT as a pretrained model. 32 | 33 | Image generation with Nested Diffusion can be run with the following script: 34 | ``` 35 | python generate_imagenet.py --steps --nested \ 36 | -n -b \ 37 | -o --cfg 38 | ``` 39 | For Vanilla DiT diffusion process please run with the following arguments: 40 | ``` 41 | python generate_imagenet.py --steps \ 42 | -n -b \ 43 | -o --cfg 44 | ``` 45 | This script will save all intermediate predictions in separate directories. 46 | 47 | For multi-GPU sampling or ddim sampling, please see the usage of `generate_imagenet.py`. 48 | 49 | 50 | 51 | ## References and Acknowledgements 52 | ```BibTeX 53 | @InProceedings{Elata_2024_WACV, 54 | author = {Elata, Noam and Kawar, Bahjat and Michaeli, Tomer and Elad, Michael}, 55 | title = {Nested Diffusion Processes for Anytime Image Generation}, 56 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 57 | month = {January}, 58 | year = {2024}, 59 | pages = {5018-5027} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | import torch 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 10 | ''' 11 | Converts a torch Tensor into an image Numpy array 12 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 13 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 14 | ''' 15 | tensor = tensor.squeeze().clamp_(*min_max) # clamp 16 | n_dim = tensor.dim() 17 | if n_dim == 4: 18 | n_img = len(tensor) 19 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 20 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 21 | elif n_dim == 3: 22 | img_np = tensor.numpy() 23 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 24 | elif n_dim == 2: 25 | img_np = tensor.numpy() 26 | else: 27 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 28 | if out_type == np.uint8: 29 | img_np = ((img_np+1) * 127.5).round() 30 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 31 | return img_np.astype(out_type) 32 | 33 | def postprocess(images): 34 | return [tensor2img(image) for image in images] 35 | 36 | 37 | def set_seed(seed, gl_seed=0): 38 | """ set random seed, gl_seed used in worker_init_fn function """ 39 | if seed >=0 and gl_seed>=0: 40 | seed += gl_seed 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. 47 | speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' 48 | if seed >=0 and gl_seed>=0: # slower, more reproducible 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | else: # faster, less reproducible 52 | torch.backends.cudnn.deterministic = False 53 | torch.backends.cudnn.benchmark = True 54 | 55 | def set_gpu(args, distributed=False, rank=0): 56 | """ set parameter to gpu or ddp """ 57 | if args is None: 58 | return None 59 | if distributed and isinstance(args, torch.nn.Module): 60 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True) 61 | else: 62 | return args.cuda() 63 | 64 | def set_device(args, distributed=False, rank=0): 65 | """ set parameter to gpu or cpu """ 66 | if torch.cuda.is_available(): 67 | if isinstance(args, list): 68 | return (set_gpu(item, distributed, rank) for item in args) 69 | elif isinstance(args, dict): 70 | return {key:set_gpu(args[key], distributed, rank) for key in args} 71 | else: 72 | args = set_gpu(args, distributed, rank) 73 | return args 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gaussian_diffusion as gd 2 | from .respace import SpacedDiffusion, space_timesteps 3 | 4 | 5 | def create_diffusion( 6 | timestep_respacing, 7 | noise_schedule="linear", 8 | use_kl=False, 9 | sigma_small=False, 10 | predict_xstart=False, 11 | learn_sigma=True, 12 | rescale_learned_sigmas=False, 13 | diffusion_steps=1000, 14 | max_step=None 15 | ): 16 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 17 | if use_kl: 18 | loss_type = gd.LossType.RESCALED_KL 19 | elif rescale_learned_sigmas: 20 | loss_type = gd.LossType.RESCALED_MSE 21 | else: 22 | loss_type = gd.LossType.MSE 23 | if timestep_respacing is None or timestep_respacing == "": 24 | timestep_respacing = [diffusion_steps] 25 | return SpacedDiffusion( 26 | # the max_step argument sets the timesteps to be selected from a subset of timesteps - those below max_steps 27 | use_timesteps=space_timesteps(diffusion_steps if max_step is None else max_step, timestep_respacing), 28 | max_step=max_step, 29 | betas=betas, 30 | model_mean_type=( 31 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 32 | ), 33 | model_var_type=( 34 | ( 35 | gd.ModelVarType.FIXED_LARGE 36 | if not sigma_small 37 | else gd.ModelVarType.FIXED_SMALL 38 | ) 39 | if not learn_sigma 40 | else gd.ModelVarType.LEARNED_RANGE 41 | ), 42 | loss_type=loss_type 43 | # rescale_timesteps=rescale_timesteps, 44 | ) 45 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | 5 | def normal_kl(mean1, logvar1, mean2, logvar2): 6 | """ 7 | Compute the KL divergence between two gaussians. 8 | Shapes are automatically broadcasted, so batches can be compared to 9 | scalars, among other use cases. 10 | """ 11 | tensor = None 12 | for obj in (mean1, logvar1, mean2, logvar2): 13 | if isinstance(obj, th.Tensor): 14 | tensor = obj 15 | break 16 | assert tensor is not None, "at least one argument must be a Tensor" 17 | 18 | # Force variances to be Tensors. Broadcasting helps convert scalars to 19 | # Tensors, but it does not work for th.exp(). 20 | logvar1, logvar2 = [ 21 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 22 | for x in (logvar1, logvar2) 23 | ] 24 | 25 | return 0.5 * ( 26 | -1.0 27 | + logvar2 28 | - logvar1 29 | + th.exp(logvar1 - logvar2) 30 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 31 | ) 32 | 33 | 34 | def approx_standard_normal_cdf(x): 35 | """ 36 | A fast approximation of the cumulative distribution function of the 37 | standard normal. 38 | """ 39 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 40 | 41 | 42 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 43 | """ 44 | Compute the log-likelihood of a continuous Gaussian distribution. 45 | :param x: the targets 46 | :param means: the Gaussian mean Tensor. 47 | :param log_scales: the Gaussian log stddev Tensor. 48 | :return: a tensor like x of log probabilities (in nats). 49 | """ 50 | centered_x = x - means 51 | inv_stdv = th.exp(-log_scales) 52 | normalized_x = centered_x * inv_stdv 53 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 54 | return log_probs 55 | 56 | 57 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 58 | """ 59 | Compute the log-likelihood of a Gaussian distribution discretizing to a 60 | given image. 61 | :param x: the target images. It is assumed that this was uint8 values, 62 | rescaled to the range [-1, 1]. 63 | :param means: the Gaussian mean Tensor. 64 | :param log_scales: the Gaussian log stddev Tensor. 65 | :return: a tensor like x of log probabilities (in nats). 66 | """ 67 | assert x.shape == means.shape == log_scales.shape 68 | centered_x = x - means 69 | inv_stdv = th.exp(-log_scales) 70 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 71 | cdf_plus = approx_standard_normal_cdf(plus_in) 72 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 73 | cdf_min = approx_standard_normal_cdf(min_in) 74 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 75 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 76 | cdf_delta = cdf_plus - cdf_min 77 | log_probs = th.where( 78 | x < -0.999, 79 | log_cdf_plus, 80 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 81 | ) 82 | assert log_probs.shape == x.shape 83 | return log_probs 84 | -------------------------------------------------------------------------------- /diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch as th 8 | import enum 9 | 10 | import torchvision 11 | 12 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 13 | 14 | def mean_flat(tensor): 15 | """ 16 | Take the mean over all non-batch dimensions. 17 | """ 18 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 19 | 20 | 21 | class ModelMeanType(enum.Enum): 22 | """ 23 | Which type of output the model predicts. 24 | """ 25 | 26 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 27 | START_X = enum.auto() # the model predicts x_0 28 | EPSILON = enum.auto() # the model predicts epsilon 29 | 30 | 31 | class ModelVarType(enum.Enum): 32 | """ 33 | What is used as the model's output variance. 34 | The LEARNED_RANGE option has been added to allow the model to predict 35 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 36 | """ 37 | 38 | LEARNED = enum.auto() 39 | FIXED_SMALL = enum.auto() 40 | FIXED_LARGE = enum.auto() 41 | LEARNED_RANGE = enum.auto() 42 | 43 | 44 | class LossType(enum.Enum): 45 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 46 | RESCALED_MSE = ( 47 | enum.auto() 48 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 49 | KL = enum.auto() # use the variational lower-bound 50 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 51 | 52 | def is_vb(self): 53 | return self == LossType.KL or self == LossType.RESCALED_KL 54 | 55 | 56 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 57 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 58 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 59 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 60 | return betas 61 | 62 | 63 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 64 | """ 65 | This is the deprecated API for creating beta schedules. 66 | See get_named_beta_schedule() for the new library of schedules. 67 | """ 68 | if beta_schedule == "quad": 69 | betas = ( 70 | np.linspace( 71 | beta_start ** 0.5, 72 | beta_end ** 0.5, 73 | num_diffusion_timesteps, 74 | dtype=np.float64, 75 | ) 76 | ** 2 77 | ) 78 | elif beta_schedule == "linear": 79 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 80 | elif beta_schedule == "warmup10": 81 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 82 | elif beta_schedule == "warmup50": 83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 84 | elif beta_schedule == "const": 85 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 86 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 87 | betas = 1.0 / np.linspace( 88 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 89 | ) 90 | else: 91 | raise NotImplementedError(beta_schedule) 92 | assert betas.shape == (num_diffusion_timesteps,) 93 | return betas 94 | 95 | 96 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 97 | """ 98 | Get a pre-defined beta schedule for the given name. 99 | The beta schedule library consists of beta schedules which remain similar 100 | in the limit of num_diffusion_timesteps. 101 | Beta schedules may be added, but should not be removed or changed once 102 | they are committed to maintain backwards compatibility. 103 | """ 104 | if schedule_name == "linear": 105 | # Linear schedule from Ho et al, extended to work for any number of 106 | # diffusion steps. 107 | scale = 1000 / num_diffusion_timesteps 108 | return get_beta_schedule( 109 | "linear", 110 | beta_start=scale * 0.0001, 111 | beta_end=scale * 0.02, 112 | num_diffusion_timesteps=num_diffusion_timesteps, 113 | ) 114 | elif schedule_name == "squaredcos_cap_v2": 115 | return betas_for_alpha_bar( 116 | num_diffusion_timesteps, 117 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 118 | ) 119 | else: 120 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 121 | 122 | 123 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 124 | """ 125 | Create a beta schedule that discretizes the given alpha_t_bar function, 126 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 127 | :param num_diffusion_timesteps: the number of betas to produce. 128 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 129 | produces the cumulative product of (1-beta) up to that 130 | part of the diffusion process. 131 | :param max_beta: the maximum beta to use; use values lower than 1 to 132 | prevent singularities. 133 | """ 134 | betas = [] 135 | for i in range(num_diffusion_timesteps): 136 | t1 = i / num_diffusion_timesteps 137 | t2 = (i + 1) / num_diffusion_timesteps 138 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 139 | return np.array(betas) 140 | 141 | 142 | class GaussianDiffusion: 143 | """ 144 | Utilities for training and sampling diffusion models. 145 | Original ported from this codebase: 146 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 147 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 148 | starting at T and going to 1. 149 | """ 150 | 151 | def __init__( 152 | self, 153 | *, 154 | betas, 155 | model_mean_type, 156 | model_var_type, 157 | loss_type 158 | ): 159 | 160 | self.model_mean_type = model_mean_type 161 | self.model_var_type = model_var_type 162 | self.loss_type = loss_type 163 | 164 | # Use float64 for accuracy. 165 | betas = np.array(betas, dtype=np.float64) 166 | self.betas = betas 167 | assert len(betas.shape) == 1, "betas must be 1-D" 168 | assert (betas > 0).all() and (betas <= 1).all() 169 | 170 | self.num_timesteps = int(betas.shape[0]) 171 | 172 | alphas = 1.0 - betas 173 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 174 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 175 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 176 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 177 | 178 | # calculations for diffusion q(x_t | x_{t-1}) and others 179 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 180 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 181 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 182 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 183 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 184 | 185 | # calculations for posterior q(x_{t-1} | x_t, x_0) 186 | self.posterior_variance = ( 187 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 188 | ) 189 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 190 | self.posterior_log_variance_clipped = np.log( 191 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 192 | ) if len(self.posterior_variance) > 1 else np.array([]) 193 | 194 | self.posterior_mean_coef1 = ( 195 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 196 | ) 197 | self.posterior_mean_coef2 = ( 198 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 199 | ) 200 | 201 | def q_mean_variance(self, x_start, t): 202 | """ 203 | Get the distribution q(x_t | x_0). 204 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 205 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 206 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 207 | """ 208 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 209 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 210 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 211 | return mean, variance, log_variance 212 | 213 | def q_sample(self, x_start, t, noise=None): 214 | """ 215 | Diffuse the data for a given number of diffusion steps. 216 | In other words, sample from q(x_t | x_0). 217 | :param x_start: the initial data batch. 218 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 219 | :param noise: if specified, the split-out normal noise. 220 | :return: A noisy version of x_start. 221 | """ 222 | if noise is None: 223 | noise = th.randn_like(x_start) 224 | assert noise.shape == x_start.shape 225 | return ( 226 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 227 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 228 | ) 229 | 230 | def q_posterior_mean_variance(self, x_start, x_t, t): 231 | """ 232 | Compute the mean and variance of the diffusion posterior: 233 | q(x_{t-1} | x_t, x_0) 234 | """ 235 | assert x_start.shape == x_t.shape 236 | posterior_mean = ( 237 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 238 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 239 | ) 240 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 241 | posterior_log_variance_clipped = _extract_into_tensor( 242 | self.posterior_log_variance_clipped, t, x_t.shape 243 | ) 244 | assert ( 245 | posterior_mean.shape[0] 246 | == posterior_variance.shape[0] 247 | == posterior_log_variance_clipped.shape[0] 248 | == x_start.shape[0] 249 | ) 250 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 251 | 252 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 253 | """ 254 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 255 | the initial x, x_0. 256 | :param model: the model, which takes a signal and a batch of timesteps 257 | as input. 258 | :param x: the [N x C x ...] tensor at time t. 259 | :param t: a 1-D Tensor of timesteps. 260 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 261 | :param denoised_fn: if not None, a function which applies to the 262 | x_start prediction before it is used to sample. Applies before 263 | clip_denoised. 264 | :param model_kwargs: if not None, a dict of extra keyword arguments to 265 | pass to the model. This can be used for conditioning. 266 | :return: a dict with the following keys: 267 | - 'mean': the model mean output. 268 | - 'variance': the model variance output. 269 | - 'log_variance': the log of 'variance'. 270 | - 'pred_xstart': the prediction for x_0. 271 | """ 272 | if model_kwargs is None: 273 | model_kwargs = {} 274 | 275 | B, C = x.shape[:2] 276 | assert t.shape == (B,) 277 | model_output = model(x, t, **model_kwargs) 278 | if isinstance(model_output, tuple): 279 | model_output, extra = model_output 280 | else: 281 | extra = None 282 | 283 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 284 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 285 | model_output, model_var_values = th.split(model_output, C, dim=1) 286 | # model_output = model_output + (t > 1).reshape(-1, 1, 1, 1) * extra 287 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 288 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 289 | # The model_var_values is [-1, 1] for [min_var, max_var]. 290 | frac = (model_var_values + 1) / 2 291 | model_log_variance = frac * max_log + (1 - frac) * min_log 292 | model_variance = th.exp(model_log_variance) 293 | else: 294 | model_variance, model_log_variance = { 295 | # for fixedlarge, we set the initial (log-)variance like so 296 | # to get a better decoder log likelihood. 297 | ModelVarType.FIXED_LARGE: ( 298 | np.append(self.posterior_variance[1], self.betas[1:]), 299 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 300 | ), 301 | ModelVarType.FIXED_SMALL: ( 302 | self.posterior_variance, 303 | self.posterior_log_variance_clipped, 304 | ), 305 | }[self.model_var_type] 306 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 307 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 308 | 309 | def process_xstart(x): 310 | if denoised_fn is not None: 311 | x = denoised_fn(x) 312 | if clip_denoised: 313 | return x.clamp(-1, 1) 314 | return x 315 | 316 | if self.model_mean_type == ModelMeanType.START_X: 317 | pred_xstart = process_xstart(model_output) 318 | else: 319 | pred_xstart = process_xstart( 320 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 321 | ) 322 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 323 | 324 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 325 | return { 326 | "mean": model_mean, 327 | "variance": model_variance, 328 | "log_variance": model_log_variance, 329 | "pred_xstart": pred_xstart, 330 | "extra": extra, 331 | } 332 | 333 | def _predict_xstart_from_eps(self, x_t, t, eps): 334 | assert x_t.shape == eps.shape 335 | return ( 336 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 337 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 338 | ) 339 | 340 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 341 | return ( 342 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 343 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 344 | 345 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 346 | """ 347 | Compute the mean for the previous step, given a function cond_fn that 348 | computes the gradient of a conditional log probability with respect to 349 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 350 | condition on y. 351 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 352 | """ 353 | gradient = cond_fn(x, t, **model_kwargs) 354 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 355 | return new_mean 356 | 357 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 358 | """ 359 | Compute what the p_mean_variance output would have been, should the 360 | model's score function be conditioned by cond_fn. 361 | See condition_mean() for details on cond_fn. 362 | Unlike condition_mean(), this instead uses the conditioning strategy 363 | from Song et al (2020). 364 | """ 365 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 366 | 367 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 368 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 369 | 370 | out = p_mean_var.copy() 371 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 372 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 373 | return out 374 | 375 | def p_sample( 376 | self, 377 | model, 378 | x, 379 | t, 380 | clip_denoised=True, 381 | denoised_fn=None, 382 | cond_fn=None, 383 | model_kwargs=None, 384 | ): 385 | """ 386 | Sample x_{t-1} from the model at the given timestep. 387 | :param model: the model to sample from. 388 | :param x: the current tensor at x_{t-1}. 389 | :param t: the value of t, starting at 0 for the first diffusion step. 390 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 391 | :param denoised_fn: if not None, a function which applies to the 392 | x_start prediction before it is used to sample. 393 | :param cond_fn: if not None, this is a gradient function that acts 394 | similarly to the model. 395 | :param model_kwargs: if not None, a dict of extra keyword arguments to 396 | pass to the model. This can be used for conditioning. 397 | :return: a dict containing the following keys: 398 | - 'sample': a random sample from the model. 399 | - 'pred_xstart': a prediction of x_0. 400 | """ 401 | out = self.p_mean_variance( 402 | model, 403 | x, 404 | t, 405 | clip_denoised=clip_denoised, 406 | denoised_fn=denoised_fn, 407 | model_kwargs=model_kwargs, 408 | ) 409 | noise = th.randn_like(x) 410 | nonzero_mask = ( 411 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 412 | ) # no noise when t == 0 413 | if cond_fn is not None: 414 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 415 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 416 | return {"sample": sample, "pred_xstart": out["pred_xstart"], "log_variance": out["log_variance"]} 417 | 418 | def p_sample_const( 419 | self, 420 | model, 421 | x, 422 | t, 423 | get_eps, 424 | clip_denoised=True, 425 | denoised_fn=None, 426 | cond_fn=None, 427 | model_kwargs=None, 428 | ): 429 | """ 430 | Sample x_{t-1} from the model at the given timestep. 431 | :param model: the model to sample from. 432 | :param x: the current tensor at x_{t-1}. 433 | :param t: the value of t, starting at 0 for the first diffusion step. 434 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 435 | :param denoised_fn: if not None, a function which applies to the 436 | x_start prediction before it is used to sample. 437 | :param cond_fn: if not None, this is a gradient function that acts 438 | similarly to the model. 439 | :param model_kwargs: if not None, a dict of extra keyword arguments to 440 | pass to the model. This can be used for conditioning. 441 | :return: a dict containing the following keys: 442 | - 'sample': a random sample from the model. 443 | - 'pred_xstart': a prediction of x_0. 444 | """ 445 | out = self.p_mean_variance( 446 | model, 447 | x, 448 | t, 449 | clip_denoised=clip_denoised, 450 | denoised_fn=denoised_fn, 451 | model_kwargs=model_kwargs, 452 | ) 453 | const_x = self.q_sample(x_start=out["pred_xstart"], t=t, noise=get_eps(pred_xstart=out["pred_xstart"])) 454 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=const_x, t=t) 455 | out["mean"] = model_mean 456 | noise = th.randn_like(x) 457 | nonzero_mask = ( 458 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 459 | ) # no noise when t == 0 460 | if cond_fn is not None: 461 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 462 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 463 | return {"sample": sample, "pred_xstart": out["pred_xstart"], "log_variance": out["log_variance"]} 464 | 465 | def p_sample_loop_progressive( 466 | self, 467 | model, 468 | shape, 469 | noise=None, 470 | clip_denoised=True, 471 | denoised_fn=None, 472 | cond_fn=None, 473 | model_kwargs=None, 474 | device=None, 475 | progress=False, 476 | ): 477 | """ 478 | Generate samples from the model and yield intermediate samples from 479 | each timestep of diffusion. 480 | Arguments are the same as p_sample_loop(). 481 | Returns a generator over dicts, where each dict is the return value of 482 | p_sample(). 483 | """ 484 | if device is None: 485 | device = next(model.parameters()).device 486 | assert isinstance(shape, (tuple, list)) 487 | if noise is not None: 488 | img = noise 489 | else: 490 | img = th.randn(*shape, device=device) 491 | indices = list(range(self.num_timesteps))[::-1] 492 | 493 | if progress: 494 | # Lazy import so that we don't depend on tqdm. 495 | from tqdm.auto import tqdm 496 | 497 | indices = tqdm(indices, leave=False) 498 | 499 | for i in indices: 500 | t = th.tensor([i] * shape[0], device=device) 501 | with th.no_grad(): 502 | out = self.p_sample( 503 | model, 504 | img, 505 | t, 506 | clip_denoised=clip_denoised, 507 | denoised_fn=denoised_fn, 508 | cond_fn=cond_fn, 509 | model_kwargs=model_kwargs, 510 | ) 511 | yield out 512 | img = out["sample"] 513 | 514 | def p_sample_loop( 515 | self, 516 | model, 517 | shape, 518 | noise=None, 519 | clip_denoised=True, 520 | denoised_fn=None, 521 | cond_fn=None, 522 | model_kwargs=None, 523 | device=None, 524 | progress=False, 525 | ): 526 | """ 527 | Generate samples from the model. 528 | :param model: the model module. 529 | :param shape: the shape of the samples, (N, C, H, W). 530 | :param noise: if specified, the noise from the encoder to sample. 531 | Should be of the same shape as `shape`. 532 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 533 | :param denoised_fn: if not None, a function which applies to the 534 | x_start prediction before it is used to sample. 535 | :param cond_fn: if not None, this is a gradient function that acts 536 | similarly to the model. 537 | :param model_kwargs: if not None, a dict of extra keyword arguments to 538 | pass to the model. This can be used for conditioning. 539 | :param device: if specified, the device to create the samples on. 540 | If not specified, use a model parameter's device. 541 | :param progress: if True, show a tqdm progress bar. 542 | :return: a non-differentiable batch of samples. 543 | """ 544 | final = None 545 | for i, sample in enumerate(self.p_sample_loop_progressive( 546 | model, 547 | shape, 548 | noise=noise, 549 | clip_denoised=clip_denoised, 550 | denoised_fn=denoised_fn, 551 | cond_fn=cond_fn, 552 | model_kwargs=model_kwargs, 553 | device=device, 554 | progress=progress, 555 | )): 556 | final = sample 557 | return final 558 | 559 | def inner_sample( 560 | self, 561 | model, 562 | x, 563 | t, 564 | inner_steps=1, 565 | ddim=False, 566 | clip_denoised=True, 567 | denoised_fn=None, 568 | cond_fn=None, 569 | model_kwargs=None, 570 | ): 571 | """ 572 | Sample x_{t-1} from the model at the given timestep. 573 | :param model: the model to sample from. 574 | :param x: the current tensor at x_{t-1}. 575 | :param t: the value of t, starting at 0 for the first diffusion step. 576 | :param inner_steps: the number of inner steps to use. 577 | :param ddim: sample inner diffusion with deterministic DDIM. 578 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 579 | :param denoised_fn: if not None, a function which applies to the 580 | x_start prediction before it is used to sample. 581 | :param cond_fn: if not None, this is a gradient function that acts 582 | similarly to the model. 583 | :param model_kwargs: if not None, a dict of extra keyword arguments to 584 | pass to the model. This can be used for conditioning. 585 | :return: a dict containing the following keys: 586 | - 'sample': a random sample from the model. 587 | - 'pred_xstart': a prediction of x_0. 588 | """ 589 | 590 | from diffusion import create_diffusion 591 | size = min(inner_steps, self.timestep_map[t[0].item()] + 1) 592 | diffusion = create_diffusion(str(size), max_step=self.timestep_map[t[0].item()] + 1) 593 | indices = list(range(diffusion.num_timesteps))[::-1] 594 | img = x 595 | for iteration, i in enumerate(indices): 596 | t_tag = th.tensor([i] * x.shape[0], device=x.device) 597 | with th.no_grad(): 598 | if ddim: 599 | out = diffusion.ddim_sample( 600 | model, 601 | img, 602 | t_tag, 603 | clip_denoised=clip_denoised, 604 | denoised_fn=denoised_fn, 605 | cond_fn=cond_fn, 606 | model_kwargs=model_kwargs, 607 | ) 608 | else: 609 | out = diffusion.p_sample( 610 | model, 611 | img, 612 | t_tag, 613 | clip_denoised=clip_denoised, 614 | denoised_fn=denoised_fn, 615 | cond_fn=cond_fn, 616 | model_kwargs=model_kwargs, 617 | ) 618 | img = out["sample"] 619 | x_hat = out["pred_xstart"] 620 | eps = self._predict_eps_from_xstart(x, t, x_hat) 621 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 622 | pred_xstart = x_hat 623 | sample = ( 624 | x_hat * th.sqrt(alpha_bar_prev) 625 | + th.sqrt(1 - alpha_bar_prev) * eps 626 | ) 627 | 628 | return {"sample": sample, "pred_xstart": pred_xstart} 629 | 630 | def nested_sample_loop_progressive( 631 | self, 632 | model, 633 | shape, 634 | noise=None, 635 | inner_steps=1, 636 | ddim=False, 637 | clip_denoised=True, 638 | denoised_fn=None, 639 | cond_fn=None, 640 | model_kwargs=None, 641 | device=None, 642 | progress=False, 643 | ): 644 | """ 645 | Generate samples from the model and yield intermediate samples from 646 | each timestep of diffusion. 647 | Arguments are the same as nested_sample_loop(). 648 | Returns a generator over dicts, where each dict is the return value of 649 | nested_sample(). 650 | """ 651 | if device is None: 652 | device = next(model.parameters()).device 653 | assert isinstance(shape, (tuple, list)) 654 | if noise is not None: 655 | img = noise 656 | else: 657 | img = th.randn(*shape, device=device) 658 | indices = list(range(self.num_timesteps))[::-1][:-1] 659 | 660 | if progress: 661 | # Lazy import so that we don't depend on tqdm. 662 | from tqdm.auto import tqdm 663 | indices = tqdm(indices, leave=False) 664 | 665 | with th.no_grad(): 666 | for i in indices: 667 | t = th.tensor([i] * img.shape[0], device=device) 668 | out = self.inner_sample( 669 | model, 670 | img, 671 | t, 672 | inner_steps=int(inner_steps), 673 | ddim=ddim, 674 | clip_denoised=clip_denoised, 675 | denoised_fn=denoised_fn, 676 | cond_fn=cond_fn, 677 | model_kwargs=model_kwargs, 678 | ) 679 | 680 | yield out["pred_xstart"] 681 | img = out['sample'] 682 | 683 | 684 | def nested_sample_loop( 685 | self, 686 | model, 687 | shape, 688 | noise=None, 689 | inner_steps=1, 690 | ddim=False, 691 | clip_denoised=True, 692 | denoised_fn=None, 693 | cond_fn=None, 694 | model_kwargs=None, 695 | device=None, 696 | progress=False, 697 | ): 698 | """ 699 | Generate samples from the model. 700 | :param model: the model module. 701 | :param shape: the shape of the samples, (N, C, H, W). 702 | :param noise: if specified, the noise from the encoder to sample. 703 | Should be of the same shape as `shape`. 704 | :param inner_steps: number of inner steps to use. 705 | :param ddim: sample inner diffusion with deterministic DDIM. 706 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 707 | :param denoised_fn: if not None, a function which applies to the 708 | x_start prediction before it is used to sample. 709 | :param cond_fn: if not None, this is a gradient function that acts 710 | similarly to the model. 711 | :param model_kwargs: if not None, a dict of extra keyword arguments to 712 | pass to the model. This can be used for conditioning. 713 | :param device: if specified, the device to create the samples on. 714 | If not specified, use a model parameter's device. 715 | :param progress: if True, show a tqdm progress bar. 716 | :return: a non-differentiable batch of samples. 717 | """ 718 | res = [] 719 | for sample in self.nested_sample_loop_progressive( 720 | model, 721 | shape, 722 | noise=noise, 723 | inner_steps=inner_steps, 724 | ddim=ddim, 725 | clip_denoised=clip_denoised, 726 | denoised_fn=denoised_fn, 727 | cond_fn=cond_fn, 728 | model_kwargs=model_kwargs, 729 | device=device, 730 | progress=progress, 731 | ): 732 | res += [sample] 733 | return res 734 | 735 | 736 | 737 | def ddim_sample( 738 | self, 739 | model, 740 | x, 741 | t, 742 | clip_denoised=True, 743 | denoised_fn=None, 744 | cond_fn=None, 745 | model_kwargs=None, 746 | eta=0.0, 747 | ): 748 | """ 749 | Sample x_{t-1} from the model using DDIM. 750 | Same usage as p_sample(). 751 | """ 752 | out = self.p_mean_variance( 753 | model, 754 | x, 755 | t, 756 | clip_denoised=clip_denoised, 757 | denoised_fn=denoised_fn, 758 | model_kwargs=model_kwargs, 759 | ) 760 | if cond_fn is not None: 761 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 762 | 763 | # Usually our model outputs epsilon, but we re-derive it 764 | # in case we used x_start or x_prev prediction. 765 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 766 | 767 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 768 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 769 | sigma = ( 770 | eta 771 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 772 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 773 | ) 774 | # Equation 12. 775 | noise = th.randn_like(x) 776 | mean_pred = ( 777 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 778 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 779 | ) 780 | nonzero_mask = ( 781 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 782 | ) # no noise when t == 0 783 | sample = mean_pred + nonzero_mask * sigma * noise 784 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 785 | 786 | def ddim_reverse_sample( 787 | self, 788 | model, 789 | x, 790 | t, 791 | clip_denoised=True, 792 | denoised_fn=None, 793 | cond_fn=None, 794 | model_kwargs=None, 795 | eta=0.0, 796 | ): 797 | """ 798 | Sample x_{t+1} from the model using DDIM reverse ODE. 799 | """ 800 | assert eta == 0.0, "Reverse ODE only for deterministic path" 801 | out = self.p_mean_variance( 802 | model, 803 | x, 804 | t, 805 | clip_denoised=clip_denoised, 806 | denoised_fn=denoised_fn, 807 | model_kwargs=model_kwargs, 808 | ) 809 | if cond_fn is not None: 810 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 811 | # Usually our model outputs epsilon, but we re-derive it 812 | # in case we used x_start or x_prev prediction. 813 | eps = ( 814 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 815 | - out["pred_xstart"] 816 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 817 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 818 | 819 | # Equation 12. reversed 820 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 821 | 822 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 823 | 824 | def ddim_sample_loop( 825 | self, 826 | model, 827 | shape, 828 | noise=None, 829 | clip_denoised=True, 830 | denoised_fn=None, 831 | cond_fn=None, 832 | model_kwargs=None, 833 | device=None, 834 | progress=False, 835 | eta=0.0, 836 | ): 837 | """ 838 | Generate samples from the model using DDIM. 839 | Same usage as p_sample_loop(). 840 | """ 841 | final = None 842 | for sample in self.ddim_sample_loop_progressive( 843 | model, 844 | shape, 845 | noise=noise, 846 | clip_denoised=clip_denoised, 847 | denoised_fn=denoised_fn, 848 | cond_fn=cond_fn, 849 | model_kwargs=model_kwargs, 850 | device=device, 851 | progress=progress, 852 | eta=eta, 853 | ): 854 | final = sample 855 | return final["sample"] 856 | 857 | def ddim_sample_loop_progressive( 858 | self, 859 | model, 860 | shape, 861 | noise=None, 862 | clip_denoised=True, 863 | denoised_fn=None, 864 | cond_fn=None, 865 | model_kwargs=None, 866 | device=None, 867 | progress=False, 868 | eta=0.0, 869 | ): 870 | """ 871 | Use DDIM to sample from the model and yield intermediate samples from 872 | each timestep of DDIM. 873 | Same usage as p_sample_loop_progressive(). 874 | """ 875 | if device is None: 876 | device = next(model.parameters()).device 877 | assert isinstance(shape, (tuple, list)) 878 | if noise is not None: 879 | img = noise 880 | else: 881 | img = th.randn(*shape, device=device) 882 | indices = list(range(self.num_timesteps))[::-1] 883 | 884 | if progress: 885 | # Lazy import so that we don't depend on tqdm. 886 | from tqdm.auto import tqdm 887 | 888 | indices = tqdm(indices, leave=False) 889 | 890 | for i in indices: 891 | t = th.tensor([i] * shape[0], device=device) 892 | with th.no_grad(): 893 | out = self.ddim_sample( 894 | model, 895 | img, 896 | t, 897 | clip_denoised=clip_denoised, 898 | denoised_fn=denoised_fn, 899 | cond_fn=cond_fn, 900 | model_kwargs=model_kwargs, 901 | eta=eta, 902 | ) 903 | yield out 904 | img = out["sample"] 905 | 906 | 907 | 908 | def _vb_terms_bpd( 909 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 910 | ): 911 | """ 912 | Get a term for the variational lower-bound. 913 | The resulting units are bits (rather than nats, as one might expect). 914 | This allows for comparison to other papers. 915 | :return: a dict with the following keys: 916 | - 'output': a shape [N] tensor of NLLs or KLs. 917 | - 'pred_xstart': the x_0 predictions. 918 | """ 919 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 920 | x_start=x_start, x_t=x_t, t=t 921 | ) 922 | out = self.p_mean_variance( 923 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 924 | ) 925 | kl = normal_kl( 926 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 927 | ) 928 | kl = mean_flat(kl) / np.log(2.0) 929 | 930 | decoder_nll = -discretized_gaussian_log_likelihood( 931 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 932 | ) 933 | assert decoder_nll.shape == x_start.shape 934 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 935 | 936 | # At the first timestep return the decoder NLL, 937 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 938 | output = th.where((t == 0), decoder_nll, kl) 939 | return {"output": output, "pred_xstart": out["pred_xstart"]} 940 | 941 | def training_losses(self, model, x_start, class_label, loss_fn): 942 | """ 943 | Compute training losses for a single timestep. 944 | :param model: the model to evaluate loss on. 945 | :param x_start: the [N x C x ...] tensor of inputs. 946 | :param t: a batch of timestep indices. 947 | :param model_kwargs: if not None, a dict of extra keyword arguments to 948 | pass to the model. This can be used for conditioning. 949 | :param noise: if specified, the specific Gaussian noise to try to remove. 950 | :return: a dict with the key "loss" containing a tensor of shape [N]. 951 | Some mean or variance settings may also have other keys. 952 | """ 953 | noise = th.randn_like(x_start, device=x_start.device) 954 | b, c, h, w = x_start.shape 955 | t = torch.randint(0, self.num_timesteps, (b,), device=x_start.device) 956 | x_t = self.q_sample(x_start, t, noise=noise) 957 | 958 | noise_hat = model(x_t, t, class_label)[:, :4] 959 | 960 | loss = loss_fn(noise_hat, noise) 961 | return loss 962 | 963 | def _prior_bpd(self, x_start): 964 | """ 965 | Get the prior KL term for the variational lower-bound, measured in 966 | bits-per-dim. 967 | This term can't be optimized, as it only depends on the encoder. 968 | :param x_start: the [N x C x ...] tensor of inputs. 969 | :return: a batch of [N] KL values (in bits), one per batch element. 970 | """ 971 | batch_size = x_start.shape[0] 972 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 973 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 974 | kl_prior = normal_kl( 975 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 976 | ) 977 | return mean_flat(kl_prior) / np.log(2.0) 978 | 979 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 980 | """ 981 | Compute the entire variational lower-bound, measured in bits-per-dim, 982 | as well as other related quantities. 983 | :param model: the model to evaluate loss on. 984 | :param x_start: the [N x C x ...] tensor of inputs. 985 | :param clip_denoised: if True, clip denoised samples. 986 | :param model_kwargs: if not None, a dict of extra keyword arguments to 987 | pass to the model. This can be used for conditioning. 988 | :return: a dict containing the following keys: 989 | - total_bpd: the total variational lower-bound, per batch element. 990 | - prior_bpd: the prior term in the lower-bound. 991 | - vb: an [N x T] tensor of terms in the lower-bound. 992 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 993 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 994 | """ 995 | device = x_start.device 996 | batch_size = x_start.shape[0] 997 | 998 | vb = [] 999 | xstart_mse = [] 1000 | mse = [] 1001 | for t in list(range(self.num_timesteps))[::-1]: 1002 | t_batch = th.tensor([t] * batch_size, device=device) 1003 | noise = th.randn_like(x_start) 1004 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 1005 | # Calculate VLB term at the current timestep 1006 | with th.no_grad(): 1007 | out = self._vb_terms_bpd( 1008 | model, 1009 | x_start=x_start, 1010 | x_t=x_t, 1011 | t=t_batch, 1012 | clip_denoised=clip_denoised, 1013 | model_kwargs=model_kwargs, 1014 | ) 1015 | vb.append(out["output"]) 1016 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 1017 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 1018 | mse.append(mean_flat((eps - noise) ** 2)) 1019 | 1020 | vb = th.stack(vb, dim=1) 1021 | xstart_mse = th.stack(xstart_mse, dim=1) 1022 | mse = th.stack(mse, dim=1) 1023 | 1024 | prior_bpd = self._prior_bpd(x_start) 1025 | total_bpd = vb.sum(dim=1) + prior_bpd 1026 | return { 1027 | "total_bpd": total_bpd, 1028 | "prior_bpd": prior_bpd, 1029 | "vb": vb, 1030 | "xstart_mse": xstart_mse, 1031 | "mse": mse, 1032 | } 1033 | 1034 | 1035 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1036 | """ 1037 | Extract values from a 1-D numpy array for a batch of indices. 1038 | :param arr: the 1-D numpy array. 1039 | :param timesteps: a tensor of indices into the array to extract. 1040 | :param broadcast_shape: a larger shape of K dimensions with the batch 1041 | dimension equal to the length of timesteps. 1042 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1043 | """ 1044 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1045 | while len(res.shape) < len(broadcast_shape): 1046 | res = res[..., None] 1047 | return res + th.zeros(broadcast_shape, device=timesteps.device) 1048 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | For example, if there's 300 timesteps and the section counts are [10,15,20] 13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 14 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | :param num_timesteps: the number of diffusion steps in the original 18 | process to divide up. 19 | :param section_counts: either a list of numbers, or a string containing 20 | comma-separated numbers, indicating the step count 21 | per section. As a special case, use "ddimN" where N 22 | is a number of steps to use the striding from the 23 | DDIM paper. 24 | :return: a set of diffusion steps from the original process to use. 25 | """ 26 | if isinstance(section_counts, str): 27 | if section_counts.startswith("ddim"): 28 | desired_count = int(section_counts[len("ddim") :]) 29 | for i in range(1, num_timesteps): 30 | if len(range(0, num_timesteps, i)) == desired_count: 31 | return set(range(0, num_timesteps, i)) 32 | raise ValueError( 33 | f"cannot create exactly {num_timesteps} steps with an integer stride" 34 | ) 35 | section_counts = [int(x) for x in section_counts.split(",")] 36 | size_per = num_timesteps // len(section_counts) 37 | extra = num_timesteps % len(section_counts) 38 | start_idx = 0 39 | all_steps = [] 40 | for i, section_count in enumerate(section_counts): 41 | size = size_per + (1 if i < extra else 0) 42 | if size < section_count: 43 | raise ValueError( 44 | f"cannot divide section of {size} steps into {section_count}" 45 | ) 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | class SpacedDiffusion(GaussianDiffusion): 61 | """ 62 | A diffusion process which can skip steps in a base diffusion process. 63 | :param use_timesteps: a collection (sequence or set) of timesteps from the 64 | original diffusion process to retain. 65 | :param kwargs: the kwargs to create the base diffusion process. 66 | """ 67 | 68 | def __init__(self, use_timesteps, max_step=None, **kwargs): 69 | self.use_timesteps = set(use_timesteps) 70 | self.timestep_map = [] 71 | self.original_num_steps = len(kwargs["betas"]) 72 | 73 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 74 | last_alpha_cumprod = 1.0 75 | new_betas = [] 76 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod[:max_step]): 77 | if i in self.use_timesteps: 78 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 79 | last_alpha_cumprod = alpha_cumprod 80 | self.timestep_map.append(i) 81 | kwargs["betas"] = np.array(new_betas) 82 | super().__init__(**kwargs) 83 | 84 | def p_mean_variance( 85 | self, model, *args, **kwargs 86 | ): # pylint: disable=signature-differs 87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 88 | 89 | def training_losses( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 93 | 94 | def condition_mean(self, cond_fn, *args, **kwargs): 95 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 96 | 97 | def condition_score(self, cond_fn, *args, **kwargs): 98 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 99 | 100 | def _wrap_model(self, model): 101 | if isinstance(model, _WrappedModel): 102 | return model 103 | return _WrappedModel( 104 | model, self.timestep_map, self.original_num_steps 105 | ) 106 | 107 | def _scale_timesteps(self, t): 108 | # Scaling is done by the wrapped model. 109 | return t 110 | 111 | 112 | class _WrappedModel: 113 | def __init__(self, model, timestep_map, original_num_steps): 114 | self.model = model 115 | self.timestep_map = timestep_map 116 | # self.rescale_timesteps = rescale_timesteps 117 | self.original_num_steps = original_num_steps 118 | 119 | def __call__(self, x, ts, **kwargs): 120 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 121 | new_ts = map_tensor[ts] 122 | # if self.rescale_timesteps: 123 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 124 | return self.model(x, new_ts, **kwargs) 125 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | :param name: the name of the sampler. 12 | :param diffusion: the diffusion object to sample for. 13 | """ 14 | if name == "uniform": 15 | return UniformSampler(diffusion) 16 | elif name == "loss-second-moment": 17 | return LossSecondMomentResampler(diffusion) 18 | else: 19 | raise NotImplementedError(f"unknown schedule sampler: {name}") 20 | 21 | 22 | class ScheduleSampler(ABC): 23 | """ 24 | A distribution over timesteps in the diffusion process, intended to reduce 25 | variance of the objective. 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | 32 | @abstractmethod 33 | def weights(self): 34 | """ 35 | Get a numpy array of weights, one per diffusion step. 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | :param batch_size: the number of timesteps. 43 | :param device: the torch device to save to. 44 | :return: a tuple (timesteps, weights): 45 | - timesteps: a tensor of timestep indices. 46 | - weights: a tensor of weights to scale the resulting losses. 47 | """ 48 | w = self.weights() 49 | p = w / np.sum(w) 50 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 51 | indices = th.from_numpy(indices_np).long().to(device) 52 | weights_np = 1 / (len(p) * p[indices_np]) 53 | weights = th.from_numpy(weights_np).float().to(device) 54 | return indices, weights 55 | 56 | 57 | class UniformSampler(ScheduleSampler): 58 | def __init__(self, diffusion): 59 | self.diffusion = diffusion 60 | self._weights = np.ones([diffusion.num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | 65 | 66 | class LossAwareSampler(ScheduleSampler): 67 | def update_with_local_losses(self, local_ts, local_losses): 68 | """ 69 | Update the reweighting using losses from a model. 70 | Call this method from each rank with a batch of timesteps and the 71 | corresponding losses for each of those timesteps. 72 | This method will perform synchronization to make sure all of the ranks 73 | maintain the exact same reweighting. 74 | :param local_ts: an integer Tensor of timesteps. 75 | :param local_losses: a 1D Tensor of losses. 76 | """ 77 | batch_sizes = [ 78 | th.tensor([0], dtype=th.int32, device=local_ts.device) 79 | for _ in range(dist.get_world_size()) 80 | ] 81 | dist.all_gather( 82 | batch_sizes, 83 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 84 | ) 85 | 86 | # Pad all_gather batches to be the maximum batch size. 87 | batch_sizes = [x.item() for x in batch_sizes] 88 | max_bs = max(batch_sizes) 89 | 90 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 91 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 92 | dist.all_gather(timestep_batches, local_ts) 93 | dist.all_gather(loss_batches, local_losses) 94 | timesteps = [ 95 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 96 | ] 97 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 98 | self.update_with_all_losses(timesteps, losses) 99 | 100 | @abstractmethod 101 | def update_with_all_losses(self, ts, losses): 102 | """ 103 | Update the reweighting using losses from a model. 104 | Sub-classes should override this method to update the reweighting 105 | using losses from the model. 106 | This method directly updates the reweighting without synchronizing 107 | between workers. It is called by update_with_local_losses from all 108 | ranks with identical arguments. Thus, it should have deterministic 109 | behavior to maintain state across workers. 110 | :param ts: a list of int timesteps. 111 | :param losses: a list of float losses, one per timestep. 112 | """ 113 | 114 | 115 | class LossSecondMomentResampler(LossAwareSampler): 116 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 117 | self.diffusion = diffusion 118 | self.history_per_term = history_per_term 119 | self.uniform_prob = uniform_prob 120 | self._loss_history = np.zeros( 121 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 122 | ) 123 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 124 | 125 | def weights(self): 126 | if not self._warmed_up(): 127 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 128 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 129 | weights /= np.sum(weights) 130 | weights *= 1 - self.uniform_prob 131 | weights += self.uniform_prob / len(weights) 132 | return weights 133 | 134 | def update_with_all_losses(self, ts, losses): 135 | for t, loss in zip(ts, losses): 136 | if self._loss_counts[t] == self.history_per_term: 137 | # Shift out the oldest loss term. 138 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 139 | self._loss_history[t, -1] = loss 140 | else: 141 | self._loss_history[t, self._loss_counts[t]] = loss 142 | self._loss_counts[t] += 1 143 | 144 | def _warmed_up(self): 145 | return (self._loss_counts == self.history_per_term).all() 146 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Functions for downloading pre-trained DiT models 4 | """ 5 | from torchvision.datasets.utils import download_url 6 | import torch 7 | import os 8 | 9 | 10 | pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} 11 | 12 | 13 | def find_model(model_name): 14 | """ 15 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. 16 | """ 17 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints 18 | return download_model(model_name) 19 | else: # Load a custom DiT checkpoint: 20 | assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' 21 | return torch.load(model_name, map_location=lambda storage, loc: storage) 22 | 23 | 24 | def download_model(model_name): 25 | """ 26 | Downloads a pre-trained DiT model from the web. 27 | """ 28 | assert model_name in pretrained_models 29 | local_path = f'pretrained_models/{model_name}' 30 | if not os.path.isfile(local_path): 31 | os.makedirs('pretrained_models', exist_ok=True) 32 | web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' 33 | download_url(web_path, 'pretrained_models') 34 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 35 | return model 36 | 37 | 38 | if __name__ == "__main__": 39 | # Download all DiT checkpoints 40 | for model in pretrained_models: 41 | download_model(model) 42 | print('Done.') 43 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nested-diffusion 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | - mamba 10 | - pytorch 11 | - torchvision 12 | - torchaudio 13 | - pytorch-cuda=11.7 14 | - tensorflow 15 | - tqdm 16 | - pyyaml 17 | - tensorboard 18 | - tensorboardx 19 | - ca-certificates 20 | - certifi 21 | - openssl 22 | - pip 23 | - pip: 24 | - accelerate==0.19.0 25 | - diffusers==0.16.1 26 | - fsspec==2023.5.0 27 | - huggingface-hub==0.14.1 28 | - psutil==5.9.5 29 | - regex==2023.5.5 30 | - safetensors==0.3.1 31 | - timm==0.9.2 32 | - tokenizers==0.13.3 33 | - transformers==4.29.2 34 | 35 | 36 | -------------------------------------------------------------------------------- /figures/Nested_Egg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noamelata/NestedDiffusion/25c2700162d079927d33d57dd5c8e05a565be646/figures/Nested_Egg.png -------------------------------------------------------------------------------- /figures/arxiv.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | arXiv 19 | 21 | 46 | 50 | 57 | 58 | 60 | 61 | 63 | image/svg+xml 64 | 66 | arXiv 67 | 69 | 70 | 71 | arXiv 72 | 73 | 74 | 75 | 77 | 79 | 81 | 83 | 85 | 87 | 89 | 91 | 93 | 95 | 96 | 97 | 98 | 102 | 106 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /figures/generation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noamelata/NestedDiffusion/25c2700162d079927d33d57dd5c8e05a565be646/figures/generation.gif -------------------------------------------------------------------------------- /generate_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from functools import partial 5 | 6 | import torch 7 | 8 | torch.backends.cuda.matmul.allow_tf32 = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | 11 | from torch.utils import data 12 | from torch.utils.data import DistributedSampler, TensorDataset 13 | import torch.multiprocessing as mp 14 | from torchvision.utils import save_image 15 | from tqdm import tqdm 16 | 17 | import core.util as Util 18 | from diffusion import create_diffusion 19 | from download import find_model 20 | from models import DiT_XL_2 21 | from diffusers.models import AutoencoderKL 22 | 23 | 24 | def main_worker(gpu, ngpus_per_node, opt): 25 | if 'local_rank' not in opt: 26 | opt.local_rank = opt.global_rank = gpu 27 | if opt.distributed: 28 | torch.cuda.set_device(int(opt.local_rank)) 29 | print('using GPU {} for generation'.format(int(opt.local_rank))) 30 | torch.distributed.init_process_group(backend = 'nccl', 31 | init_method = opt.init_method, 32 | world_size = opt.world_size, 33 | rank = opt.global_rank, 34 | group_name='mtorch' 35 | ) 36 | set_device = partial(Util.set_device, rank=opt.global_rank) 37 | '''set seed and and cuDNN environment ''' 38 | torch.backends.cudnn.enabled = True 39 | torch.backends.cudnn.deterministic = False 40 | torch.backends.cudnn.benchmark = True 41 | torch.set_grad_enabled(False) 42 | num_sampling_steps = opt.steps 43 | # Nested Diffusion skips the last sampling step from $t = 1$ to $t = 0$, as this is computed in the last inner step 44 | num_sampling_steps = num_sampling_steps + 1 if opt.nested is not None else num_sampling_steps 45 | cfg_scale = opt.cfg 46 | 47 | # Load model: 48 | image_size = 256 49 | latent_channels = 4 50 | assert image_size in [256, 512], "We only provide pre-trained models for 256x256 and 512x512 resolutions." 51 | latent_size = image_size // 8 52 | model = DiT_XL_2(input_size=latent_size) 53 | state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") 54 | model.load_state_dict(state_dict) 55 | model = set_device(model, distributed=opt.distributed) 56 | model.eval() 57 | diffusion = create_diffusion(str(num_sampling_steps)) 58 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema") 59 | vae = set_device(vae, distributed=opt.distributed) 60 | vae.eval() 61 | 62 | images_in_output = opt.nested if opt.nested is not None else 1 63 | 64 | 65 | if opt.global_rank == 0: 66 | if os.path.exists(opt.output): 67 | shutil.rmtree(opt.output) 68 | os.makedirs(opt.output, exist_ok=True) 69 | for i in range(images_in_output): 70 | os.makedirs(os.path.join(opt.output, f"{(opt.nested if opt.nested is not None else 1) * i}"), exist_ok=True) 71 | if opt.distributed: 72 | torch.distributed.barrier() 73 | else: 74 | if opt.distributed: 75 | torch.distributed.barrier() 76 | 77 | assert opt.number % 1000 == 0 78 | per_class = (opt.number // 1000) 79 | 80 | dataset = TensorDataset(torch.arange(1000).repeat_interleave(per_class), torch.arange(1000 * per_class)) 81 | data_sampler = None 82 | if opt.distributed: 83 | data_sampler = DistributedSampler(dataset, shuffle=False, num_replicas=opt.world_size, rank=opt.global_rank) 84 | class_loader = data.DataLoader( 85 | dataset, 86 | sampler=data_sampler, 87 | batch_size=opt.batch, 88 | shuffle=False, 89 | num_workers=opt.num_workers, 90 | drop_last=False 91 | ) 92 | 93 | for class_, n in tqdm(class_loader): 94 | b = class_.shape[0] 95 | z = torch.randn(b, latent_channels, latent_size, latent_size) 96 | y = class_.clone().int() 97 | z, y = set_device(z, distributed=opt.distributed), set_device(y, distributed=opt.distributed) 98 | func = (model.module.forward_with_cfg if opt.distributed else model.forward_with_cfg) if cfg_scale != 1.0 else model 99 | model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale != 1.0 else dict(y=y) 100 | 101 | if opt.nested: 102 | samples = diffusion.nested_sample_loop(func, z.shape, z, inner_steps=opt.nested, 103 | ddim=opt.ddim, clip_denoised=False, 104 | model_kwargs=model_kwargs, progress=True, device=z.device) 105 | elif opt.ddim: 106 | samples = diffusion.ddim_sample_loop(func, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, 107 | progress=True, eta=opt.eta, device=z.device) 108 | else: 109 | samples = diffusion.p_sample_loop(func, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, 110 | progress=True, device=z.device) 111 | 112 | decode = vae.module.decode if opt.distributed else vae.decode 113 | if not isinstance(samples, list): 114 | samples = [samples] 115 | samples = [decode(s / 0.18215).sample for s in samples[:-1]] 116 | for i, sample in enumerate(samples): 117 | for j in range(b): 118 | number = n[j].item() 119 | save_image((0.5 + 0.5 * sample[j]), 120 | os.path.join(opt.output, 121 | f"{(opt.nested if opt.nested is not None else 1) * i}", 122 | f"{number:06d}.png")) 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser() 127 | 128 | parser.add_argument('-n', '--number', type=int, default=10000) 129 | parser.add_argument('-b', '--batch', type=int, default=64) 130 | parser.add_argument('-o', '--output', type=str, required=True) 131 | parser.add_argument('--seed', type=int, default=0) 132 | parser.add_argument('-s', '--steps', type=int, default=50) 133 | parser.add_argument('-e', '--eta', type=float, default=1.0) 134 | parser.add_argument('--nested', type=int, default=0) 135 | parser.add_argument('--ddim', action='store_true', default=False) 136 | parser.add_argument('--num-workers', type=int, default=0) 137 | parser.add_argument('--cfg', type=float, default=1.0) 138 | parser.add_argument('-P', '--port', default='21012', type=str) 139 | parser.add_argument('-gpu', '--gpu_ids', type=str, default="0") 140 | 141 | ''' parser configs ''' 142 | args = parser.parse_args() 143 | 144 | ''' cuda devices ''' 145 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids 146 | print('export CUDA_VISIBLE_DEVICES={}'.format(args.gpu_ids)) 147 | args.gpu_ids = [int(x) for x in args.gpu_ids.split(",")] 148 | args.distributed = len(args.gpu_ids) > 1 149 | 150 | ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training''' 151 | if args.distributed: 152 | ngpus_per_node = len(args.gpu_ids) # or torch.cuda.device_count() 153 | args.world_size = ngpus_per_node 154 | args.init_method = 'tcp://127.0.0.1:' + args.port 155 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 156 | else: 157 | args.world_size = 1 158 | main_worker(0, 1, args) 159 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from NestedPipeline import NestedStableDiffusionPipeline 7 | from NestedScheduler import NestedScheduler 8 | 9 | def main(seed, inner, outer, fp16, prompt, outdir): 10 | scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 11 | prediction_type='sample', clip_sample=False, set_alpha_to_one=False) 12 | if fp16: 13 | pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", 14 | torch_dtype=torch.float16, scheduler=scheduler) 15 | else: 16 | pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) 17 | os.makedirs(outdir, exist_ok=True) 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | generator = torch.Generator(device).manual_seed(seed) 20 | pipe.to(device) 21 | 22 | for i, im in enumerate(pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, generator=generator)): 23 | im.images[0].save(os.path.join(outdir, f"ND_{(i+1) * inner}NFEs_{prompt}.png")) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--seed', type=int, default=24) 29 | parser.add_argument('--inner', type=int, default=4) 30 | parser.add_argument('--outer', type=int, default=25) 31 | parser.add_argument('--outdir', type=str, default="figures") 32 | parser.add_argument('--fp16', action='store_true', default=False) 33 | parser.add_argument('-p', '--prompt', type=str, default="a photograph of a nest with a blue egg inside") 34 | args = parser.parse_args() 35 | main(**vars(args)) 36 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=1000, 160 | learn_sigma=True, 161 | ): 162 | super().__init__() 163 | self.learn_sigma = learn_sigma 164 | self.in_channels = in_channels 165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 166 | self.patch_size = patch_size 167 | self.num_heads = num_heads 168 | 169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 170 | self.t_embedder = TimestepEmbedder(hidden_size) 171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 172 | num_patches = self.x_embedder.num_patches 173 | # Will use fixed sin-cos embedding: 174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 175 | 176 | self.blocks = nn.ModuleList([ 177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 178 | ]) 179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 180 | self.initialize_weights() 181 | 182 | def initialize_weights(self): 183 | # Initialize transformer layers: 184 | def _basic_init(module): 185 | if isinstance(module, nn.Linear): 186 | torch.nn.init.xavier_uniform_(module.weight) 187 | if module.bias is not None: 188 | nn.init.constant_(module.bias, 0) 189 | self.apply(_basic_init) 190 | 191 | # Initialize (and freeze) pos_embed by sin-cos embedding: 192 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 193 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 194 | 195 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 196 | w = self.x_embedder.proj.weight.data 197 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 198 | 199 | # Initialize label embedding table: 200 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 201 | 202 | # Initialize timestep embedding MLP: 203 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 204 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 205 | 206 | # Zero-out adaLN modulation layers in DiT blocks: 207 | for block in self.blocks: 208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 210 | 211 | # Zero-out output layers: 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | def unpatchify(self, x): 218 | """ 219 | x: (N, T, patch_size**2 * C) 220 | imgs: (N, H, W, C) 221 | """ 222 | c = self.out_channels 223 | p = self.x_embedder.patch_size[0] 224 | h = w = int(x.shape[1] ** 0.5) 225 | assert h * w == x.shape[1] 226 | 227 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 228 | x = torch.einsum('nhwpqc->nchpwq', x) 229 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 230 | return imgs 231 | 232 | def forward(self, x, t, y, return_mid=False): 233 | """ 234 | Forward pass of DiT. 235 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 236 | t: (N,) tensor of diffusion timesteps 237 | y: (N,) tensor of class labels 238 | """ 239 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 240 | embedded = x.clone() 241 | t = self.t_embedder(t) # (N, D) 242 | y = self.y_embedder(y, self.training) # (N, D) 243 | c = t + y # (N, D) 244 | for block in self.blocks: 245 | x = block(x, c) # (N, T, D) 246 | features = x.clone() 247 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 248 | x = self.unpatchify(x) # (N, out_channels, H, W) 249 | if not return_mid: 250 | return x 251 | return x, features, embedded, c, self.pos_embed 252 | 253 | 254 | def forward_with_cfg(self, x, t, y, cfg_scale): 255 | """ 256 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 257 | """ 258 | model_out = self.forward(x.tile([2, 1, 1, 1]), torch.cat([t, t], dim=0), torch.cat([y, 1000 * torch.ones_like(y)], dim=0)) 259 | # this is probably an error in DiT as the latent representation has 4 channels. 260 | # we did not modify the application of CFG, and continue applying it to 3 channels only. 261 | eps, rest = model_out[:, :3], model_out[:, 3:] 262 | cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) 263 | cond_rest, uncond_rest = torch.chunk(rest, 2, dim=0) 264 | cfg_nh = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 265 | res = torch.cat([cfg_nh, cond_rest], dim=1) 266 | return res 267 | 268 | 269 | ################################################################################# 270 | # Sine/Cosine Positional Embedding Functions # 271 | ################################################################################# 272 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 273 | 274 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 275 | """ 276 | grid_size: int of the grid height and width 277 | return: 278 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 279 | """ 280 | grid_h = np.arange(grid_size, dtype=np.float32) 281 | grid_w = np.arange(grid_size, dtype=np.float32) 282 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 283 | grid = np.stack(grid, axis=0) 284 | 285 | grid = grid.reshape([2, 1, grid_size, grid_size]) 286 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 287 | if cls_token and extra_tokens > 0: 288 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 289 | return pos_embed 290 | 291 | 292 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 293 | assert embed_dim % 2 == 0 294 | 295 | # use half of dimensions to encode grid_h 296 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 297 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 298 | 299 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 300 | return emb 301 | 302 | 303 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 304 | """ 305 | embed_dim: output dimension for each position 306 | pos: a list of positions to be encoded: size (M,) 307 | out: (M, D) 308 | """ 309 | assert embed_dim % 2 == 0 310 | omega = np.arange(embed_dim // 2, dtype=np.float64) 311 | omega /= embed_dim / 2. 312 | omega = 1. / 10000**omega # (D/2,) 313 | 314 | pos = pos.reshape(-1) # (M,) 315 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 316 | 317 | emb_sin = np.sin(out) # (M, D/2) 318 | emb_cos = np.cos(out) # (M, D/2) 319 | 320 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 321 | return emb 322 | 323 | 324 | ################################################################################# 325 | # DiT Configs # 326 | ################################################################################# 327 | 328 | def DiT_XL_2(**kwargs): 329 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 330 | 331 | def DiT_XL_4(**kwargs): 332 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 333 | 334 | def DiT_XL_8(**kwargs): 335 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 336 | 337 | def DiT_L_2(**kwargs): 338 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 339 | 340 | def DiT_L_4(**kwargs): 341 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 342 | 343 | def DiT_L_8(**kwargs): 344 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 345 | 346 | def DiT_B_2(**kwargs): 347 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 348 | 349 | def DiT_B_4(**kwargs): 350 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 351 | 352 | def DiT_B_8(**kwargs): 353 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 354 | 355 | def DiT_S_2(**kwargs): 356 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 357 | 358 | def DiT_S_4(**kwargs): 359 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 360 | 361 | def DiT_S_8(**kwargs): 362 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 363 | -------------------------------------------------------------------------------- /sample_text_to_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from NestedPipeline import NestedStableDiffusionPipeline 7 | from NestedScheduler import NestedScheduler 8 | 9 | def main(seed, inner, outer, fp16, dpm_solver, prompt, outdir): 10 | scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 11 | prediction_type='sample', clip_sample=False, set_alpha_to_one=False) 12 | if fp16: 13 | pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", 14 | torch_dtype=torch.float16, scheduler=scheduler) 15 | else: 16 | pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) 17 | os.makedirs(outdir, exist_ok=True) 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | generator = torch.Generator(device).manual_seed(seed) 20 | pipe.to(device) 21 | 22 | for i, im in enumerate(pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, 23 | inner_scheduler="DPMSolver" if dpm_solver else "DDIM", generator=generator)): 24 | im.images[0].save(os.path.join(outdir, f"ND_{(i+1) * inner}NFEs_{prompt}.png")) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--seed', type=int, default=24) 30 | parser.add_argument('--inner', type=int, default=4) 31 | parser.add_argument('--outer', type=int, default=25) 32 | parser.add_argument('--outdir', type=str, default="figures") 33 | parser.add_argument('--fp16', action='store_true', default=False) 34 | parser.add_argument('--dpm-solver', action='store_true', default=False) 35 | parser.add_argument('-p', '--prompt', type=str, default="a photograph of a nest with a blue egg inside") 36 | args = parser.parse_args() 37 | main(**vars(args)) 38 | --------------------------------------------------------------------------------