├── README.md ├── datasets.py ├── pipeline_stable_video_diffusion_re.py ├── requirements.txt ├── scheduling_euler_discrete_resampling.py ├── svd_sequential_re.py ├── test_data ├── Multiview_data │ └── mipnerf360_lite │ │ └── garden │ │ ├── frame_00006.JPG │ │ └── frame_00104.JPG ├── gym_motion_2024_frames │ └── arm_clip1 │ │ ├── frame_00018.jpg │ │ └── frame_00060.jpg └── video_frames │ └── dolomite_clip3 │ ├── frame_00000.jpg │ └── frame_00100.jpg └── unet_spatio_temporal_condition.py /README.md: -------------------------------------------------------------------------------- 1 | # Time_Reversal_Fusion 2 |

3 | TRF_teaser_figure 4 |

5 | 6 | This is the official Pytorch implementation of Time Reversal Fusion (accepted at ECCV2024). 7 | We proposed a new sampling strategy called Time-Reversal Fusion (TRF), which enables the image-to-video model to generate sequences toward a given end frame without any tuning or back-propagated optimization. We define this new task as "Bounded Generation" and it generalizes three scenarios in computer vision: 8 | 1) Generating subject motion with the two bound images capturing a moving subject. 9 | 2) Synthesizing camera motion using two images captured from different viewpoints of a static scene. 10 | 3) Achieving video looping by using the same image for both bounds. 11 | 12 | Please refer to the [arXiv paper](https://arxiv.org/abs/2403.14611) for more technical details and [Project Page](time-reversal.github.io) for more video results. 13 | 14 | ## Todo 15 | - [x] TRF code release 16 | - [x] Bounded Generation Dataset release 17 | - [ ] Gradio demo 18 | 19 | ## Getting Started 20 | Clone the repo: 21 | ```bash 22 | git clone https://github.com/HavenFeng/time_reversal/ 23 | cd time_reveral 24 | ``` 25 | 26 | ### Requirements 27 | * Python 3.10 (numpy, skimage, scipy, opencv) 28 | * Diffusers 29 | * PyTorch >= 2.0.1 (Diffusers compatible) 30 | You can run 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | If you encountered errors when installing Diffusers, please follow the [official installation guide](https://huggingface.co/docs/diffusers/en/installation) to re-install the library. 35 | 36 | ### Usage 37 | 1. **Run inference with samples in paper** 38 | ```bash 39 | python svd_sequential_re.py multiview 40 | ``` 41 | Check different task results with "multiview", "video frames", "gym_motion" and "image2loop", the generated results can be found in the ./output folder. 42 | 2. **TRF++ (add LoRA "patches" to enhance domain-specific task)** 43 | TRF was designed to probe SVD's bounded generation capabilities without fine-tuning, but we've observed SVD's biases in subject and camera motion, as well as sensitivity to conditioning factors like FPS and motion intensity. These required careful parameter tuning for different inputs. To improve generation quality and robustness for other downstream tasks, we fine-tuned LoRA "patch" on various domain-specific datasets, better supporting long-range linear motion and extreme 3D views generation. 44 | ``` 45 | coming soon 46 | ``` 47 | 48 | ## Evaluation 49 | We evaluate our methods with the [Bounded Generation Dataset](https://drive.google.com/drive/folders/1qH4yx5954Bm6h1E4olEqgV0pSJicdkNu?usp=sharing) compared to the domain-specific state-of-the-art methods. 50 | For more details of the evaluation, please check our [arXiv paper](https://arxiv.org/abs/2403.14611). 51 | 52 | 53 | ## Citation 54 | If you find our work useful to your research, please consider citing: 55 | ``` 56 | @inproceedings{Feng:TRF:ECCV2024, 57 | title = {Explorative In-betweening of Time and Space}, 58 | author = {Feng, Haiwen and Ding, Zheng and Xia, Zhihao and Niklaus, Simon and Abrevaya, Victoria and Black, Michael J. and Zhang Xuaner}, 59 | booktitle = {European Conference on Computer Vision}, 60 | year = {2024} 61 | } 62 | ``` 63 | 64 | ## Notes 65 | The video form of of our teaser image: 66 | 67 | https://github.com/user-attachments/assets/b984c57c-a450-4071-996c-dc3df1445e79 68 | 69 | More domain-specific lora patch models will be released soon 70 | 71 | ## License 72 | This code and model are available for non-commercial scientific research purposes. 73 | 74 | ## Acknowledgements 75 | We would like to thank recent baseline works that allow us to easily perform quantitative and qualitative comparisons :) 76 | [FILM](https://github.com/google-research/frame-interpolation), 77 | [Wide-Baseline](https://github.com/yilundu/cross_attention_renderer), 78 | [Text2Cinemagraph](https://github.com/text2cinemagraph/text2cinemagraph/tree/master), 79 | 80 | This work was partly supported by the German Federal Ministry of Education and Research (BMBF): Tuebingen AI Center, FKZ: 01IS18039B 81 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | 5 | def get_dataset(dataset_folder, data_type='frame', filter_keyword=None): 6 | """ 7 | General function to retrieve datasets based on the data_type. 8 | 9 | Parameters: 10 | dataset_folder (str): The root directory of the dataset. 11 | data_type (str): Type of dataset to retrieve. Options are 'frame', 'multiview', 'loop'. 12 | filter_keyword (str, optional): A keyword to filter the dataset. 13 | 14 | Returns: 15 | np.ndarray: An array of image pairs. 16 | """ 17 | if data_type == 'loop': 18 | # For loop dataset, pairs are the same image 19 | image_paths = sorted(glob.glob(os.path.join(dataset_folder, 'img2loop', '*', '*'))) 20 | if filter_keyword: 21 | image_paths = [path for path in image_paths if filter_keyword in path] 22 | video_frames = np.array([[path, path] for path in image_paths]) 23 | else: 24 | # For 'frame' and 'multiview' datasets 25 | if data_type == 'frame': 26 | video_frame_selected = [ 27 | 'video_frames/dolomite_clip3/frame_00000.jpg', 28 | 'video_frames/dolomite_clip3/frame_00100.jpg', 29 | 30 | 'gym_motion_2024_frames/arm_clip1/frame_00018.jpg', 31 | 'gym_motion_2024_frames/arm_clip1/frame_00060.jpg', 32 | ] 33 | elif data_type == 'multiview': 34 | video_frame_selected = [ 35 | 'Multiview_data/mipnerf360_lite/garden/frame_00006.JPG', 36 | 'Multiview_data/mipnerf360_lite/garden/frame_00104.JPG', 37 | ] 38 | else: 39 | raise ValueError(f"Unsupported data_type: {data_type}") 40 | 41 | # Prepend dataset_folder to paths 42 | video_frame_selected = [os.path.join(dataset_folder, path) for path in video_frame_selected] 43 | if filter_keyword: 44 | video_frame_selected = [path for path in video_frame_selected if filter_keyword in path] 45 | 46 | # Reshape into pairs 47 | video_frames = np.array(video_frame_selected).reshape(-1, 2) 48 | 49 | return video_frames 50 | -------------------------------------------------------------------------------- /pipeline_stable_video_diffusion_re.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from dataclasses import dataclass 17 | from typing import Callable, Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | import PIL.Image 21 | import torch 22 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 23 | 24 | from diffusers.image_processor import VaeImageProcessor 25 | from diffusers.models import AutoencoderKLTemporalDecoder 26 | from unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel 27 | from scheduling_euler_discrete_resampling import EulerDiscreteScheduler 28 | 29 | from diffusers.utils import BaseOutput, logging 30 | from diffusers.utils.torch_utils import randn_tensor 31 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 32 | 33 | 34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 35 | 36 | 37 | def _append_dims(x, target_dims): 38 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 39 | dims_to_append = target_dims - x.ndim 40 | if dims_to_append < 0: 41 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 42 | return x[(...,) + (None,) * dims_to_append] 43 | 44 | 45 | def tensor2vid(video: torch.Tensor, processor, output_type="np"): 46 | 47 | batch_size, channels, num_frames, height, width = video.shape 48 | outputs = [] 49 | for batch_idx in range(batch_size): 50 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 51 | batch_output = processor.postprocess(batch_vid, output_type) 52 | 53 | outputs.append(batch_output) 54 | 55 | return outputs 56 | 57 | 58 | def interpolate_spherical(p0, p1, fract_mixing: float): 59 | r""" 60 | borrowed from latentblending repo 61 | Helper function to correctly mix two random variables using spherical interpolation. 62 | See https://en.wikipedia.org/wiki/Slerp 63 | The function will always cast up to float64 for sake of extra 4. 64 | Args: 65 | p0: 66 | First tensor for interpolation 67 | p1: 68 | Second tensor for interpolation 69 | fract_mixing: float 70 | Mixing coefficient of interval [0, 1]. 71 | 0 will return in p0 72 | 1 will return in p1 73 | 0.x will return a mix between both preserving angular velocity. 74 | """ 75 | 76 | if p0.dtype == torch.float16: 77 | recast_to = 'fp16' 78 | else: 79 | recast_to = 'fp32' 80 | 81 | p0 = p0.double() 82 | p1 = p1.double() 83 | norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) 84 | epsilon = 1e-7 85 | dot = torch.sum(p0 * p1) / norm 86 | dot = dot.clamp(-1 + epsilon, 1 - epsilon) 87 | 88 | theta_0 = torch.arccos(dot) 89 | sin_theta_0 = torch.sin(theta_0) 90 | theta_t = theta_0 * fract_mixing 91 | s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 92 | s1 = torch.sin(theta_t) / sin_theta_0 93 | interp = p0 * s0 + p1 * s1 94 | 95 | if recast_to == 'fp16': 96 | interp = interp.half() 97 | elif recast_to == 'fp32': 98 | interp = interp.float() 99 | 100 | return interp 101 | 102 | 103 | def interpolate_linear(p0, p1, fract_mixing): 104 | r""" 105 | Helper function to mix two variables using standard linear interpolation. 106 | Args: 107 | p0: 108 | First tensor / np.ndarray for interpolation 109 | p1: 110 | Second tensor / np.ndarray for interpolation 111 | fract_mixing: float 112 | Mixing coefficient of interval [0, 1]. 113 | 0 will return in p0 114 | 1 will return in p1 115 | 0.x will return a linear mix between both. 116 | """ 117 | reconvert_uint8 = False 118 | if type(p0) is np.ndarray and p0.dtype == 'uint8': 119 | reconvert_uint8 = True 120 | p0 = p0.astype(np.float64) 121 | 122 | if type(p1) is np.ndarray and p1.dtype == 'uint8': 123 | reconvert_uint8 = True 124 | p1 = p1.astype(np.float64) 125 | 126 | interp = (1 - fract_mixing) * p0 + fract_mixing * p1 127 | 128 | if reconvert_uint8: 129 | interp = np.clip(interp, 0, 255).astype(np.uint8) 130 | 131 | return interp 132 | 133 | 134 | @dataclass 135 | class StableVideoDiffusionPipelineOutput(BaseOutput): 136 | r""" 137 | Output class for zero-shot text-to-video pipeline. 138 | 139 | Args: 140 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]): 141 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 142 | num_channels)`. 143 | """ 144 | 145 | frames: Union[List[PIL.Image.Image], np.ndarray] 146 | 147 | 148 | class StableVideoDiffusionPipeline_Custom(DiffusionPipeline): 149 | r""" 150 | Pipeline to generate video from an input image using Stable Video Diffusion. 151 | 152 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 153 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 154 | 155 | Args: 156 | vae ([`AutoencoderKL`]): 157 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 158 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): 159 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). 160 | unet ([`UNetSpatioTemporalConditionModel`]): 161 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. 162 | scheduler ([`EulerDiscreteScheduler`]): 163 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 164 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 165 | A `CLIPImageProcessor` to extract features from generated images. 166 | """ 167 | 168 | model_cpu_offload_seq = "image_encoder->unet->vae" 169 | _callback_tensor_inputs = ["latents"] 170 | 171 | def __init__( 172 | self, 173 | vae: AutoencoderKLTemporalDecoder, 174 | image_encoder: CLIPVisionModelWithProjection, 175 | unet: UNetSpatioTemporalConditionModel, 176 | scheduler: EulerDiscreteScheduler, #EulerDiscreteScheduler, 177 | feature_extractor: CLIPImageProcessor, 178 | ): 179 | super().__init__() 180 | 181 | self.register_modules( 182 | vae=vae, 183 | image_encoder=image_encoder, 184 | unet=unet, 185 | scheduler=scheduler, 186 | feature_extractor=feature_extractor, 187 | ) 188 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 189 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 190 | 191 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): 192 | dtype = next(self.image_encoder.parameters()).dtype 193 | 194 | if not isinstance(image, torch.Tensor): 195 | image = self.image_processor.pil_to_numpy(image) 196 | image = self.image_processor.numpy_to_pt(image) 197 | 198 | # We normalize the image before resizing to match with the original implementation. 199 | # Then we unnormalize it after resizing. 200 | image = image * 2.0 - 1.0 201 | image = _resize_with_antialiasing(image, (224, 224)) 202 | image = (image + 1.0) / 2.0 203 | 204 | # Normalize the image with for CLIP input 205 | image = self.feature_extractor( 206 | images=image, 207 | do_normalize=True, 208 | do_center_crop=False, 209 | do_resize=False, 210 | do_rescale=False, 211 | return_tensors="pt", 212 | ).pixel_values 213 | 214 | image = image.to(device=device, dtype=dtype) 215 | image_embeddings = self.image_encoder(image).image_embeds 216 | image_embeddings = image_embeddings.unsqueeze(1) 217 | 218 | # duplicate image embeddings for each generation per prompt, using mps friendly method 219 | bs_embed, seq_len, _ = image_embeddings.shape 220 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 221 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 222 | 223 | if do_classifier_free_guidance: 224 | negative_image_embeddings = torch.zeros_like(image_embeddings) 225 | 226 | # For classifier free guidance, we need to do two forward passes. 227 | # Here we concatenate the unconditional and text embeddings into a single batch 228 | # to avoid doing two forward passes 229 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 230 | 231 | return image_embeddings 232 | 233 | def _encode_vae_image( 234 | self, 235 | image: torch.Tensor, 236 | device, 237 | num_videos_per_prompt, 238 | do_classifier_free_guidance, 239 | ): 240 | image = image.to(device=device) 241 | image_latents = self.vae.encode(image).latent_dist.mode() 242 | 243 | if do_classifier_free_guidance: 244 | negative_image_latents = torch.zeros_like(image_latents) 245 | 246 | # For classifier free guidance, we need to do two forward passes. 247 | # Here we concatenate the unconditional and text embeddings into a single batch 248 | # to avoid doing two forward passes 249 | image_latents = torch.cat([negative_image_latents, image_latents]) 250 | 251 | # duplicate image_latents for each generation per prompt, using mps friendly method 252 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) 253 | 254 | return image_latents 255 | 256 | def _get_add_time_ids( 257 | self, 258 | fps, 259 | motion_bucket_id, 260 | noise_aug_strength, 261 | dtype, 262 | batch_size, 263 | num_videos_per_prompt, 264 | do_classifier_free_guidance, 265 | ): 266 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 267 | 268 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) 269 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 270 | 271 | if expected_add_embed_dim != passed_add_embed_dim: 272 | raise ValueError( 273 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 274 | ) 275 | 276 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 277 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 278 | 279 | if do_classifier_free_guidance: 280 | add_time_ids = torch.cat([add_time_ids, add_time_ids]) 281 | 282 | return add_time_ids 283 | 284 | def decode_latents(self, latents, num_frames, decode_chunk_size=14): 285 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 286 | latents = latents.flatten(0, 1) 287 | 288 | latents = 1 / self.vae.config.scaling_factor * latents 289 | 290 | accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) 291 | 292 | # decode decode_chunk_size frames at a time to avoid OOM 293 | frames = [] 294 | for i in range(0, latents.shape[0], decode_chunk_size): 295 | num_frames_in = latents[i : i + decode_chunk_size].shape[0] 296 | decode_kwargs = {} 297 | if accepts_num_frames: 298 | # we only pass num_frames_in if it's expected 299 | decode_kwargs["num_frames"] = num_frames_in 300 | 301 | frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample 302 | frames.append(frame) 303 | frames = torch.cat(frames, dim=0) 304 | 305 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] 306 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) 307 | 308 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 309 | frames = frames.float() 310 | return frames 311 | 312 | def check_inputs(self, image, height, width): 313 | if ( 314 | not isinstance(image, torch.Tensor) 315 | and not isinstance(image, PIL.Image.Image) 316 | and not isinstance(image, list) 317 | ): 318 | raise ValueError( 319 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 320 | f" {type(image)}" 321 | ) 322 | 323 | if height % 8 != 0 or width % 8 != 0: 324 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 325 | 326 | def prepare_latents( 327 | self, 328 | batch_size, 329 | num_frames, 330 | num_channels_latents, 331 | height, 332 | width, 333 | dtype, 334 | device, 335 | generator, 336 | latents=None, 337 | ): 338 | shape = ( 339 | batch_size, 340 | num_frames, 341 | num_channels_latents // 2, 342 | height // self.vae_scale_factor, 343 | width // self.vae_scale_factor, 344 | ) 345 | if isinstance(generator, list) and len(generator) != batch_size: 346 | raise ValueError( 347 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 348 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 349 | ) 350 | # import ipdb; ipdb.set_trace() 351 | if latents is None: 352 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 353 | else: 354 | latents = latents.to(device) 355 | 356 | # scale the initial noise by the standard deviation required by the scheduler 357 | latents = latents * self.scheduler.init_noise_sigma 358 | return latents 359 | 360 | def get_blend_weights(self, latents, step_id, num_frames, weight_type = 'non-linear', time_dir='forward'): 361 | #weight type: linear, non-linear, progression (progressively reduce the influence of backward frames) 362 | if weight_type == 'linear': 363 | blend_weights = torch.linspace(1, 0, steps=num_frames).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device) 364 | elif weight_type == 'non-linear': 365 | # import ipdb; ipdb.set_trace() 366 | b_value = 0.85 367 | blend_weights = torch.tensor([b_value ** x for x in range(num_frames)]).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device) 368 | if time_dir =='forward': 369 | blend_weights = torch.flip((1 - blend_weights), dims=[0]) 370 | 371 | elif weight_type == 'atanh': 372 | # import ipdb; ipdb.set_trace() 373 | b_value = 0.95 374 | torch.linspace(-b_value, b_value, steps=num_frames) 375 | curve = torch.atanh(torch.linspace(-b_value, b_value, steps=num_frames)).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device) 376 | curve_norm = (curve - curve.min()) / (curve.max() - curve.min()) 377 | blend_weights = curve_norm 378 | if time_dir =='forward': 379 | blend_weights = torch.flip((1 - blend_weights), dims=[0]) 380 | 381 | elif weight_type == 'non-linear-decay': 382 | b_list = torch.linspace(0.98, 0.5, steps=num_frames) 383 | a_value = 5 384 | b_value = b_list[step_id] 385 | blend_weights = [a_value * b_value ** x for x in num_frames] 386 | if time_dir =='forward': 387 | blend_weights = (1 - blend_weights).reverse() 388 | 389 | return blend_weights 390 | 391 | @property 392 | def guidance_scale(self): 393 | return self._guidance_scale 394 | 395 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 396 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 397 | # corresponds to doing no classifier free guidance. 398 | @property 399 | def do_classifier_free_guidance(self): 400 | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None 401 | 402 | @property 403 | def num_timesteps(self): 404 | return self._num_timesteps 405 | 406 | @torch.no_grad() 407 | def __call__( 408 | self, 409 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 410 | end_frame: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, 411 | height: int = 576, 412 | width: int = 1024, 413 | num_frames: Optional[int] = None, 414 | num_inference_steps: int = 25, 415 | min_guidance_scale: float = 1.0, 416 | max_guidance_scale: float = 3.0, 417 | fps: int = 7, 418 | motion_bucket_id: int = 127, 419 | noise_aug_strength: int = 0.02, 420 | decode_chunk_size: Optional[int] = None, 421 | num_videos_per_prompt: Optional[int] = 1, 422 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 423 | latents: Optional[torch.FloatTensor] = None, 424 | output_type: Optional[str] = "pil", 425 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 426 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 427 | return_dict: bool = True, 428 | same_noise_per_frame: bool = True, 429 | jump_length: int = 3, 430 | jump_n_sample: int = 2, 431 | repeat_step_ratio: float = 0.5, 432 | noise_scale_ratio: float = 0.5, 433 | custom_timestep_spacing: bool = True 434 | ): 435 | r""" 436 | The call function to the pipeline for generation. 437 | 438 | Args: 439 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): 440 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 441 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 442 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 443 | The height in pixels of the generated image. 444 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 445 | The width in pixels of the generated image. 446 | num_frames (`int`, *optional*): 447 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` 448 | num_inference_steps (`int`, *optional*, defaults to 25): 449 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 450 | expense of slower inference. This parameter is modulated by `strength`. 451 | min_guidance_scale (`float`, *optional*, defaults to 1.0): 452 | The minimum guidance scale. Used for the classifier free guidance with first frame. 453 | max_guidance_scale (`float`, *optional*, defaults to 3.0): 454 | The maximum guidance scale. Used for the classifier free guidance with last frame. 455 | fps (`int`, *optional*, defaults to 7): 456 | Frames per second. The rate at which the generated images shall be exported to a video after generation. 457 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. 458 | motion_bucket_id (`int`, *optional*, defaults to 127): 459 | The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. 460 | noise_aug_strength (`int`, *optional*, defaults to 0.02): 461 | The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. 462 | decode_chunk_size (`int`, *optional*): 463 | The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency 464 | between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once 465 | for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. 466 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 467 | The number of images to generate per prompt. 468 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 469 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 470 | generation deterministic. 471 | latents (`torch.FloatTensor`, *optional*): 472 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 473 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 474 | tensor is generated by sampling using the supplied random `generator`. 475 | output_type (`str`, *optional*, defaults to `"pil"`): 476 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 477 | callback_on_step_end (`Callable`, *optional*): 478 | A function that calls at the end of each denoising steps during the inference. The function is called 479 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 480 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 481 | `callback_on_step_end_tensor_inputs`. 482 | callback_on_step_end_tensor_inputs (`List`, *optional*): 483 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 484 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 485 | `._callback_tensor_inputs` attribute of your pipeline class. 486 | return_dict (`bool`, *optional*, defaults to `True`): 487 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 488 | plain tuple. 489 | 490 | Returns: 491 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: 492 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, 493 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames. 494 | 495 | Examples: 496 | 497 | ```py 498 | from diffusers import StableVideoDiffusionPipeline 499 | from diffusers.utils import load_image, export_to_video 500 | 501 | pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") 502 | pipe.to("cuda") 503 | 504 | image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") 505 | image = image.resize((1024, 576)) 506 | 507 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] 508 | export_to_video(frames, "generated.mp4", fps=7) 509 | ``` 510 | """ 511 | # import ipdb; ipdb.set_trace() 512 | # 0. Default height and width to unet 513 | height = height or self.unet.config.sample_size * self.vae_scale_factor 514 | width = width or self.unet.config.sample_size * self.vae_scale_factor 515 | 516 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames 517 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames 518 | 519 | # 1. Check inputs. Raise error if not correct 520 | self.check_inputs(image, height, width) 521 | 522 | # 2. Define call parameters 523 | if isinstance(image, PIL.Image.Image): 524 | batch_size = 1 525 | elif isinstance(image, list): 526 | batch_size = len(image) 527 | else: 528 | batch_size = image.shape[0] 529 | device = self._execution_device 530 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 531 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 532 | # corresponds to doing no classifier free guidance. 533 | do_classifier_free_guidance = max_guidance_scale > 1.0 534 | emb_cond = {} 535 | # 3. Encode input image 536 | image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) 537 | if end_frame is not None: 538 | image_embeddings_end_frame = self._encode_image(end_frame, device, num_videos_per_prompt, do_classifier_free_guidance) 539 | emb_cond['forward'] = image_embeddings 540 | emb_cond['backward'] = image_embeddings_end_frame 541 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which 542 | # is why it is reduced here. 543 | # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 544 | fps = fps - 1 545 | 546 | # 4. Encode input image using VAE 547 | image = self.image_processor.preprocess(image, height=height, width=width) 548 | if end_frame is not None: 549 | end_frame = self.image_processor.preprocess(end_frame, height=height, width=width) 550 | 551 | noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) 552 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 553 | 554 | frame_cond = {} 555 | 556 | def image_to_image_latents(image): 557 | image = image + noise_aug_strength * noise 558 | if needs_upcasting: 559 | self.vae.to(dtype=torch.float32) 560 | image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) 561 | image_latents = image_latents.to(image_embeddings.dtype) 562 | # cast back to fp16 if needed 563 | if needs_upcasting: 564 | self.vae.to(dtype=torch.float16) 565 | # Repeat the image latents for each frame so we can concatenate them with the noise 566 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] 567 | image_latents = image_latents.unsqueeze(1).repeat(1,num_frames,1,1,1) #torch.Size([2, 25, 4, 72, 128]) 568 | return image_latents 569 | 570 | frame_cond['forward'] = image_to_image_latents(image) 571 | frame_cond['backward'] = image_to_image_latents(end_frame) 572 | 573 | # 5. Get Added Time IDs 574 | added_time_ids = self._get_add_time_ids( 575 | fps, 576 | motion_bucket_id, 577 | noise_aug_strength, 578 | image_embeddings.dtype, 579 | batch_size, 580 | num_videos_per_prompt, 581 | do_classifier_free_guidance, 582 | ) 583 | added_time_ids = added_time_ids.to(device) 584 | 585 | # 4. Prepare timesteps 586 | self.scheduler.set_timesteps(num_inference_steps) 587 | 588 | if custom_timestep_spacing: 589 | #todo: gradually decrease the number of jumps 590 | reverse_list = np.array(list(range(num_inference_steps)))[::-1] 591 | timesteps_idx = [] 592 | jumps = {} 593 | for j in range(int(num_inference_steps * repeat_step_ratio), num_inference_steps - jump_length, jump_length): 594 | jumps[j] = jump_n_sample - 1 595 | 596 | t = num_inference_steps 597 | while t >= 1: 598 | t = t - 1 599 | timesteps_idx.append(t) 600 | 601 | if jumps.get(t, 0) > 0: 602 | jumps[t] = jumps[t] - 1 603 | for _ in range(jump_length): 604 | t = t + 1 605 | timesteps_idx.append(t) 606 | timesteps_idx = reverse_list[timesteps_idx] 607 | 608 | step_id_last = timesteps_idx[0] - 1 609 | timesteps = self.scheduler.timesteps 610 | 611 | # 5. Prepare latent variables 612 | num_channels_latents = self.unet.config.in_channels 613 | latents = self.prepare_latents( 614 | batch_size * num_videos_per_prompt, 615 | num_frames, 616 | num_channels_latents, 617 | height, 618 | width, 619 | image_embeddings.dtype, 620 | device, 621 | generator, 622 | latents, 623 | ) 624 | 625 | # 7. Prepare guidance scale 626 | guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) 627 | guidance_scale = guidance_scale.to(device, latents.dtype) 628 | guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) 629 | guidance_scale = _append_dims(guidance_scale, latents.ndim) 630 | 631 | self._guidance_scale = guidance_scale 632 | 633 | count = torch.zeros_like(latents) 634 | value = torch.zeros_like(latents) 635 | 636 | # 8. Denoising loop 637 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 638 | self._num_timesteps = len(timesteps) 639 | 640 | with self.progress_bar(total=len(timesteps_idx)) as progress_bar: 641 | for i, step_id in enumerate(timesteps_idx): 642 | 643 | # import ipdb; ipdb.set_trace() 644 | t = timesteps[step_id] 645 | # expand the latents if we are doing classifier free guidance 646 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 647 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 648 | # import ipdb; ipdb.set_trace() 649 | count.zero_() 650 | value.zero_() 651 | for time_dir in ['forward', 'backward']: 652 | # Concatenate image_latents over channels dimention 653 | 654 | if step_id > step_id_last: 655 | image_latents = frame_cond[time_dir] 656 | image_embeddings = emb_cond[time_dir] if end_frame is not None else image_embeddings 657 | latent_model_input_new = torch.cat([latent_model_input, image_latents], dim=2) 658 | if time_dir == 'backward': 659 | latent_model_input_new = torch.flip(latent_model_input_new, dims=[1]) 660 | latents = torch.flip(latents, dims=[1]) 661 | # predict the noise residual 662 | noise_pred = self.unet( 663 | latent_model_input_new, 664 | t, 665 | encoder_hidden_states=image_embeddings, 666 | added_time_ids=added_time_ids, 667 | return_dict=False, 668 | )[0] 669 | # perform guidance 670 | if do_classifier_free_guidance: 671 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 672 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 673 | # compute the previous noisy sample x_t -> x_t-1 674 | self.scheduler._step_index = step_id 675 | latents_view_denoised = self.scheduler.step(noise_pred, t, latents, step_type=time_dir).prev_sample 676 | blend_weights = self.get_blend_weights(latents, step_id, num_frames, weight_type = 'linear', time_dir=time_dir) 677 | latents_view_denoised = latents_view_denoised * blend_weights 678 | if time_dir == 'backward': 679 | # import ipdb; ipdb.set_trace() 680 | latents_view_denoised = torch.flip(latents_view_denoised, dims=[1]) 681 | 682 | else: 683 | latents_view_denoised = self.scheduler.undo_step(latents, step_id_last, ratio=noise_scale_ratio) 684 | if time_dir == 'backward': 685 | latents_view_denoised = 0. 686 | 687 | value += latents_view_denoised 688 | count += 0.5 689 | 690 | step_id_last = step_id 691 | 692 | latents = torch.where(count > 0, value / count, value) 693 | 694 | if callback_on_step_end is not None: 695 | callback_kwargs = {} 696 | for k in callback_on_step_end_tensor_inputs: 697 | callback_kwargs[k] = locals()[k] 698 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 699 | latents = callback_outputs.pop("latents", latents) 700 | 701 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 702 | progress_bar.update() 703 | 704 | if not output_type == "latent": 705 | # cast back to fp16 if needed 706 | if needs_upcasting: 707 | self.vae.to(dtype=torch.float16) 708 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) 709 | frames = tensor2vid(frames, self.image_processor, output_type=output_type) 710 | else: 711 | frames = latents 712 | 713 | self.maybe_free_model_hooks() 714 | 715 | if not return_dict: 716 | return frames 717 | 718 | return StableVideoDiffusionPipelineOutput(frames=frames) 719 | 720 | 721 | # resizing utils 722 | # TODO: clean up later 723 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 724 | h, w = input.shape[-2:] 725 | factors = (h / size[0], w / size[1]) 726 | 727 | # First, we have to determine sigma 728 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 729 | sigmas = ( 730 | max((factors[0] - 1.0) / 2.0, 0.001), 731 | max((factors[1] - 1.0) / 2.0, 0.001), 732 | ) 733 | 734 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 735 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 736 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 737 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 738 | 739 | # Make sure it is odd 740 | if (ks[0] % 2) == 0: 741 | ks = ks[0] + 1, ks[1] 742 | 743 | if (ks[1] % 2) == 0: 744 | ks = ks[0], ks[1] + 1 745 | 746 | input = _gaussian_blur2d(input, ks, sigmas) 747 | 748 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 749 | return output 750 | 751 | 752 | def _compute_padding(kernel_size): 753 | """Compute padding tuple.""" 754 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 755 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 756 | if len(kernel_size) < 2: 757 | raise AssertionError(kernel_size) 758 | computed = [k - 1 for k in kernel_size] 759 | 760 | # for even kernels we need to do asymmetric padding :( 761 | out_padding = 2 * len(kernel_size) * [0] 762 | 763 | for i in range(len(kernel_size)): 764 | computed_tmp = computed[-(i + 1)] 765 | 766 | pad_front = computed_tmp // 2 767 | pad_rear = computed_tmp - pad_front 768 | 769 | out_padding[2 * i + 0] = pad_front 770 | out_padding[2 * i + 1] = pad_rear 771 | 772 | return out_padding 773 | 774 | 775 | def _filter2d(input, kernel): 776 | # prepare kernel 777 | b, c, h, w = input.shape 778 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 779 | 780 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 781 | 782 | height, width = tmp_kernel.shape[-2:] 783 | 784 | padding_shape: list[int] = _compute_padding([height, width]) 785 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 786 | 787 | # kernel and input tensor reshape to align element-wise or batch-wise params 788 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 789 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 790 | 791 | # convolve the tensor with the kernel. 792 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 793 | 794 | out = output.view(b, c, h, w) 795 | return out 796 | 797 | 798 | def _gaussian(window_size: int, sigma): 799 | if isinstance(sigma, float): 800 | sigma = torch.tensor([[sigma]]) 801 | 802 | batch_size = sigma.shape[0] 803 | 804 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 805 | 806 | if window_size % 2 == 0: 807 | x = x + 0.5 808 | 809 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 810 | 811 | return gauss / gauss.sum(-1, keepdim=True) 812 | 813 | 814 | def _gaussian_blur2d(input, kernel_size, sigma): 815 | if isinstance(sigma, tuple): 816 | sigma = torch.tensor([sigma], dtype=input.dtype) 817 | else: 818 | sigma = sigma.to(dtype=input.dtype) 819 | 820 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 821 | bs = sigma.shape[0] 822 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 823 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 824 | out_x = _filter2d(input, kernel_x[..., None, :]) 825 | out = _filter2d(out_x, kernel_y[..., None]) 826 | 827 | return out 828 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | asttokens==2.2.1 3 | av==11.0.0 4 | backcall==0.2.0 5 | beautifulsoup4==4.12.2 6 | blessed==1.20.0 7 | brotlipy==0.7.0 8 | certifi==2024.2.2 9 | cffi==1.15.1 10 | charset-normalizer==2.0.4 11 | cmake==3.28.1 12 | colorama==0.4.6 13 | contourpy==1.1.0 14 | cryptography==41.0.2 15 | cycler==0.11.0 16 | decorator==5.1.1 17 | decord==0.6.0 18 | diffusers==0.24.0 19 | einops==0.7.0 20 | enlighten==1.12.4 21 | executing==1.2.0 22 | filelock==3.12.3 23 | fonttools==4.42.1 24 | fsspec==2023.6.0 25 | gdown==4.7.1 26 | huggingface-hub==0.19.4 27 | idna==3.4 28 | importlib-metadata==6.8.0 29 | importlib-resources==6.0.1 30 | ipdb==0.13.13 31 | ipython==8.12.2 32 | jedi==0.19.0 33 | Jinja2==3.1.2 34 | kiwisolver==1.4.5 35 | lit==17.0.6 36 | MarkupSafe==2.1.3 37 | matplotlib==3.7.2 38 | matplotlib-inline==0.1.6 39 | mkl-fft==1.3.6 40 | mkl-random==1.2.2 41 | mkl-service==2.4.0 42 | mpmath==1.3.0 43 | mypy-extensions==1.0.0 44 | networkx==3.1 45 | numpy==1.24.3 46 | nvidia-cublas-cu11==11.10.3.66 47 | nvidia-cublas-cu12==12.1.3.1 48 | nvidia-cuda-cupti-cu11==11.7.101 49 | nvidia-cuda-cupti-cu12==12.1.105 50 | nvidia-cuda-nvrtc-cu11==11.7.99 51 | nvidia-cuda-nvrtc-cu12==12.1.105 52 | nvidia-cuda-runtime-cu11==11.7.99 53 | nvidia-cuda-runtime-cu12==12.1.105 54 | nvidia-cudnn-cu11==8.5.0.96 55 | nvidia-cudnn-cu12==8.9.2.26 56 | nvidia-cufft-cu11==10.9.0.58 57 | nvidia-cufft-cu12==11.0.2.54 58 | nvidia-curand-cu11==10.2.10.91 59 | nvidia-curand-cu12==10.3.2.106 60 | nvidia-cusolver-cu11==11.4.0.1 61 | nvidia-cusolver-cu12==11.4.5.107 62 | nvidia-cusparse-cu11==11.7.4.91 63 | nvidia-cusparse-cu12==12.1.0.106 64 | nvidia-nccl-cu11==2.14.3 65 | nvidia-nccl-cu12==2.18.1 66 | nvidia-nvjitlink-cu12==12.3.101 67 | nvidia-nvtx-cu11==11.7.91 68 | nvidia-nvtx-cu12==12.1.105 69 | opencv-contrib-python==4.9.0.80 70 | opencv-python==4.8.0.76 71 | packaging==23.1 72 | pandas==2.0.3 73 | parso==0.8.3 74 | pexpect==4.8.0 75 | pickleshare==0.7.5 76 | Pillow==9.4.0 77 | pip==23.2.1 78 | prefixed==0.7.0 79 | prompt-toolkit==3.0.39 80 | protobuf==4.24.2 81 | psutil==5.9.5 82 | ptyprocess==0.7.0 83 | pure-eval==0.2.2 84 | pycolmap==0.6.1 85 | pycparser==2.21 86 | Pygments==2.16.1 87 | pyOpenSSL==23.2.0 88 | pyparsing==3.0.9 89 | pyre-extensions==0.0.29 90 | PySocks==1.7.1 91 | python-dateutil==2.8.2 92 | pytorch-fid==0.3.0 93 | pytz==2024.1 94 | PyYAML==5.1.2 95 | regex==2022.7.9 96 | requests==2.31.0 97 | safetensors==0.3.3 98 | scipy==1.10.1 99 | sentencepiece==0.1.99 100 | setuptools==68.0.0 101 | six==1.16.0 102 | soupsieve==2.5 103 | splatting==0.0.0 104 | stack-data==0.6.2 105 | sympy==1.12 106 | tokenizers==0.15.0 107 | tomli==2.0.1 108 | torch==2.0.1 109 | torchaudio==0.13.1 110 | torchvision==0.14.1 111 | tqdm==4.66.1 112 | traitlets==5.9.0 113 | transformers==4.35.2 114 | triton==2.0.0 115 | typing_extensions==4.7.1 116 | typing-inspect==0.9.0 117 | tzdata==2024.1 118 | urllib3==1.26.16 119 | wcwidth==0.2.6 120 | wheel==0.38.4 121 | xformers==0.0.20 122 | zipp==3.16.2 123 | -------------------------------------------------------------------------------- /scheduling_euler_discrete_resampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 26 | 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete 33 | class EulerDiscreteSchedulerOutput(BaseOutput): 34 | """ 35 | Output class for the scheduler's `step` function output. 36 | 37 | Args: 38 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 40 | denoising loop. 41 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 42 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 43 | `pred_original_sample` can be used to preview progress or for guidance. 44 | """ 45 | 46 | prev_sample: torch.FloatTensor 47 | pred_original_sample: Optional[torch.FloatTensor] = None 48 | 49 | 50 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 51 | def betas_for_alpha_bar( 52 | num_diffusion_timesteps, 53 | max_beta=0.999, 54 | alpha_transform_type="cosine", 55 | ): 56 | """ 57 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 58 | (1-beta) over time from t = [0,1]. 59 | 60 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 61 | to that part of the diffusion process. 62 | 63 | 64 | Args: 65 | num_diffusion_timesteps (`int`): the number of betas to produce. 66 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 67 | prevent singularities. 68 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 69 | Choose from `cosine` or `exp` 70 | 71 | Returns: 72 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 73 | """ 74 | if alpha_transform_type == "cosine": 75 | 76 | def alpha_bar_fn(t): 77 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 78 | 79 | elif alpha_transform_type == "exp": 80 | 81 | def alpha_bar_fn(t): 82 | return math.exp(t * -12.0) 83 | 84 | else: 85 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 86 | 87 | betas = [] 88 | for i in range(num_diffusion_timesteps): 89 | t1 = i / num_diffusion_timesteps 90 | t2 = (i + 1) / num_diffusion_timesteps 91 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 92 | return torch.tensor(betas, dtype=torch.float32) 93 | 94 | 95 | class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 96 | """ 97 | Euler scheduler. 98 | 99 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 100 | methods the library implements for all schedulers such as loading and saving. 101 | 102 | Args: 103 | num_train_timesteps (`int`, defaults to 1000): 104 | The number of diffusion steps to train the model. 105 | beta_start (`float`, defaults to 0.0001): 106 | The starting `beta` value of inference. 107 | beta_end (`float`, defaults to 0.02): 108 | The final `beta` value. 109 | beta_schedule (`str`, defaults to `"linear"`): 110 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 111 | `linear` or `scaled_linear`. 112 | trained_betas (`np.ndarray`, *optional*): 113 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. 114 | prediction_type (`str`, defaults to `epsilon`, *optional*): 115 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 116 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 117 | Video](https://imagen.research.google/video/paper.pdf) paper). 118 | interpolation_type(`str`, defaults to `"linear"`, *optional*): 119 | The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of 120 | `"linear"` or `"log_linear"`. 121 | use_karras_sigmas (`bool`, *optional*, defaults to `False`): 122 | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 123 | the sigmas are determined according to a sequence of noise levels {σi}. 124 | timestep_spacing (`str`, defaults to `"linspace"`): 125 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 126 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 127 | steps_offset (`int`, defaults to 0): 128 | An offset added to the inference steps. You can use a combination of `offset=1` and 129 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 130 | Diffusion. 131 | """ 132 | 133 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 134 | order = 1 135 | 136 | @register_to_config 137 | def __init__( 138 | self, 139 | num_train_timesteps: int = 1000, 140 | beta_start: float = 0.0001, 141 | beta_end: float = 0.02, 142 | beta_schedule: str = "linear", 143 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 144 | prediction_type: str = "epsilon", 145 | interpolation_type: str = "linear", 146 | use_karras_sigmas: Optional[bool] = False, 147 | sigma_min: Optional[float] = None, 148 | sigma_max: Optional[float] = None, 149 | timestep_spacing: str = "linspace", 150 | timestep_type: str = "discrete", # can be "discrete" or "continuous" 151 | steps_offset: int = 0, 152 | ): 153 | if trained_betas is not None: 154 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 155 | elif beta_schedule == "linear": 156 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 157 | elif beta_schedule == "scaled_linear": 158 | # this schedule is very specific to the latent diffusion model. 159 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 160 | elif beta_schedule == "squaredcos_cap_v2": 161 | # Glide cosine schedule 162 | self.betas = betas_for_alpha_bar(num_train_timesteps) 163 | else: 164 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 165 | 166 | self.alphas = 1.0 - self.betas 167 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 168 | 169 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 170 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() 171 | 172 | sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32) 173 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 174 | 175 | # setable values 176 | self.num_inference_steps = None 177 | 178 | # TODO: Support the full EDM scalings for all prediction types and timestep types 179 | if timestep_type == "continuous" and prediction_type == "v_prediction": 180 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) 181 | else: 182 | self.timesteps = timesteps 183 | 184 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 185 | 186 | self.is_scale_input_called = False 187 | self.use_karras_sigmas = use_karras_sigmas 188 | 189 | self._step_index = None 190 | 191 | @property 192 | def init_noise_sigma(self): 193 | # standard deviation of the initial noise distribution 194 | if self.config.timestep_spacing in ["linspace", "trailing"]: 195 | return self.sigmas.max() 196 | 197 | return (self.sigmas.max() ** 2 + 1) ** 0.5 198 | 199 | @property 200 | def step_index(self): 201 | """ 202 | The index counter for current timestep. It will increae 1 after each scheduler step. 203 | """ 204 | return self._step_index 205 | 206 | def scale_model_input( 207 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 208 | ) -> torch.FloatTensor: 209 | """ 210 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 211 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. 212 | 213 | Args: 214 | sample (`torch.FloatTensor`): 215 | The input sample. 216 | timestep (`int`, *optional*): 217 | The current timestep in the diffusion chain. 218 | 219 | Returns: 220 | `torch.FloatTensor`: 221 | A scaled input sample. 222 | """ 223 | if self.step_index is None: 224 | self._init_step_index(timestep) 225 | 226 | sigma = self.sigmas[self.step_index] 227 | sample = sample / ((sigma**2 + 1) ** 0.5) 228 | 229 | self.is_scale_input_called = True 230 | return sample 231 | 232 | def set_timesteps( 233 | self, 234 | num_inference_steps: int, 235 | device: Union[str, torch.device] = None): 236 | """ 237 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 238 | 239 | Args: 240 | num_inference_steps (`int`): 241 | The number of diffusion steps used when generating samples with a pre-trained model. 242 | device (`str` or `torch.device`, *optional*): 243 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 244 | """ 245 | self.num_inference_steps = num_inference_steps 246 | # import ipdb; ipdb.set_trace() 247 | 248 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 249 | if self.config.timestep_spacing == "linspace": 250 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ 251 | ::-1 252 | ].copy() 253 | elif self.config.timestep_spacing == "leading": 254 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 255 | # creates integer timesteps by multiplying by ratio 256 | # casting to int to avoid issues when num_inference_step is power of 3 257 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) 258 | timesteps += self.config.steps_offset 259 | elif self.config.timestep_spacing == "trailing": 260 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 261 | # creates integer timesteps by multiplying by ratio 262 | # casting to int to avoid issues when num_inference_step is power of 3 263 | timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) 264 | timesteps -= 1 265 | else: 266 | raise ValueError( 267 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." 268 | ) 269 | 270 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 271 | log_sigmas = np.log(sigmas) 272 | 273 | if self.config.interpolation_type == "linear": 274 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) 275 | elif self.config.interpolation_type == "log_linear": 276 | sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() 277 | else: 278 | raise ValueError( 279 | f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" 280 | " 'linear' or 'log_linear'" 281 | ) 282 | 283 | if self.use_karras_sigmas: 284 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) 285 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) 286 | 287 | 288 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 289 | 290 | # TODO: Support the full EDM scalings for all prediction types and timestep types 291 | if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": 292 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) 293 | else: 294 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) 295 | 296 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 297 | self._step_index = None 298 | 299 | def _sigma_to_t(self, sigma, log_sigmas): 300 | # get log sigma 301 | log_sigma = np.log(np.maximum(sigma, 1e-10)) 302 | 303 | # get distribution 304 | dists = log_sigma - log_sigmas[:, np.newaxis] 305 | 306 | # get sigmas range 307 | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) 308 | high_idx = low_idx + 1 309 | 310 | low = log_sigmas[low_idx] 311 | high = log_sigmas[high_idx] 312 | 313 | # interpolate sigmas 314 | w = (low - log_sigma) / (low - high) 315 | w = np.clip(w, 0, 1) 316 | 317 | # transform interpolation to time range 318 | t = (1 - w) * low_idx + w * high_idx 319 | t = t.reshape(sigma.shape) 320 | return t 321 | 322 | # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 323 | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: 324 | """Constructs the noise schedule of Karras et al. (2022).""" 325 | 326 | # Hack to make sure that other schedulers which copy this function don't break 327 | # TODO: Add this logic to the other schedulers 328 | if hasattr(self.config, "sigma_min"): 329 | sigma_min = self.config.sigma_min 330 | else: 331 | sigma_min = None 332 | 333 | if hasattr(self.config, "sigma_max"): 334 | sigma_max = self.config.sigma_max 335 | else: 336 | sigma_max = None 337 | 338 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 339 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 340 | 341 | rho = 7.0 # 7.0 is the value used in the paper 342 | ramp = np.linspace(0, 1, num_inference_steps) 343 | min_inv_rho = sigma_min ** (1 / rho) 344 | max_inv_rho = sigma_max ** (1 / rho) 345 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 346 | return sigmas 347 | 348 | def _init_step_index(self, timestep): 349 | if isinstance(timestep, torch.Tensor): 350 | timestep = timestep.to(self.timesteps.device) 351 | 352 | index_candidates = (self.timesteps == timestep).nonzero() 353 | 354 | # The sigma index that is taken for the **very** first `step` 355 | # is always the second index (or the last index if there is only 1) 356 | # This way we can ensure we don't accidentally skip a sigma in 357 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 358 | if len(index_candidates) > 1: 359 | step_index = index_candidates[1] 360 | else: 361 | step_index = index_candidates[0] 362 | 363 | self._step_index = step_index.item() 364 | 365 | def step( 366 | self, 367 | model_output: torch.FloatTensor, 368 | timestep: Union[float, torch.FloatTensor], 369 | sample: torch.FloatTensor, 370 | step_type: str = "forward", 371 | s_churn: float = 0.0, 372 | s_tmin: float = 0.0, 373 | s_tmax: float = float("inf"), 374 | s_noise: float = 1.0, 375 | generator: Optional[torch.Generator] = None, 376 | return_dict: bool = True, 377 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 378 | """ 379 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 380 | process from the learned model outputs (most often the predicted noise). 381 | 382 | Args: 383 | model_output (`torch.FloatTensor`): 384 | The direct output from learned diffusion model. 385 | timestep (`float`): 386 | The current discrete timestep in the diffusion chain. 387 | sample (`torch.FloatTensor`): 388 | A current instance of a sample created by the diffusion process. 389 | s_churn (`float`): 390 | s_tmin (`float`): 391 | s_tmax (`float`): 392 | s_noise (`float`, defaults to 1.0): 393 | Scaling factor for noise added to the sample. 394 | generator (`torch.Generator`, *optional*): 395 | A random number generator. 396 | return_dict (`bool`): 397 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 398 | tuple. 399 | 400 | Returns: 401 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 402 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 403 | returned, otherwise a tuple is returned where the first element is the sample tensor. 404 | """ 405 | 406 | if ( 407 | isinstance(timestep, int) 408 | or isinstance(timestep, torch.IntTensor) 409 | or isinstance(timestep, torch.LongTensor) 410 | ): 411 | raise ValueError( 412 | ( 413 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 414 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 415 | " one of the `scheduler.timesteps` as a timestep." 416 | ), 417 | ) 418 | 419 | if not self.is_scale_input_called: 420 | logger.warning( 421 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 422 | "See `StableDiffusionPipeline` for a usage example." 423 | ) 424 | # import ipdb; ipdb.set_trace() 425 | if self.step_index is None: 426 | self._init_step_index(timestep) 427 | 428 | sigma = self.sigmas[self.step_index] 429 | 430 | gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 431 | 432 | noise = randn_tensor( 433 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator 434 | ) 435 | 436 | eps = noise * s_noise 437 | sigma_hat = sigma * (gamma + 1) 438 | 439 | if gamma > 0: 440 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 441 | 442 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 443 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 444 | # backwards compatibility 445 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": 446 | pred_original_sample = model_output 447 | elif self.config.prediction_type == "epsilon": 448 | pred_original_sample = sample - sigma_hat * model_output 449 | elif self.config.prediction_type == "v_prediction": 450 | # denoised = model_output * c_out + input * c_skip 451 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 452 | else: 453 | raise ValueError( 454 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 455 | ) 456 | 457 | # 2. Convert to an ODE derivative 458 | derivative = (sample - pred_original_sample) / sigma_hat 459 | dt = self.sigmas[self.step_index + 1] - sigma_hat 460 | 461 | prev_sample = sample + derivative * dt 462 | 463 | # upon completion increase step index by one 464 | self._step_index += 1 465 | 466 | if not return_dict: 467 | return (prev_sample,) 468 | 469 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 470 | 471 | 472 | def undo_step(self, sample, step_id, generator=None, ratio=0.49): 473 | noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) 474 | sample = sample + noise * ((self.sigmas[step_id]**2 - self.sigmas[step_id+1]**2) ** 0.5) * ratio 475 | return sample 476 | 477 | 478 | def add_noise( 479 | self, 480 | original_samples: torch.FloatTensor, 481 | noise: torch.FloatTensor, 482 | timesteps: torch.FloatTensor, 483 | ) -> torch.FloatTensor: 484 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 485 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 486 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 487 | # mps does not support float64 488 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) 489 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 490 | else: 491 | schedule_timesteps = self.timesteps.to(original_samples.device) 492 | timesteps = timesteps.to(original_samples.device) 493 | 494 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 495 | 496 | sigma = sigmas[step_indices].flatten() 497 | while len(sigma.shape) < len(original_samples.shape): 498 | sigma = sigma.unsqueeze(-1) 499 | 500 | noisy_samples = original_samples + noise * sigma 501 | return noisy_samples 502 | 503 | def __len__(self): 504 | return self.config.num_train_timesteps 505 | -------------------------------------------------------------------------------- /svd_sequential_re.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import tqdm._tqdm 4 | import torch 5 | from diffusers import StableVideoDiffusionPipeline 6 | from PIL import Image 7 | from diffusers.utils import load_image, export_to_video 8 | from pipeline_stable_video_diffusion_re import StableVideoDiffusionPipeline_Custom 9 | from unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel 10 | from scheduling_euler_discrete_resampling import EulerDiscreteScheduler 11 | from datasets import get_dataset 12 | import tqdm 13 | 14 | def img_preprocess(img_path, mode='crop', orig_aspect=None): 15 | image = load_image(img_path) 16 | w, h = image.size 17 | aspect_ratio = orig_aspect if orig_aspect else 16 / 9 18 | 19 | if mode == 'crop': 20 | # Crop the image to the specified aspect ratio 21 | if w / h > aspect_ratio: 22 | height = h 23 | width = int(height * aspect_ratio) 24 | else: 25 | width = w 26 | height = int(width / aspect_ratio) 27 | left = (w - width) // 2 28 | top = (h - height) // 2 29 | image = image.crop((left, top, left + width, top + height)) 30 | image = image.resize((1024, 576)) 31 | elif mode == 'padding': 32 | # Pad the image to the specified aspect ratio 33 | new_w = int(h * aspect_ratio) if w / h < aspect_ratio else w 34 | new_h = int(w / aspect_ratio) if w / h >= aspect_ratio else h 35 | new_image = Image.new("RGB", (new_w, new_h), (0, 0, 0)) 36 | new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2)) 37 | image = new_image.resize((1024, 576)) 38 | return image 39 | 40 | def time_reversal_fusion(model_card, start_frame, end_frame, num_inference_steps, fps_value, jump_n_sample, jump_length, repeat_step_ratio, noise_scale_ratio, motion_id, generator): 41 | # Load model components 42 | original_pipe = StableVideoDiffusionPipeline.from_pretrained( 43 | model_card, torch_dtype=torch.float16, variant="fp16" 44 | ) 45 | unet_custom = UNetSpatioTemporalConditionModel.from_pretrained( 46 | model_card, subfolder="unet", torch_dtype=torch.float16, variant="fp16" 47 | ).to('cuda') 48 | scheduler_custom = EulerDiscreteScheduler.from_pretrained( 49 | model_card, subfolder="scheduler", torch_dtype=torch.float16, variant="fp16" 50 | ) 51 | # Generate frames 52 | pipe = StableVideoDiffusionPipeline_Custom( 53 | vae=original_pipe.vae, 54 | image_encoder=original_pipe.image_encoder, 55 | unet=unet_custom, 56 | scheduler=scheduler_custom, 57 | feature_extractor=original_pipe.feature_extractor, 58 | ) 59 | pipe.enable_model_cpu_offload() 60 | frames = pipe( 61 | start_frame, 62 | end_frame, 63 | height=start_frame.height, 64 | width=start_frame.width, 65 | num_frames=25, 66 | num_inference_steps=num_inference_steps, 67 | fps=fps_value, 68 | jump_length=jump_length, 69 | jump_n_sample=jump_n_sample, 70 | repeat_step_ratio=repeat_step_ratio, 71 | noise_scale_ratio=noise_scale_ratio, 72 | decode_chunk_size=8, 73 | motion_bucket_id=motion_id, 74 | generator=generator, 75 | ).frames[0] 76 | 77 | return frames 78 | 79 | def main(data_type): 80 | # Configuration 81 | root_dir = os.path.dirname(os.path.abspath(__file__)) 82 | # data_type = 'multiview' # Options: 'image2loop', 'video_frames', 'multiview' 83 | dataset_folder = os.path.join(root_dir, 'test_data') 84 | output_folder = os.path.join(root_dir, 'output', f'{data_type}_exp') 85 | os.makedirs(output_folder, exist_ok=True) 86 | 87 | # Model parameters 88 | model_card = "stabilityai/stable-video-diffusion-img2vid-xt" 89 | fps_value = 7 90 | motion_id = 127 91 | random_seed = 42 92 | num_inference_steps = 50 93 | jump_n_sample = 2 94 | jump_length = 5 95 | repeat_step_ratio = 0.8 96 | noise_scale_ratio = 1.0 97 | generator = torch.manual_seed(random_seed) 98 | 99 | # Load data-specific settings 100 | if data_type == 'image2loop': 101 | video_frames = get_dataset(dataset_folder, data_type='loop') 102 | fps_value = 7 103 | motion_id = 127 104 | jump_n_sample = 2 105 | jump_length = 5 106 | repeat_step_ratio = 0.8 107 | elif data_type == 'video_frames': 108 | video_frames = get_dataset(dataset_folder, data_type='frame', filter_keyword='video_frames') 109 | fps_value = 7 110 | motion_id = 127 111 | jump_n_sample = 2 112 | jump_length = 5 113 | repeat_step_ratio = 0.8 114 | noise_scale_ratio = .95 115 | 116 | elif data_type == 'gym_motion': 117 | video_frames = get_dataset(dataset_folder, data_type='frame', filter_keyword='gym_motion') 118 | fps_value = 17 119 | motion_id = 10 120 | jump_n_sample = 2 121 | jump_length = 5 122 | repeat_step_ratio = 0.8 123 | elif data_type == 'multiview': 124 | video_frames = get_dataset(dataset_folder, data_type='multiview') 125 | fps_value = 7 126 | motion_id = 127 127 | jump_n_sample = 2 128 | jump_length = 5 129 | repeat_step_ratio = 0.8 130 | noise_scale_ratio = 1 131 | 132 | else: 133 | raise ValueError(f"Unsupported data_type: {data_type}") 134 | 135 | # Model directory 136 | re_steps = int((1 - repeat_step_ratio) * num_inference_steps) 137 | model_folder_name = f"{model_card.split('/')[-1]}_fps{fps_value}_id{motion_id}_s-num{num_inference_steps}_re{re_steps}_{jump_length}_{jump_n_sample}_{noise_scale_ratio}" 138 | model_folder = os.path.join(output_folder, model_folder_name) 139 | os.makedirs(model_folder, exist_ok=True) 140 | 141 | # Process each image pair 142 | for idx in tqdm.tqdm(range(len(video_frames))): 143 | image_pair = video_frames[idx] 144 | start_frame = img_preprocess(image_pair[0]) 145 | end_frame = img_preprocess(image_pair[1]) 146 | 147 | # Generate frame folder name 148 | base_name_start = os.path.splitext(os.path.basename(image_pair[0]))[0] 149 | base_name_end = os.path.splitext(os.path.basename(image_pair[1]))[0] 150 | dir_name = os.path.basename(os.path.dirname(image_pair[0])) 151 | 152 | if data_type == 'image2loop': 153 | frame_folder_name = f"{dir_name}_{base_name_start}" 154 | else: 155 | frame_folder_name = f"{dir_name}_{base_name_start}_{base_name_end}" 156 | 157 | frame_folder = os.path.join(model_folder, frame_folder_name) 158 | os.makedirs(frame_folder, exist_ok=True) 159 | video_file = f"{frame_folder}.mp4" 160 | 161 | frames = time_reversal_fusion( 162 | model_card, start_frame, end_frame, num_inference_steps, fps_value, jump_n_sample, jump_length, repeat_step_ratio, noise_scale_ratio, motion_id, generator 163 | ) 164 | 165 | # Save frames 166 | for i, frame in enumerate(frames): 167 | frame.save(os.path.join(frame_folder, f'{i}.png')) 168 | 169 | # Export to video 170 | export_to_video(frames, video_file, fps=fps_value) 171 | 172 | if __name__ == '__main__': 173 | #different input flags for different datasets 174 | args = sys.argv[1:] 175 | data_type = args[0] 176 | 177 | main(data_type) 178 | -------------------------------------------------------------------------------- /test_data/Multiview_data/mipnerf360_lite/garden/frame_00006.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/Multiview_data/mipnerf360_lite/garden/frame_00006.JPG -------------------------------------------------------------------------------- /test_data/Multiview_data/mipnerf360_lite/garden/frame_00104.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/Multiview_data/mipnerf360_lite/garden/frame_00104.JPG -------------------------------------------------------------------------------- /test_data/gym_motion_2024_frames/arm_clip1/frame_00018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/gym_motion_2024_frames/arm_clip1/frame_00018.jpg -------------------------------------------------------------------------------- /test_data/gym_motion_2024_frames/arm_clip1/frame_00060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/gym_motion_2024_frames/arm_clip1/frame_00060.jpg -------------------------------------------------------------------------------- /test_data/video_frames/dolomite_clip3/frame_00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/video_frames/dolomite_clip3/frame_00000.jpg -------------------------------------------------------------------------------- /test_data/video_frames/dolomite_clip3/frame_00100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/video_frames/dolomite_clip3/frame_00100.jpg -------------------------------------------------------------------------------- /unet_spatio_temporal_condition.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import UNet2DConditionLoadersMixin 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor 11 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 12 | from diffusers.models.modeling_utils import ModelMixin 13 | from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 14 | 15 | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | @dataclass 20 | class UNetSpatioTemporalConditionOutput(BaseOutput): 21 | """ 22 | The output of [`UNetSpatioTemporalConditionModel`]. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 26 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 27 | """ 28 | 29 | sample: torch.FloatTensor = None 30 | haha: str = 'haha' 31 | 32 | 33 | class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 34 | r""" 35 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 36 | shaped output. 37 | 38 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 39 | for all models (such as downloading or saving). 40 | 41 | Parameters: 42 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 43 | Height and width of input/output sample. 44 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 45 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 46 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 47 | The tuple of downsample blocks to use. 48 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 49 | The tuple of upsample blocks to use. 50 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 51 | The tuple of output channels for each block. 52 | addition_time_embed_dim: (`int`, defaults to 256): 53 | Dimension to to encode the additional time ids. 54 | projection_class_embeddings_input_dim (`int`, defaults to 768): 55 | The dimension of the projection of encoded `added_time_ids`. 56 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 57 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 58 | The dimension of the cross attention features. 59 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 60 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 61 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 62 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 63 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 64 | The number of attention heads. 65 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 66 | """ 67 | 68 | _supports_gradient_checkpointing = True 69 | 70 | @register_to_config 71 | def __init__( 72 | self, 73 | sample_size: Optional[int] = None, 74 | in_channels: int = 8, 75 | out_channels: int = 4, 76 | down_block_types: Tuple[str] = ( 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "CrossAttnDownBlockSpatioTemporal", 80 | "DownBlockSpatioTemporal", 81 | ), 82 | up_block_types: Tuple[str] = ( 83 | "UpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | "CrossAttnUpBlockSpatioTemporal", 87 | ), 88 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 89 | addition_time_embed_dim: int = 256, 90 | projection_class_embeddings_input_dim: int = 768, 91 | layers_per_block: Union[int, Tuple[int]] = 2, 92 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 93 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 94 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 95 | num_frames: int = 25, 96 | ): 97 | super().__init__() 98 | 99 | self.sample_size = sample_size 100 | 101 | # Check inputs 102 | if len(down_block_types) != len(up_block_types): 103 | raise ValueError( 104 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 105 | ) 106 | 107 | if len(block_out_channels) != len(down_block_types): 108 | raise ValueError( 109 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 110 | ) 111 | 112 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 113 | raise ValueError( 114 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 115 | ) 116 | 117 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 118 | raise ValueError( 119 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 120 | ) 121 | 122 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 123 | raise ValueError( 124 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 125 | ) 126 | 127 | # input 128 | self.conv_in = nn.Conv2d( 129 | in_channels, 130 | block_out_channels[0], 131 | kernel_size=3, 132 | padding=1, 133 | ) 134 | 135 | # time 136 | time_embed_dim = block_out_channels[0] * 4 137 | # import pdb; pdb.set_trace() 138 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 139 | timestep_input_dim = block_out_channels[0] 140 | 141 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 142 | 143 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 145 | 146 | self.down_blocks = nn.ModuleList([]) 147 | self.up_blocks = nn.ModuleList([]) 148 | 149 | if isinstance(num_attention_heads, int): 150 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 151 | 152 | if isinstance(cross_attention_dim, int): 153 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 154 | 155 | if isinstance(layers_per_block, int): 156 | layers_per_block = [layers_per_block] * len(down_block_types) 157 | 158 | if isinstance(transformer_layers_per_block, int): 159 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 160 | 161 | blocks_time_embed_dim = time_embed_dim 162 | 163 | # down 164 | output_channel = block_out_channels[0] 165 | for i, down_block_type in enumerate(down_block_types): 166 | input_channel = output_channel 167 | output_channel = block_out_channels[i] 168 | is_final_block = i == len(block_out_channels) - 1 169 | 170 | down_block = get_down_block( 171 | down_block_type, 172 | num_layers=layers_per_block[i], 173 | transformer_layers_per_block=transformer_layers_per_block[i], 174 | in_channels=input_channel, 175 | out_channels=output_channel, 176 | temb_channels=blocks_time_embed_dim, 177 | add_downsample=not is_final_block, 178 | resnet_eps=1e-5, 179 | cross_attention_dim=cross_attention_dim[i], 180 | num_attention_heads=num_attention_heads[i], 181 | resnet_act_fn="silu", 182 | ) 183 | self.down_blocks.append(down_block) 184 | 185 | # mid 186 | self.mid_block = UNetMidBlockSpatioTemporal( 187 | block_out_channels[-1], 188 | temb_channels=blocks_time_embed_dim, 189 | transformer_layers_per_block=transformer_layers_per_block[-1], 190 | cross_attention_dim=cross_attention_dim[-1], 191 | num_attention_heads=num_attention_heads[-1], 192 | ) 193 | 194 | # count how many layers upsample the images 195 | self.num_upsamplers = 0 196 | 197 | # up 198 | reversed_block_out_channels = list(reversed(block_out_channels)) 199 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 200 | reversed_layers_per_block = list(reversed(layers_per_block)) 201 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 202 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 203 | 204 | output_channel = reversed_block_out_channels[0] 205 | for i, up_block_type in enumerate(up_block_types): 206 | is_final_block = i == len(block_out_channels) - 1 207 | 208 | prev_output_channel = output_channel 209 | output_channel = reversed_block_out_channels[i] 210 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 211 | 212 | # add upsample block for all BUT final layer 213 | if not is_final_block: 214 | add_upsample = True 215 | self.num_upsamplers += 1 216 | else: 217 | add_upsample = False 218 | 219 | up_block = get_up_block( 220 | up_block_type, 221 | num_layers=reversed_layers_per_block[i] + 1, 222 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 223 | in_channels=input_channel, 224 | out_channels=output_channel, 225 | prev_output_channel=prev_output_channel, 226 | temb_channels=blocks_time_embed_dim, 227 | add_upsample=add_upsample, 228 | resnet_eps=1e-5, 229 | resolution_idx=i, 230 | cross_attention_dim=reversed_cross_attention_dim[i], 231 | num_attention_heads=reversed_num_attention_heads[i], 232 | resnet_act_fn="silu", 233 | ) 234 | self.up_blocks.append(up_block) 235 | prev_output_channel = output_channel 236 | 237 | # out 238 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 239 | self.conv_act = nn.SiLU() 240 | 241 | self.conv_out = nn.Conv2d( 242 | block_out_channels[0], 243 | out_channels, 244 | kernel_size=3, 245 | padding=1, 246 | ) 247 | 248 | @property 249 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 250 | r""" 251 | Returns: 252 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 253 | indexed by its weight name. 254 | """ 255 | # set recursively 256 | processors = {} 257 | 258 | def fn_recursive_add_processors( 259 | name: str, 260 | module: torch.nn.Module, 261 | processors: Dict[str, AttentionProcessor], 262 | ): 263 | if hasattr(module, "get_processor"): 264 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 265 | 266 | for sub_name, child in module.named_children(): 267 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 268 | 269 | return processors 270 | 271 | for name, module in self.named_children(): 272 | fn_recursive_add_processors(name, module, processors) 273 | 274 | return processors 275 | 276 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 277 | r""" 278 | Sets the attention processor to use to compute attention. 279 | 280 | Parameters: 281 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 282 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 283 | for **all** `Attention` layers. 284 | 285 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 286 | processor. This is strongly recommended when setting trainable attention processors. 287 | 288 | """ 289 | count = len(self.attn_processors.keys()) 290 | 291 | if isinstance(processor, dict) and len(processor) != count: 292 | raise ValueError( 293 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 294 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 295 | ) 296 | 297 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 298 | if hasattr(module, "set_processor"): 299 | if not isinstance(processor, dict): 300 | module.set_processor(processor) 301 | else: 302 | module.set_processor(processor.pop(f"{name}.processor")) 303 | 304 | for sub_name, child in module.named_children(): 305 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 306 | 307 | for name, module in self.named_children(): 308 | fn_recursive_attn_processor(name, module, processor) 309 | 310 | def set_default_attn_processor(self): 311 | """ 312 | Disables custom attention processors and sets the default attention implementation. 313 | """ 314 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 315 | processor = AttnProcessor() 316 | else: 317 | raise ValueError( 318 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 319 | ) 320 | 321 | self.set_attn_processor(processor) 322 | 323 | def _set_gradient_checkpointing(self, module, value=False): 324 | if hasattr(module, "gradient_checkpointing"): 325 | module.gradient_checkpointing = value 326 | 327 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 328 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 329 | """ 330 | Sets the attention processor to use [feed forward 331 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 332 | 333 | Parameters: 334 | chunk_size (`int`, *optional*): 335 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 336 | over each tensor of dim=`dim`. 337 | dim (`int`, *optional*, defaults to `0`): 338 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 339 | or dim=1 (sequence length). 340 | """ 341 | if dim not in [0, 1]: 342 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 343 | 344 | # By default chunk size is 1 345 | chunk_size = chunk_size or 1 346 | 347 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 348 | if hasattr(module, "set_chunk_feed_forward"): 349 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 350 | 351 | for child in module.children(): 352 | fn_recursive_feed_forward(child, chunk_size, dim) 353 | 354 | for module in self.children(): 355 | fn_recursive_feed_forward(module, chunk_size, dim) 356 | 357 | def forward( 358 | self, 359 | sample: torch.FloatTensor, 360 | timestep: Union[torch.Tensor, float, int], 361 | encoder_hidden_states: torch.Tensor, 362 | added_time_ids: torch.Tensor, 363 | return_dict: bool = True, 364 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 365 | r""" 366 | The [`UNetSpatioTemporalConditionModel`] forward method. 367 | 368 | Args: 369 | sample (`torch.FloatTensor`): 370 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 371 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 372 | encoder_hidden_states (`torch.FloatTensor`): 373 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 374 | added_time_ids: (`torch.FloatTensor`): 375 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 376 | embeddings and added to the time embeddings. 377 | return_dict (`bool`, *optional*, defaults to `True`): 378 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 379 | tuple. 380 | Returns: 381 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 382 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 383 | a `tuple` is returned where the first element is the sample tensor. 384 | """ 385 | # 1. time 386 | timesteps = timestep 387 | if not torch.is_tensor(timesteps): 388 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 389 | # This would be a good case for the `match` statement (Python 3.10+) 390 | is_mps = sample.device.type == "mps" 391 | if isinstance(timestep, float): 392 | dtype = torch.float32 if is_mps else torch.float64 393 | else: 394 | dtype = torch.int32 if is_mps else torch.int64 395 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 396 | elif len(timesteps.shape) == 0: 397 | timesteps = timesteps[None].to(sample.device) 398 | 399 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 400 | batch_size, num_frames = sample.shape[:2] 401 | timesteps = timesteps.expand(batch_size) 402 | 403 | t_emb = self.time_proj(timesteps) 404 | 405 | # `Timesteps` does not contain any weights and will always return f32 tensors 406 | # but time_embedding might actually be running in fp16. so we need to cast here. 407 | # there might be better ways to encapsulate this. 408 | t_emb = t_emb.to(dtype=sample.dtype) 409 | 410 | emb = self.time_embedding(t_emb) 411 | 412 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 413 | time_embeds = time_embeds.reshape((batch_size, -1)) 414 | time_embeds = time_embeds.to(emb.dtype) 415 | aug_emb = self.add_embedding(time_embeds) 416 | emb = emb + aug_emb 417 | 418 | # Flatten the batch and frames dimensions 419 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 420 | sample = sample.flatten(0, 1) 421 | # Repeat the embeddings num_video_frames times 422 | # emb: [batch, channels] -> [batch * frames, channels] 423 | emb = emb.repeat_interleave(num_frames, dim=0) 424 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 425 | 426 | if encoder_hidden_states.shape[0]==batch_size: 427 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 428 | 429 | # 2. pre-process 430 | sample = self.conv_in(sample) 431 | 432 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 433 | 434 | down_block_res_samples = (sample,) 435 | for downsample_block in self.down_blocks: 436 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 437 | sample, res_samples = downsample_block( 438 | hidden_states=sample, 439 | temb=emb, 440 | encoder_hidden_states=encoder_hidden_states, 441 | image_only_indicator=image_only_indicator, 442 | ) 443 | else: 444 | sample, res_samples = downsample_block( 445 | hidden_states=sample, 446 | temb=emb, 447 | image_only_indicator=image_only_indicator, 448 | ) 449 | 450 | down_block_res_samples += res_samples 451 | 452 | # 4. mid 453 | sample = self.mid_block( 454 | hidden_states=sample, 455 | temb=emb, 456 | encoder_hidden_states=encoder_hidden_states, 457 | image_only_indicator=image_only_indicator, 458 | ) 459 | 460 | # 5. up 461 | for i, upsample_block in enumerate(self.up_blocks): 462 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 463 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 464 | 465 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 466 | sample = upsample_block( 467 | hidden_states=sample, 468 | temb=emb, 469 | res_hidden_states_tuple=res_samples, 470 | encoder_hidden_states=encoder_hidden_states, 471 | image_only_indicator=image_only_indicator, 472 | ) 473 | else: 474 | sample = upsample_block( 475 | hidden_states=sample, 476 | temb=emb, 477 | res_hidden_states_tuple=res_samples, 478 | image_only_indicator=image_only_indicator, 479 | ) 480 | 481 | # 6. post-process 482 | sample = self.conv_norm_out(sample) 483 | sample = self.conv_act(sample) 484 | sample = self.conv_out(sample) 485 | 486 | # 7. Reshape back to original shape 487 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 488 | 489 | if not return_dict: 490 | 491 | return (sample,) 492 | 493 | return UNetSpatioTemporalConditionOutput(sample=sample) 494 | --------------------------------------------------------------------------------