├── README.md
├── datasets.py
├── pipeline_stable_video_diffusion_re.py
├── requirements.txt
├── scheduling_euler_discrete_resampling.py
├── svd_sequential_re.py
├── test_data
├── Multiview_data
│ └── mipnerf360_lite
│ │ └── garden
│ │ ├── frame_00006.JPG
│ │ └── frame_00104.JPG
├── gym_motion_2024_frames
│ └── arm_clip1
│ │ ├── frame_00018.jpg
│ │ └── frame_00060.jpg
└── video_frames
│ └── dolomite_clip3
│ ├── frame_00000.jpg
│ └── frame_00100.jpg
└── unet_spatio_temporal_condition.py
/README.md:
--------------------------------------------------------------------------------
1 | # Time_Reversal_Fusion
2 |
3 |
4 |
5 |
6 | This is the official Pytorch implementation of Time Reversal Fusion (accepted at ECCV2024).
7 | We proposed a new sampling strategy called Time-Reversal Fusion (TRF), which enables the image-to-video model to generate sequences toward a given end frame without any tuning or back-propagated optimization. We define this new task as "Bounded Generation" and it generalizes three scenarios in computer vision:
8 | 1) Generating subject motion with the two bound images capturing a moving subject.
9 | 2) Synthesizing camera motion using two images captured from different viewpoints of a static scene.
10 | 3) Achieving video looping by using the same image for both bounds.
11 |
12 | Please refer to the [arXiv paper](https://arxiv.org/abs/2403.14611) for more technical details and [Project Page](time-reversal.github.io) for more video results.
13 |
14 | ## Todo
15 | - [x] TRF code release
16 | - [x] Bounded Generation Dataset release
17 | - [ ] Gradio demo
18 |
19 | ## Getting Started
20 | Clone the repo:
21 | ```bash
22 | git clone https://github.com/HavenFeng/time_reversal/
23 | cd time_reveral
24 | ```
25 |
26 | ### Requirements
27 | * Python 3.10 (numpy, skimage, scipy, opencv)
28 | * Diffusers
29 | * PyTorch >= 2.0.1 (Diffusers compatible)
30 | You can run
31 | ```bash
32 | pip install -r requirements.txt
33 | ```
34 | If you encountered errors when installing Diffusers, please follow the [official installation guide](https://huggingface.co/docs/diffusers/en/installation) to re-install the library.
35 |
36 | ### Usage
37 | 1. **Run inference with samples in paper**
38 | ```bash
39 | python svd_sequential_re.py multiview
40 | ```
41 | Check different task results with "multiview", "video frames", "gym_motion" and "image2loop", the generated results can be found in the ./output folder.
42 | 2. **TRF++ (add LoRA "patches" to enhance domain-specific task)**
43 | TRF was designed to probe SVD's bounded generation capabilities without fine-tuning, but we've observed SVD's biases in subject and camera motion, as well as sensitivity to conditioning factors like FPS and motion intensity. These required careful parameter tuning for different inputs. To improve generation quality and robustness for other downstream tasks, we fine-tuned LoRA "patch" on various domain-specific datasets, better supporting long-range linear motion and extreme 3D views generation.
44 | ```
45 | coming soon
46 | ```
47 |
48 | ## Evaluation
49 | We evaluate our methods with the [Bounded Generation Dataset](https://drive.google.com/drive/folders/1qH4yx5954Bm6h1E4olEqgV0pSJicdkNu?usp=sharing) compared to the domain-specific state-of-the-art methods.
50 | For more details of the evaluation, please check our [arXiv paper](https://arxiv.org/abs/2403.14611).
51 |
52 |
53 | ## Citation
54 | If you find our work useful to your research, please consider citing:
55 | ```
56 | @inproceedings{Feng:TRF:ECCV2024,
57 | title = {Explorative In-betweening of Time and Space},
58 | author = {Feng, Haiwen and Ding, Zheng and Xia, Zhihao and Niklaus, Simon and Abrevaya, Victoria and Black, Michael J. and Zhang Xuaner},
59 | booktitle = {European Conference on Computer Vision},
60 | year = {2024}
61 | }
62 | ```
63 |
64 | ## Notes
65 | The video form of of our teaser image:
66 |
67 | https://github.com/user-attachments/assets/b984c57c-a450-4071-996c-dc3df1445e79
68 |
69 | More domain-specific lora patch models will be released soon
70 |
71 | ## License
72 | This code and model are available for non-commercial scientific research purposes.
73 |
74 | ## Acknowledgements
75 | We would like to thank recent baseline works that allow us to easily perform quantitative and qualitative comparisons :)
76 | [FILM](https://github.com/google-research/frame-interpolation),
77 | [Wide-Baseline](https://github.com/yilundu/cross_attention_renderer),
78 | [Text2Cinemagraph](https://github.com/text2cinemagraph/text2cinemagraph/tree/master),
79 |
80 | This work was partly supported by the German Federal Ministry of Education and Research (BMBF): Tuebingen AI Center, FKZ: 01IS18039B
81 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 |
5 | def get_dataset(dataset_folder, data_type='frame', filter_keyword=None):
6 | """
7 | General function to retrieve datasets based on the data_type.
8 |
9 | Parameters:
10 | dataset_folder (str): The root directory of the dataset.
11 | data_type (str): Type of dataset to retrieve. Options are 'frame', 'multiview', 'loop'.
12 | filter_keyword (str, optional): A keyword to filter the dataset.
13 |
14 | Returns:
15 | np.ndarray: An array of image pairs.
16 | """
17 | if data_type == 'loop':
18 | # For loop dataset, pairs are the same image
19 | image_paths = sorted(glob.glob(os.path.join(dataset_folder, 'img2loop', '*', '*')))
20 | if filter_keyword:
21 | image_paths = [path for path in image_paths if filter_keyword in path]
22 | video_frames = np.array([[path, path] for path in image_paths])
23 | else:
24 | # For 'frame' and 'multiview' datasets
25 | if data_type == 'frame':
26 | video_frame_selected = [
27 | 'video_frames/dolomite_clip3/frame_00000.jpg',
28 | 'video_frames/dolomite_clip3/frame_00100.jpg',
29 |
30 | 'gym_motion_2024_frames/arm_clip1/frame_00018.jpg',
31 | 'gym_motion_2024_frames/arm_clip1/frame_00060.jpg',
32 | ]
33 | elif data_type == 'multiview':
34 | video_frame_selected = [
35 | 'Multiview_data/mipnerf360_lite/garden/frame_00006.JPG',
36 | 'Multiview_data/mipnerf360_lite/garden/frame_00104.JPG',
37 | ]
38 | else:
39 | raise ValueError(f"Unsupported data_type: {data_type}")
40 |
41 | # Prepend dataset_folder to paths
42 | video_frame_selected = [os.path.join(dataset_folder, path) for path in video_frame_selected]
43 | if filter_keyword:
44 | video_frame_selected = [path for path in video_frame_selected if filter_keyword in path]
45 |
46 | # Reshape into pairs
47 | video_frames = np.array(video_frame_selected).reshape(-1, 2)
48 |
49 | return video_frames
50 |
--------------------------------------------------------------------------------
/pipeline_stable_video_diffusion_re.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from dataclasses import dataclass
17 | from typing import Callable, Dict, List, Optional, Union
18 |
19 | import numpy as np
20 | import PIL.Image
21 | import torch
22 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
23 |
24 | from diffusers.image_processor import VaeImageProcessor
25 | from diffusers.models import AutoencoderKLTemporalDecoder
26 | from unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
27 | from scheduling_euler_discrete_resampling import EulerDiscreteScheduler
28 |
29 | from diffusers.utils import BaseOutput, logging
30 | from diffusers.utils.torch_utils import randn_tensor
31 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32 |
33 |
34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35 |
36 |
37 | def _append_dims(x, target_dims):
38 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
39 | dims_to_append = target_dims - x.ndim
40 | if dims_to_append < 0:
41 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
42 | return x[(...,) + (None,) * dims_to_append]
43 |
44 |
45 | def tensor2vid(video: torch.Tensor, processor, output_type="np"):
46 |
47 | batch_size, channels, num_frames, height, width = video.shape
48 | outputs = []
49 | for batch_idx in range(batch_size):
50 | batch_vid = video[batch_idx].permute(1, 0, 2, 3)
51 | batch_output = processor.postprocess(batch_vid, output_type)
52 |
53 | outputs.append(batch_output)
54 |
55 | return outputs
56 |
57 |
58 | def interpolate_spherical(p0, p1, fract_mixing: float):
59 | r"""
60 | borrowed from latentblending repo
61 | Helper function to correctly mix two random variables using spherical interpolation.
62 | See https://en.wikipedia.org/wiki/Slerp
63 | The function will always cast up to float64 for sake of extra 4.
64 | Args:
65 | p0:
66 | First tensor for interpolation
67 | p1:
68 | Second tensor for interpolation
69 | fract_mixing: float
70 | Mixing coefficient of interval [0, 1].
71 | 0 will return in p0
72 | 1 will return in p1
73 | 0.x will return a mix between both preserving angular velocity.
74 | """
75 |
76 | if p0.dtype == torch.float16:
77 | recast_to = 'fp16'
78 | else:
79 | recast_to = 'fp32'
80 |
81 | p0 = p0.double()
82 | p1 = p1.double()
83 | norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
84 | epsilon = 1e-7
85 | dot = torch.sum(p0 * p1) / norm
86 | dot = dot.clamp(-1 + epsilon, 1 - epsilon)
87 |
88 | theta_0 = torch.arccos(dot)
89 | sin_theta_0 = torch.sin(theta_0)
90 | theta_t = theta_0 * fract_mixing
91 | s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
92 | s1 = torch.sin(theta_t) / sin_theta_0
93 | interp = p0 * s0 + p1 * s1
94 |
95 | if recast_to == 'fp16':
96 | interp = interp.half()
97 | elif recast_to == 'fp32':
98 | interp = interp.float()
99 |
100 | return interp
101 |
102 |
103 | def interpolate_linear(p0, p1, fract_mixing):
104 | r"""
105 | Helper function to mix two variables using standard linear interpolation.
106 | Args:
107 | p0:
108 | First tensor / np.ndarray for interpolation
109 | p1:
110 | Second tensor / np.ndarray for interpolation
111 | fract_mixing: float
112 | Mixing coefficient of interval [0, 1].
113 | 0 will return in p0
114 | 1 will return in p1
115 | 0.x will return a linear mix between both.
116 | """
117 | reconvert_uint8 = False
118 | if type(p0) is np.ndarray and p0.dtype == 'uint8':
119 | reconvert_uint8 = True
120 | p0 = p0.astype(np.float64)
121 |
122 | if type(p1) is np.ndarray and p1.dtype == 'uint8':
123 | reconvert_uint8 = True
124 | p1 = p1.astype(np.float64)
125 |
126 | interp = (1 - fract_mixing) * p0 + fract_mixing * p1
127 |
128 | if reconvert_uint8:
129 | interp = np.clip(interp, 0, 255).astype(np.uint8)
130 |
131 | return interp
132 |
133 |
134 | @dataclass
135 | class StableVideoDiffusionPipelineOutput(BaseOutput):
136 | r"""
137 | Output class for zero-shot text-to-video pipeline.
138 |
139 | Args:
140 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
141 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
142 | num_channels)`.
143 | """
144 |
145 | frames: Union[List[PIL.Image.Image], np.ndarray]
146 |
147 |
148 | class StableVideoDiffusionPipeline_Custom(DiffusionPipeline):
149 | r"""
150 | Pipeline to generate video from an input image using Stable Video Diffusion.
151 |
152 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
153 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
154 |
155 | Args:
156 | vae ([`AutoencoderKL`]):
157 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
158 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
159 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
160 | unet ([`UNetSpatioTemporalConditionModel`]):
161 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
162 | scheduler ([`EulerDiscreteScheduler`]):
163 | A scheduler to be used in combination with `unet` to denoise the encoded image latents.
164 | feature_extractor ([`~transformers.CLIPImageProcessor`]):
165 | A `CLIPImageProcessor` to extract features from generated images.
166 | """
167 |
168 | model_cpu_offload_seq = "image_encoder->unet->vae"
169 | _callback_tensor_inputs = ["latents"]
170 |
171 | def __init__(
172 | self,
173 | vae: AutoencoderKLTemporalDecoder,
174 | image_encoder: CLIPVisionModelWithProjection,
175 | unet: UNetSpatioTemporalConditionModel,
176 | scheduler: EulerDiscreteScheduler, #EulerDiscreteScheduler,
177 | feature_extractor: CLIPImageProcessor,
178 | ):
179 | super().__init__()
180 |
181 | self.register_modules(
182 | vae=vae,
183 | image_encoder=image_encoder,
184 | unet=unet,
185 | scheduler=scheduler,
186 | feature_extractor=feature_extractor,
187 | )
188 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
189 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
190 |
191 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
192 | dtype = next(self.image_encoder.parameters()).dtype
193 |
194 | if not isinstance(image, torch.Tensor):
195 | image = self.image_processor.pil_to_numpy(image)
196 | image = self.image_processor.numpy_to_pt(image)
197 |
198 | # We normalize the image before resizing to match with the original implementation.
199 | # Then we unnormalize it after resizing.
200 | image = image * 2.0 - 1.0
201 | image = _resize_with_antialiasing(image, (224, 224))
202 | image = (image + 1.0) / 2.0
203 |
204 | # Normalize the image with for CLIP input
205 | image = self.feature_extractor(
206 | images=image,
207 | do_normalize=True,
208 | do_center_crop=False,
209 | do_resize=False,
210 | do_rescale=False,
211 | return_tensors="pt",
212 | ).pixel_values
213 |
214 | image = image.to(device=device, dtype=dtype)
215 | image_embeddings = self.image_encoder(image).image_embeds
216 | image_embeddings = image_embeddings.unsqueeze(1)
217 |
218 | # duplicate image embeddings for each generation per prompt, using mps friendly method
219 | bs_embed, seq_len, _ = image_embeddings.shape
220 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
221 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
222 |
223 | if do_classifier_free_guidance:
224 | negative_image_embeddings = torch.zeros_like(image_embeddings)
225 |
226 | # For classifier free guidance, we need to do two forward passes.
227 | # Here we concatenate the unconditional and text embeddings into a single batch
228 | # to avoid doing two forward passes
229 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
230 |
231 | return image_embeddings
232 |
233 | def _encode_vae_image(
234 | self,
235 | image: torch.Tensor,
236 | device,
237 | num_videos_per_prompt,
238 | do_classifier_free_guidance,
239 | ):
240 | image = image.to(device=device)
241 | image_latents = self.vae.encode(image).latent_dist.mode()
242 |
243 | if do_classifier_free_guidance:
244 | negative_image_latents = torch.zeros_like(image_latents)
245 |
246 | # For classifier free guidance, we need to do two forward passes.
247 | # Here we concatenate the unconditional and text embeddings into a single batch
248 | # to avoid doing two forward passes
249 | image_latents = torch.cat([negative_image_latents, image_latents])
250 |
251 | # duplicate image_latents for each generation per prompt, using mps friendly method
252 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
253 |
254 | return image_latents
255 |
256 | def _get_add_time_ids(
257 | self,
258 | fps,
259 | motion_bucket_id,
260 | noise_aug_strength,
261 | dtype,
262 | batch_size,
263 | num_videos_per_prompt,
264 | do_classifier_free_guidance,
265 | ):
266 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
267 |
268 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
269 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
270 |
271 | if expected_add_embed_dim != passed_add_embed_dim:
272 | raise ValueError(
273 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
274 | )
275 |
276 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
277 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
278 |
279 | if do_classifier_free_guidance:
280 | add_time_ids = torch.cat([add_time_ids, add_time_ids])
281 |
282 | return add_time_ids
283 |
284 | def decode_latents(self, latents, num_frames, decode_chunk_size=14):
285 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
286 | latents = latents.flatten(0, 1)
287 |
288 | latents = 1 / self.vae.config.scaling_factor * latents
289 |
290 | accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
291 |
292 | # decode decode_chunk_size frames at a time to avoid OOM
293 | frames = []
294 | for i in range(0, latents.shape[0], decode_chunk_size):
295 | num_frames_in = latents[i : i + decode_chunk_size].shape[0]
296 | decode_kwargs = {}
297 | if accepts_num_frames:
298 | # we only pass num_frames_in if it's expected
299 | decode_kwargs["num_frames"] = num_frames_in
300 |
301 | frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
302 | frames.append(frame)
303 | frames = torch.cat(frames, dim=0)
304 |
305 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
306 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
307 |
308 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
309 | frames = frames.float()
310 | return frames
311 |
312 | def check_inputs(self, image, height, width):
313 | if (
314 | not isinstance(image, torch.Tensor)
315 | and not isinstance(image, PIL.Image.Image)
316 | and not isinstance(image, list)
317 | ):
318 | raise ValueError(
319 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
320 | f" {type(image)}"
321 | )
322 |
323 | if height % 8 != 0 or width % 8 != 0:
324 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
325 |
326 | def prepare_latents(
327 | self,
328 | batch_size,
329 | num_frames,
330 | num_channels_latents,
331 | height,
332 | width,
333 | dtype,
334 | device,
335 | generator,
336 | latents=None,
337 | ):
338 | shape = (
339 | batch_size,
340 | num_frames,
341 | num_channels_latents // 2,
342 | height // self.vae_scale_factor,
343 | width // self.vae_scale_factor,
344 | )
345 | if isinstance(generator, list) and len(generator) != batch_size:
346 | raise ValueError(
347 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
348 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
349 | )
350 | # import ipdb; ipdb.set_trace()
351 | if latents is None:
352 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
353 | else:
354 | latents = latents.to(device)
355 |
356 | # scale the initial noise by the standard deviation required by the scheduler
357 | latents = latents * self.scheduler.init_noise_sigma
358 | return latents
359 |
360 | def get_blend_weights(self, latents, step_id, num_frames, weight_type = 'non-linear', time_dir='forward'):
361 | #weight type: linear, non-linear, progression (progressively reduce the influence of backward frames)
362 | if weight_type == 'linear':
363 | blend_weights = torch.linspace(1, 0, steps=num_frames).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device)
364 | elif weight_type == 'non-linear':
365 | # import ipdb; ipdb.set_trace()
366 | b_value = 0.85
367 | blend_weights = torch.tensor([b_value ** x for x in range(num_frames)]).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device)
368 | if time_dir =='forward':
369 | blend_weights = torch.flip((1 - blend_weights), dims=[0])
370 |
371 | elif weight_type == 'atanh':
372 | # import ipdb; ipdb.set_trace()
373 | b_value = 0.95
374 | torch.linspace(-b_value, b_value, steps=num_frames)
375 | curve = torch.atanh(torch.linspace(-b_value, b_value, steps=num_frames)).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device=latents.device)
376 | curve_norm = (curve - curve.min()) / (curve.max() - curve.min())
377 | blend_weights = curve_norm
378 | if time_dir =='forward':
379 | blend_weights = torch.flip((1 - blend_weights), dims=[0])
380 |
381 | elif weight_type == 'non-linear-decay':
382 | b_list = torch.linspace(0.98, 0.5, steps=num_frames)
383 | a_value = 5
384 | b_value = b_list[step_id]
385 | blend_weights = [a_value * b_value ** x for x in num_frames]
386 | if time_dir =='forward':
387 | blend_weights = (1 - blend_weights).reverse()
388 |
389 | return blend_weights
390 |
391 | @property
392 | def guidance_scale(self):
393 | return self._guidance_scale
394 |
395 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
396 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
397 | # corresponds to doing no classifier free guidance.
398 | @property
399 | def do_classifier_free_guidance(self):
400 | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
401 |
402 | @property
403 | def num_timesteps(self):
404 | return self._num_timesteps
405 |
406 | @torch.no_grad()
407 | def __call__(
408 | self,
409 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
410 | end_frame: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
411 | height: int = 576,
412 | width: int = 1024,
413 | num_frames: Optional[int] = None,
414 | num_inference_steps: int = 25,
415 | min_guidance_scale: float = 1.0,
416 | max_guidance_scale: float = 3.0,
417 | fps: int = 7,
418 | motion_bucket_id: int = 127,
419 | noise_aug_strength: int = 0.02,
420 | decode_chunk_size: Optional[int] = None,
421 | num_videos_per_prompt: Optional[int] = 1,
422 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
423 | latents: Optional[torch.FloatTensor] = None,
424 | output_type: Optional[str] = "pil",
425 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
426 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
427 | return_dict: bool = True,
428 | same_noise_per_frame: bool = True,
429 | jump_length: int = 3,
430 | jump_n_sample: int = 2,
431 | repeat_step_ratio: float = 0.5,
432 | noise_scale_ratio: float = 0.5,
433 | custom_timestep_spacing: bool = True
434 | ):
435 | r"""
436 | The call function to the pipeline for generation.
437 |
438 | Args:
439 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
440 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
441 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
442 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
443 | The height in pixels of the generated image.
444 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
445 | The width in pixels of the generated image.
446 | num_frames (`int`, *optional*):
447 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
448 | num_inference_steps (`int`, *optional*, defaults to 25):
449 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
450 | expense of slower inference. This parameter is modulated by `strength`.
451 | min_guidance_scale (`float`, *optional*, defaults to 1.0):
452 | The minimum guidance scale. Used for the classifier free guidance with first frame.
453 | max_guidance_scale (`float`, *optional*, defaults to 3.0):
454 | The maximum guidance scale. Used for the classifier free guidance with last frame.
455 | fps (`int`, *optional*, defaults to 7):
456 | Frames per second. The rate at which the generated images shall be exported to a video after generation.
457 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
458 | motion_bucket_id (`int`, *optional*, defaults to 127):
459 | The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
460 | noise_aug_strength (`int`, *optional*, defaults to 0.02):
461 | The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
462 | decode_chunk_size (`int`, *optional*):
463 | The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
464 | between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
465 | for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
466 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
467 | The number of images to generate per prompt.
468 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
469 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
470 | generation deterministic.
471 | latents (`torch.FloatTensor`, *optional*):
472 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
473 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
474 | tensor is generated by sampling using the supplied random `generator`.
475 | output_type (`str`, *optional*, defaults to `"pil"`):
476 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
477 | callback_on_step_end (`Callable`, *optional*):
478 | A function that calls at the end of each denoising steps during the inference. The function is called
479 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
480 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
481 | `callback_on_step_end_tensor_inputs`.
482 | callback_on_step_end_tensor_inputs (`List`, *optional*):
483 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
484 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
485 | `._callback_tensor_inputs` attribute of your pipeline class.
486 | return_dict (`bool`, *optional*, defaults to `True`):
487 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
488 | plain tuple.
489 |
490 | Returns:
491 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
492 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
493 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
494 |
495 | Examples:
496 |
497 | ```py
498 | from diffusers import StableVideoDiffusionPipeline
499 | from diffusers.utils import load_image, export_to_video
500 |
501 | pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
502 | pipe.to("cuda")
503 |
504 | image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
505 | image = image.resize((1024, 576))
506 |
507 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
508 | export_to_video(frames, "generated.mp4", fps=7)
509 | ```
510 | """
511 | # import ipdb; ipdb.set_trace()
512 | # 0. Default height and width to unet
513 | height = height or self.unet.config.sample_size * self.vae_scale_factor
514 | width = width or self.unet.config.sample_size * self.vae_scale_factor
515 |
516 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
517 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
518 |
519 | # 1. Check inputs. Raise error if not correct
520 | self.check_inputs(image, height, width)
521 |
522 | # 2. Define call parameters
523 | if isinstance(image, PIL.Image.Image):
524 | batch_size = 1
525 | elif isinstance(image, list):
526 | batch_size = len(image)
527 | else:
528 | batch_size = image.shape[0]
529 | device = self._execution_device
530 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
531 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
532 | # corresponds to doing no classifier free guidance.
533 | do_classifier_free_guidance = max_guidance_scale > 1.0
534 | emb_cond = {}
535 | # 3. Encode input image
536 | image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
537 | if end_frame is not None:
538 | image_embeddings_end_frame = self._encode_image(end_frame, device, num_videos_per_prompt, do_classifier_free_guidance)
539 | emb_cond['forward'] = image_embeddings
540 | emb_cond['backward'] = image_embeddings_end_frame
541 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
542 | # is why it is reduced here.
543 | # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
544 | fps = fps - 1
545 |
546 | # 4. Encode input image using VAE
547 | image = self.image_processor.preprocess(image, height=height, width=width)
548 | if end_frame is not None:
549 | end_frame = self.image_processor.preprocess(end_frame, height=height, width=width)
550 |
551 | noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
552 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
553 |
554 | frame_cond = {}
555 |
556 | def image_to_image_latents(image):
557 | image = image + noise_aug_strength * noise
558 | if needs_upcasting:
559 | self.vae.to(dtype=torch.float32)
560 | image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
561 | image_latents = image_latents.to(image_embeddings.dtype)
562 | # cast back to fp16 if needed
563 | if needs_upcasting:
564 | self.vae.to(dtype=torch.float16)
565 | # Repeat the image latents for each frame so we can concatenate them with the noise
566 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
567 | image_latents = image_latents.unsqueeze(1).repeat(1,num_frames,1,1,1) #torch.Size([2, 25, 4, 72, 128])
568 | return image_latents
569 |
570 | frame_cond['forward'] = image_to_image_latents(image)
571 | frame_cond['backward'] = image_to_image_latents(end_frame)
572 |
573 | # 5. Get Added Time IDs
574 | added_time_ids = self._get_add_time_ids(
575 | fps,
576 | motion_bucket_id,
577 | noise_aug_strength,
578 | image_embeddings.dtype,
579 | batch_size,
580 | num_videos_per_prompt,
581 | do_classifier_free_guidance,
582 | )
583 | added_time_ids = added_time_ids.to(device)
584 |
585 | # 4. Prepare timesteps
586 | self.scheduler.set_timesteps(num_inference_steps)
587 |
588 | if custom_timestep_spacing:
589 | #todo: gradually decrease the number of jumps
590 | reverse_list = np.array(list(range(num_inference_steps)))[::-1]
591 | timesteps_idx = []
592 | jumps = {}
593 | for j in range(int(num_inference_steps * repeat_step_ratio), num_inference_steps - jump_length, jump_length):
594 | jumps[j] = jump_n_sample - 1
595 |
596 | t = num_inference_steps
597 | while t >= 1:
598 | t = t - 1
599 | timesteps_idx.append(t)
600 |
601 | if jumps.get(t, 0) > 0:
602 | jumps[t] = jumps[t] - 1
603 | for _ in range(jump_length):
604 | t = t + 1
605 | timesteps_idx.append(t)
606 | timesteps_idx = reverse_list[timesteps_idx]
607 |
608 | step_id_last = timesteps_idx[0] - 1
609 | timesteps = self.scheduler.timesteps
610 |
611 | # 5. Prepare latent variables
612 | num_channels_latents = self.unet.config.in_channels
613 | latents = self.prepare_latents(
614 | batch_size * num_videos_per_prompt,
615 | num_frames,
616 | num_channels_latents,
617 | height,
618 | width,
619 | image_embeddings.dtype,
620 | device,
621 | generator,
622 | latents,
623 | )
624 |
625 | # 7. Prepare guidance scale
626 | guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
627 | guidance_scale = guidance_scale.to(device, latents.dtype)
628 | guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
629 | guidance_scale = _append_dims(guidance_scale, latents.ndim)
630 |
631 | self._guidance_scale = guidance_scale
632 |
633 | count = torch.zeros_like(latents)
634 | value = torch.zeros_like(latents)
635 |
636 | # 8. Denoising loop
637 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
638 | self._num_timesteps = len(timesteps)
639 |
640 | with self.progress_bar(total=len(timesteps_idx)) as progress_bar:
641 | for i, step_id in enumerate(timesteps_idx):
642 |
643 | # import ipdb; ipdb.set_trace()
644 | t = timesteps[step_id]
645 | # expand the latents if we are doing classifier free guidance
646 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
647 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
648 | # import ipdb; ipdb.set_trace()
649 | count.zero_()
650 | value.zero_()
651 | for time_dir in ['forward', 'backward']:
652 | # Concatenate image_latents over channels dimention
653 |
654 | if step_id > step_id_last:
655 | image_latents = frame_cond[time_dir]
656 | image_embeddings = emb_cond[time_dir] if end_frame is not None else image_embeddings
657 | latent_model_input_new = torch.cat([latent_model_input, image_latents], dim=2)
658 | if time_dir == 'backward':
659 | latent_model_input_new = torch.flip(latent_model_input_new, dims=[1])
660 | latents = torch.flip(latents, dims=[1])
661 | # predict the noise residual
662 | noise_pred = self.unet(
663 | latent_model_input_new,
664 | t,
665 | encoder_hidden_states=image_embeddings,
666 | added_time_ids=added_time_ids,
667 | return_dict=False,
668 | )[0]
669 | # perform guidance
670 | if do_classifier_free_guidance:
671 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
672 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
673 | # compute the previous noisy sample x_t -> x_t-1
674 | self.scheduler._step_index = step_id
675 | latents_view_denoised = self.scheduler.step(noise_pred, t, latents, step_type=time_dir).prev_sample
676 | blend_weights = self.get_blend_weights(latents, step_id, num_frames, weight_type = 'linear', time_dir=time_dir)
677 | latents_view_denoised = latents_view_denoised * blend_weights
678 | if time_dir == 'backward':
679 | # import ipdb; ipdb.set_trace()
680 | latents_view_denoised = torch.flip(latents_view_denoised, dims=[1])
681 |
682 | else:
683 | latents_view_denoised = self.scheduler.undo_step(latents, step_id_last, ratio=noise_scale_ratio)
684 | if time_dir == 'backward':
685 | latents_view_denoised = 0.
686 |
687 | value += latents_view_denoised
688 | count += 0.5
689 |
690 | step_id_last = step_id
691 |
692 | latents = torch.where(count > 0, value / count, value)
693 |
694 | if callback_on_step_end is not None:
695 | callback_kwargs = {}
696 | for k in callback_on_step_end_tensor_inputs:
697 | callback_kwargs[k] = locals()[k]
698 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
699 | latents = callback_outputs.pop("latents", latents)
700 |
701 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
702 | progress_bar.update()
703 |
704 | if not output_type == "latent":
705 | # cast back to fp16 if needed
706 | if needs_upcasting:
707 | self.vae.to(dtype=torch.float16)
708 | frames = self.decode_latents(latents, num_frames, decode_chunk_size)
709 | frames = tensor2vid(frames, self.image_processor, output_type=output_type)
710 | else:
711 | frames = latents
712 |
713 | self.maybe_free_model_hooks()
714 |
715 | if not return_dict:
716 | return frames
717 |
718 | return StableVideoDiffusionPipelineOutput(frames=frames)
719 |
720 |
721 | # resizing utils
722 | # TODO: clean up later
723 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
724 | h, w = input.shape[-2:]
725 | factors = (h / size[0], w / size[1])
726 |
727 | # First, we have to determine sigma
728 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
729 | sigmas = (
730 | max((factors[0] - 1.0) / 2.0, 0.001),
731 | max((factors[1] - 1.0) / 2.0, 0.001),
732 | )
733 |
734 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
735 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
736 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
737 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
738 |
739 | # Make sure it is odd
740 | if (ks[0] % 2) == 0:
741 | ks = ks[0] + 1, ks[1]
742 |
743 | if (ks[1] % 2) == 0:
744 | ks = ks[0], ks[1] + 1
745 |
746 | input = _gaussian_blur2d(input, ks, sigmas)
747 |
748 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
749 | return output
750 |
751 |
752 | def _compute_padding(kernel_size):
753 | """Compute padding tuple."""
754 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
755 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
756 | if len(kernel_size) < 2:
757 | raise AssertionError(kernel_size)
758 | computed = [k - 1 for k in kernel_size]
759 |
760 | # for even kernels we need to do asymmetric padding :(
761 | out_padding = 2 * len(kernel_size) * [0]
762 |
763 | for i in range(len(kernel_size)):
764 | computed_tmp = computed[-(i + 1)]
765 |
766 | pad_front = computed_tmp // 2
767 | pad_rear = computed_tmp - pad_front
768 |
769 | out_padding[2 * i + 0] = pad_front
770 | out_padding[2 * i + 1] = pad_rear
771 |
772 | return out_padding
773 |
774 |
775 | def _filter2d(input, kernel):
776 | # prepare kernel
777 | b, c, h, w = input.shape
778 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
779 |
780 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
781 |
782 | height, width = tmp_kernel.shape[-2:]
783 |
784 | padding_shape: list[int] = _compute_padding([height, width])
785 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
786 |
787 | # kernel and input tensor reshape to align element-wise or batch-wise params
788 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
789 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
790 |
791 | # convolve the tensor with the kernel.
792 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
793 |
794 | out = output.view(b, c, h, w)
795 | return out
796 |
797 |
798 | def _gaussian(window_size: int, sigma):
799 | if isinstance(sigma, float):
800 | sigma = torch.tensor([[sigma]])
801 |
802 | batch_size = sigma.shape[0]
803 |
804 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
805 |
806 | if window_size % 2 == 0:
807 | x = x + 0.5
808 |
809 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
810 |
811 | return gauss / gauss.sum(-1, keepdim=True)
812 |
813 |
814 | def _gaussian_blur2d(input, kernel_size, sigma):
815 | if isinstance(sigma, tuple):
816 | sigma = torch.tensor([sigma], dtype=input.dtype)
817 | else:
818 | sigma = sigma.to(dtype=input.dtype)
819 |
820 | ky, kx = int(kernel_size[0]), int(kernel_size[1])
821 | bs = sigma.shape[0]
822 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
823 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
824 | out_x = _filter2d(input, kernel_x[..., None, :])
825 | out = _filter2d(out_x, kernel_y[..., None])
826 |
827 | return out
828 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.25.0
2 | asttokens==2.2.1
3 | av==11.0.0
4 | backcall==0.2.0
5 | beautifulsoup4==4.12.2
6 | blessed==1.20.0
7 | brotlipy==0.7.0
8 | certifi==2024.2.2
9 | cffi==1.15.1
10 | charset-normalizer==2.0.4
11 | cmake==3.28.1
12 | colorama==0.4.6
13 | contourpy==1.1.0
14 | cryptography==41.0.2
15 | cycler==0.11.0
16 | decorator==5.1.1
17 | decord==0.6.0
18 | diffusers==0.24.0
19 | einops==0.7.0
20 | enlighten==1.12.4
21 | executing==1.2.0
22 | filelock==3.12.3
23 | fonttools==4.42.1
24 | fsspec==2023.6.0
25 | gdown==4.7.1
26 | huggingface-hub==0.19.4
27 | idna==3.4
28 | importlib-metadata==6.8.0
29 | importlib-resources==6.0.1
30 | ipdb==0.13.13
31 | ipython==8.12.2
32 | jedi==0.19.0
33 | Jinja2==3.1.2
34 | kiwisolver==1.4.5
35 | lit==17.0.6
36 | MarkupSafe==2.1.3
37 | matplotlib==3.7.2
38 | matplotlib-inline==0.1.6
39 | mkl-fft==1.3.6
40 | mkl-random==1.2.2
41 | mkl-service==2.4.0
42 | mpmath==1.3.0
43 | mypy-extensions==1.0.0
44 | networkx==3.1
45 | numpy==1.24.3
46 | nvidia-cublas-cu11==11.10.3.66
47 | nvidia-cublas-cu12==12.1.3.1
48 | nvidia-cuda-cupti-cu11==11.7.101
49 | nvidia-cuda-cupti-cu12==12.1.105
50 | nvidia-cuda-nvrtc-cu11==11.7.99
51 | nvidia-cuda-nvrtc-cu12==12.1.105
52 | nvidia-cuda-runtime-cu11==11.7.99
53 | nvidia-cuda-runtime-cu12==12.1.105
54 | nvidia-cudnn-cu11==8.5.0.96
55 | nvidia-cudnn-cu12==8.9.2.26
56 | nvidia-cufft-cu11==10.9.0.58
57 | nvidia-cufft-cu12==11.0.2.54
58 | nvidia-curand-cu11==10.2.10.91
59 | nvidia-curand-cu12==10.3.2.106
60 | nvidia-cusolver-cu11==11.4.0.1
61 | nvidia-cusolver-cu12==11.4.5.107
62 | nvidia-cusparse-cu11==11.7.4.91
63 | nvidia-cusparse-cu12==12.1.0.106
64 | nvidia-nccl-cu11==2.14.3
65 | nvidia-nccl-cu12==2.18.1
66 | nvidia-nvjitlink-cu12==12.3.101
67 | nvidia-nvtx-cu11==11.7.91
68 | nvidia-nvtx-cu12==12.1.105
69 | opencv-contrib-python==4.9.0.80
70 | opencv-python==4.8.0.76
71 | packaging==23.1
72 | pandas==2.0.3
73 | parso==0.8.3
74 | pexpect==4.8.0
75 | pickleshare==0.7.5
76 | Pillow==9.4.0
77 | pip==23.2.1
78 | prefixed==0.7.0
79 | prompt-toolkit==3.0.39
80 | protobuf==4.24.2
81 | psutil==5.9.5
82 | ptyprocess==0.7.0
83 | pure-eval==0.2.2
84 | pycolmap==0.6.1
85 | pycparser==2.21
86 | Pygments==2.16.1
87 | pyOpenSSL==23.2.0
88 | pyparsing==3.0.9
89 | pyre-extensions==0.0.29
90 | PySocks==1.7.1
91 | python-dateutil==2.8.2
92 | pytorch-fid==0.3.0
93 | pytz==2024.1
94 | PyYAML==5.1.2
95 | regex==2022.7.9
96 | requests==2.31.0
97 | safetensors==0.3.3
98 | scipy==1.10.1
99 | sentencepiece==0.1.99
100 | setuptools==68.0.0
101 | six==1.16.0
102 | soupsieve==2.5
103 | splatting==0.0.0
104 | stack-data==0.6.2
105 | sympy==1.12
106 | tokenizers==0.15.0
107 | tomli==2.0.1
108 | torch==2.0.1
109 | torchaudio==0.13.1
110 | torchvision==0.14.1
111 | tqdm==4.66.1
112 | traitlets==5.9.0
113 | transformers==4.35.2
114 | triton==2.0.0
115 | typing_extensions==4.7.1
116 | typing-inspect==0.9.0
117 | tzdata==2024.1
118 | urllib3==1.26.16
119 | wcwidth==0.2.6
120 | wheel==0.38.4
121 | xformers==0.0.20
122 | zipp==3.16.2
123 |
--------------------------------------------------------------------------------
/scheduling_euler_discrete_resampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | from dataclasses import dataclass
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from diffusers.configuration_utils import ConfigMixin, register_to_config
23 | from diffusers.utils import BaseOutput, logging
24 | from diffusers.utils.torch_utils import randn_tensor
25 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26 |
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 |
31 | @dataclass
32 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
33 | class EulerDiscreteSchedulerOutput(BaseOutput):
34 | """
35 | Output class for the scheduler's `step` function output.
36 |
37 | Args:
38 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40 | denoising loop.
41 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
42 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43 | `pred_original_sample` can be used to preview progress or for guidance.
44 | """
45 |
46 | prev_sample: torch.FloatTensor
47 | pred_original_sample: Optional[torch.FloatTensor] = None
48 |
49 |
50 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
51 | def betas_for_alpha_bar(
52 | num_diffusion_timesteps,
53 | max_beta=0.999,
54 | alpha_transform_type="cosine",
55 | ):
56 | """
57 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
58 | (1-beta) over time from t = [0,1].
59 |
60 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
61 | to that part of the diffusion process.
62 |
63 |
64 | Args:
65 | num_diffusion_timesteps (`int`): the number of betas to produce.
66 | max_beta (`float`): the maximum beta to use; use values lower than 1 to
67 | prevent singularities.
68 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69 | Choose from `cosine` or `exp`
70 |
71 | Returns:
72 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73 | """
74 | if alpha_transform_type == "cosine":
75 |
76 | def alpha_bar_fn(t):
77 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
78 |
79 | elif alpha_transform_type == "exp":
80 |
81 | def alpha_bar_fn(t):
82 | return math.exp(t * -12.0)
83 |
84 | else:
85 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
86 |
87 | betas = []
88 | for i in range(num_diffusion_timesteps):
89 | t1 = i / num_diffusion_timesteps
90 | t2 = (i + 1) / num_diffusion_timesteps
91 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92 | return torch.tensor(betas, dtype=torch.float32)
93 |
94 |
95 | class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
96 | """
97 | Euler scheduler.
98 |
99 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
100 | methods the library implements for all schedulers such as loading and saving.
101 |
102 | Args:
103 | num_train_timesteps (`int`, defaults to 1000):
104 | The number of diffusion steps to train the model.
105 | beta_start (`float`, defaults to 0.0001):
106 | The starting `beta` value of inference.
107 | beta_end (`float`, defaults to 0.02):
108 | The final `beta` value.
109 | beta_schedule (`str`, defaults to `"linear"`):
110 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
111 | `linear` or `scaled_linear`.
112 | trained_betas (`np.ndarray`, *optional*):
113 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
114 | prediction_type (`str`, defaults to `epsilon`, *optional*):
115 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
116 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
117 | Video](https://imagen.research.google/video/paper.pdf) paper).
118 | interpolation_type(`str`, defaults to `"linear"`, *optional*):
119 | The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
120 | `"linear"` or `"log_linear"`.
121 | use_karras_sigmas (`bool`, *optional*, defaults to `False`):
122 | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
123 | the sigmas are determined according to a sequence of noise levels {σi}.
124 | timestep_spacing (`str`, defaults to `"linspace"`):
125 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
126 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
127 | steps_offset (`int`, defaults to 0):
128 | An offset added to the inference steps. You can use a combination of `offset=1` and
129 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
130 | Diffusion.
131 | """
132 |
133 | _compatibles = [e.name for e in KarrasDiffusionSchedulers]
134 | order = 1
135 |
136 | @register_to_config
137 | def __init__(
138 | self,
139 | num_train_timesteps: int = 1000,
140 | beta_start: float = 0.0001,
141 | beta_end: float = 0.02,
142 | beta_schedule: str = "linear",
143 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
144 | prediction_type: str = "epsilon",
145 | interpolation_type: str = "linear",
146 | use_karras_sigmas: Optional[bool] = False,
147 | sigma_min: Optional[float] = None,
148 | sigma_max: Optional[float] = None,
149 | timestep_spacing: str = "linspace",
150 | timestep_type: str = "discrete", # can be "discrete" or "continuous"
151 | steps_offset: int = 0,
152 | ):
153 | if trained_betas is not None:
154 | self.betas = torch.tensor(trained_betas, dtype=torch.float32)
155 | elif beta_schedule == "linear":
156 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
157 | elif beta_schedule == "scaled_linear":
158 | # this schedule is very specific to the latent diffusion model.
159 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
160 | elif beta_schedule == "squaredcos_cap_v2":
161 | # Glide cosine schedule
162 | self.betas = betas_for_alpha_bar(num_train_timesteps)
163 | else:
164 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
165 |
166 | self.alphas = 1.0 - self.betas
167 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
168 |
169 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
170 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
171 |
172 | sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
173 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
174 |
175 | # setable values
176 | self.num_inference_steps = None
177 |
178 | # TODO: Support the full EDM scalings for all prediction types and timestep types
179 | if timestep_type == "continuous" and prediction_type == "v_prediction":
180 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
181 | else:
182 | self.timesteps = timesteps
183 |
184 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
185 |
186 | self.is_scale_input_called = False
187 | self.use_karras_sigmas = use_karras_sigmas
188 |
189 | self._step_index = None
190 |
191 | @property
192 | def init_noise_sigma(self):
193 | # standard deviation of the initial noise distribution
194 | if self.config.timestep_spacing in ["linspace", "trailing"]:
195 | return self.sigmas.max()
196 |
197 | return (self.sigmas.max() ** 2 + 1) ** 0.5
198 |
199 | @property
200 | def step_index(self):
201 | """
202 | The index counter for current timestep. It will increae 1 after each scheduler step.
203 | """
204 | return self._step_index
205 |
206 | def scale_model_input(
207 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
208 | ) -> torch.FloatTensor:
209 | """
210 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
211 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
212 |
213 | Args:
214 | sample (`torch.FloatTensor`):
215 | The input sample.
216 | timestep (`int`, *optional*):
217 | The current timestep in the diffusion chain.
218 |
219 | Returns:
220 | `torch.FloatTensor`:
221 | A scaled input sample.
222 | """
223 | if self.step_index is None:
224 | self._init_step_index(timestep)
225 |
226 | sigma = self.sigmas[self.step_index]
227 | sample = sample / ((sigma**2 + 1) ** 0.5)
228 |
229 | self.is_scale_input_called = True
230 | return sample
231 |
232 | def set_timesteps(
233 | self,
234 | num_inference_steps: int,
235 | device: Union[str, torch.device] = None):
236 | """
237 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238 |
239 | Args:
240 | num_inference_steps (`int`):
241 | The number of diffusion steps used when generating samples with a pre-trained model.
242 | device (`str` or `torch.device`, *optional*):
243 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244 | """
245 | self.num_inference_steps = num_inference_steps
246 | # import ipdb; ipdb.set_trace()
247 |
248 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
249 | if self.config.timestep_spacing == "linspace":
250 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
251 | ::-1
252 | ].copy()
253 | elif self.config.timestep_spacing == "leading":
254 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps
255 | # creates integer timesteps by multiplying by ratio
256 | # casting to int to avoid issues when num_inference_step is power of 3
257 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
258 | timesteps += self.config.steps_offset
259 | elif self.config.timestep_spacing == "trailing":
260 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps
261 | # creates integer timesteps by multiplying by ratio
262 | # casting to int to avoid issues when num_inference_step is power of 3
263 | timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
264 | timesteps -= 1
265 | else:
266 | raise ValueError(
267 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
268 | )
269 |
270 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
271 | log_sigmas = np.log(sigmas)
272 |
273 | if self.config.interpolation_type == "linear":
274 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
275 | elif self.config.interpolation_type == "log_linear":
276 | sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
277 | else:
278 | raise ValueError(
279 | f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
280 | " 'linear' or 'log_linear'"
281 | )
282 |
283 | if self.use_karras_sigmas:
284 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
285 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
286 |
287 |
288 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
289 |
290 | # TODO: Support the full EDM scalings for all prediction types and timestep types
291 | if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
292 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
293 | else:
294 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
295 |
296 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
297 | self._step_index = None
298 |
299 | def _sigma_to_t(self, sigma, log_sigmas):
300 | # get log sigma
301 | log_sigma = np.log(np.maximum(sigma, 1e-10))
302 |
303 | # get distribution
304 | dists = log_sigma - log_sigmas[:, np.newaxis]
305 |
306 | # get sigmas range
307 | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
308 | high_idx = low_idx + 1
309 |
310 | low = log_sigmas[low_idx]
311 | high = log_sigmas[high_idx]
312 |
313 | # interpolate sigmas
314 | w = (low - log_sigma) / (low - high)
315 | w = np.clip(w, 0, 1)
316 |
317 | # transform interpolation to time range
318 | t = (1 - w) * low_idx + w * high_idx
319 | t = t.reshape(sigma.shape)
320 | return t
321 |
322 | # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
323 | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
324 | """Constructs the noise schedule of Karras et al. (2022)."""
325 |
326 | # Hack to make sure that other schedulers which copy this function don't break
327 | # TODO: Add this logic to the other schedulers
328 | if hasattr(self.config, "sigma_min"):
329 | sigma_min = self.config.sigma_min
330 | else:
331 | sigma_min = None
332 |
333 | if hasattr(self.config, "sigma_max"):
334 | sigma_max = self.config.sigma_max
335 | else:
336 | sigma_max = None
337 |
338 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
339 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
340 |
341 | rho = 7.0 # 7.0 is the value used in the paper
342 | ramp = np.linspace(0, 1, num_inference_steps)
343 | min_inv_rho = sigma_min ** (1 / rho)
344 | max_inv_rho = sigma_max ** (1 / rho)
345 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
346 | return sigmas
347 |
348 | def _init_step_index(self, timestep):
349 | if isinstance(timestep, torch.Tensor):
350 | timestep = timestep.to(self.timesteps.device)
351 |
352 | index_candidates = (self.timesteps == timestep).nonzero()
353 |
354 | # The sigma index that is taken for the **very** first `step`
355 | # is always the second index (or the last index if there is only 1)
356 | # This way we can ensure we don't accidentally skip a sigma in
357 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
358 | if len(index_candidates) > 1:
359 | step_index = index_candidates[1]
360 | else:
361 | step_index = index_candidates[0]
362 |
363 | self._step_index = step_index.item()
364 |
365 | def step(
366 | self,
367 | model_output: torch.FloatTensor,
368 | timestep: Union[float, torch.FloatTensor],
369 | sample: torch.FloatTensor,
370 | step_type: str = "forward",
371 | s_churn: float = 0.0,
372 | s_tmin: float = 0.0,
373 | s_tmax: float = float("inf"),
374 | s_noise: float = 1.0,
375 | generator: Optional[torch.Generator] = None,
376 | return_dict: bool = True,
377 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
378 | """
379 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
380 | process from the learned model outputs (most often the predicted noise).
381 |
382 | Args:
383 | model_output (`torch.FloatTensor`):
384 | The direct output from learned diffusion model.
385 | timestep (`float`):
386 | The current discrete timestep in the diffusion chain.
387 | sample (`torch.FloatTensor`):
388 | A current instance of a sample created by the diffusion process.
389 | s_churn (`float`):
390 | s_tmin (`float`):
391 | s_tmax (`float`):
392 | s_noise (`float`, defaults to 1.0):
393 | Scaling factor for noise added to the sample.
394 | generator (`torch.Generator`, *optional*):
395 | A random number generator.
396 | return_dict (`bool`):
397 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
398 | tuple.
399 |
400 | Returns:
401 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
402 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
403 | returned, otherwise a tuple is returned where the first element is the sample tensor.
404 | """
405 |
406 | if (
407 | isinstance(timestep, int)
408 | or isinstance(timestep, torch.IntTensor)
409 | or isinstance(timestep, torch.LongTensor)
410 | ):
411 | raise ValueError(
412 | (
413 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
414 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
415 | " one of the `scheduler.timesteps` as a timestep."
416 | ),
417 | )
418 |
419 | if not self.is_scale_input_called:
420 | logger.warning(
421 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
422 | "See `StableDiffusionPipeline` for a usage example."
423 | )
424 | # import ipdb; ipdb.set_trace()
425 | if self.step_index is None:
426 | self._init_step_index(timestep)
427 |
428 | sigma = self.sigmas[self.step_index]
429 |
430 | gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
431 |
432 | noise = randn_tensor(
433 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
434 | )
435 |
436 | eps = noise * s_noise
437 | sigma_hat = sigma * (gamma + 1)
438 |
439 | if gamma > 0:
440 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
441 |
442 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
443 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for
444 | # backwards compatibility
445 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
446 | pred_original_sample = model_output
447 | elif self.config.prediction_type == "epsilon":
448 | pred_original_sample = sample - sigma_hat * model_output
449 | elif self.config.prediction_type == "v_prediction":
450 | # denoised = model_output * c_out + input * c_skip
451 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
452 | else:
453 | raise ValueError(
454 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
455 | )
456 |
457 | # 2. Convert to an ODE derivative
458 | derivative = (sample - pred_original_sample) / sigma_hat
459 | dt = self.sigmas[self.step_index + 1] - sigma_hat
460 |
461 | prev_sample = sample + derivative * dt
462 |
463 | # upon completion increase step index by one
464 | self._step_index += 1
465 |
466 | if not return_dict:
467 | return (prev_sample,)
468 |
469 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
470 |
471 |
472 | def undo_step(self, sample, step_id, generator=None, ratio=0.49):
473 | noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
474 | sample = sample + noise * ((self.sigmas[step_id]**2 - self.sigmas[step_id+1]**2) ** 0.5) * ratio
475 | return sample
476 |
477 |
478 | def add_noise(
479 | self,
480 | original_samples: torch.FloatTensor,
481 | noise: torch.FloatTensor,
482 | timesteps: torch.FloatTensor,
483 | ) -> torch.FloatTensor:
484 | # Make sure sigmas and timesteps have the same device and dtype as original_samples
485 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
486 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
487 | # mps does not support float64
488 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
489 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
490 | else:
491 | schedule_timesteps = self.timesteps.to(original_samples.device)
492 | timesteps = timesteps.to(original_samples.device)
493 |
494 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
495 |
496 | sigma = sigmas[step_indices].flatten()
497 | while len(sigma.shape) < len(original_samples.shape):
498 | sigma = sigma.unsqueeze(-1)
499 |
500 | noisy_samples = original_samples + noise * sigma
501 | return noisy_samples
502 |
503 | def __len__(self):
504 | return self.config.num_train_timesteps
505 |
--------------------------------------------------------------------------------
/svd_sequential_re.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | import tqdm._tqdm
4 | import torch
5 | from diffusers import StableVideoDiffusionPipeline
6 | from PIL import Image
7 | from diffusers.utils import load_image, export_to_video
8 | from pipeline_stable_video_diffusion_re import StableVideoDiffusionPipeline_Custom
9 | from unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
10 | from scheduling_euler_discrete_resampling import EulerDiscreteScheduler
11 | from datasets import get_dataset
12 | import tqdm
13 |
14 | def img_preprocess(img_path, mode='crop', orig_aspect=None):
15 | image = load_image(img_path)
16 | w, h = image.size
17 | aspect_ratio = orig_aspect if orig_aspect else 16 / 9
18 |
19 | if mode == 'crop':
20 | # Crop the image to the specified aspect ratio
21 | if w / h > aspect_ratio:
22 | height = h
23 | width = int(height * aspect_ratio)
24 | else:
25 | width = w
26 | height = int(width / aspect_ratio)
27 | left = (w - width) // 2
28 | top = (h - height) // 2
29 | image = image.crop((left, top, left + width, top + height))
30 | image = image.resize((1024, 576))
31 | elif mode == 'padding':
32 | # Pad the image to the specified aspect ratio
33 | new_w = int(h * aspect_ratio) if w / h < aspect_ratio else w
34 | new_h = int(w / aspect_ratio) if w / h >= aspect_ratio else h
35 | new_image = Image.new("RGB", (new_w, new_h), (0, 0, 0))
36 | new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2))
37 | image = new_image.resize((1024, 576))
38 | return image
39 |
40 | def time_reversal_fusion(model_card, start_frame, end_frame, num_inference_steps, fps_value, jump_n_sample, jump_length, repeat_step_ratio, noise_scale_ratio, motion_id, generator):
41 | # Load model components
42 | original_pipe = StableVideoDiffusionPipeline.from_pretrained(
43 | model_card, torch_dtype=torch.float16, variant="fp16"
44 | )
45 | unet_custom = UNetSpatioTemporalConditionModel.from_pretrained(
46 | model_card, subfolder="unet", torch_dtype=torch.float16, variant="fp16"
47 | ).to('cuda')
48 | scheduler_custom = EulerDiscreteScheduler.from_pretrained(
49 | model_card, subfolder="scheduler", torch_dtype=torch.float16, variant="fp16"
50 | )
51 | # Generate frames
52 | pipe = StableVideoDiffusionPipeline_Custom(
53 | vae=original_pipe.vae,
54 | image_encoder=original_pipe.image_encoder,
55 | unet=unet_custom,
56 | scheduler=scheduler_custom,
57 | feature_extractor=original_pipe.feature_extractor,
58 | )
59 | pipe.enable_model_cpu_offload()
60 | frames = pipe(
61 | start_frame,
62 | end_frame,
63 | height=start_frame.height,
64 | width=start_frame.width,
65 | num_frames=25,
66 | num_inference_steps=num_inference_steps,
67 | fps=fps_value,
68 | jump_length=jump_length,
69 | jump_n_sample=jump_n_sample,
70 | repeat_step_ratio=repeat_step_ratio,
71 | noise_scale_ratio=noise_scale_ratio,
72 | decode_chunk_size=8,
73 | motion_bucket_id=motion_id,
74 | generator=generator,
75 | ).frames[0]
76 |
77 | return frames
78 |
79 | def main(data_type):
80 | # Configuration
81 | root_dir = os.path.dirname(os.path.abspath(__file__))
82 | # data_type = 'multiview' # Options: 'image2loop', 'video_frames', 'multiview'
83 | dataset_folder = os.path.join(root_dir, 'test_data')
84 | output_folder = os.path.join(root_dir, 'output', f'{data_type}_exp')
85 | os.makedirs(output_folder, exist_ok=True)
86 |
87 | # Model parameters
88 | model_card = "stabilityai/stable-video-diffusion-img2vid-xt"
89 | fps_value = 7
90 | motion_id = 127
91 | random_seed = 42
92 | num_inference_steps = 50
93 | jump_n_sample = 2
94 | jump_length = 5
95 | repeat_step_ratio = 0.8
96 | noise_scale_ratio = 1.0
97 | generator = torch.manual_seed(random_seed)
98 |
99 | # Load data-specific settings
100 | if data_type == 'image2loop':
101 | video_frames = get_dataset(dataset_folder, data_type='loop')
102 | fps_value = 7
103 | motion_id = 127
104 | jump_n_sample = 2
105 | jump_length = 5
106 | repeat_step_ratio = 0.8
107 | elif data_type == 'video_frames':
108 | video_frames = get_dataset(dataset_folder, data_type='frame', filter_keyword='video_frames')
109 | fps_value = 7
110 | motion_id = 127
111 | jump_n_sample = 2
112 | jump_length = 5
113 | repeat_step_ratio = 0.8
114 | noise_scale_ratio = .95
115 |
116 | elif data_type == 'gym_motion':
117 | video_frames = get_dataset(dataset_folder, data_type='frame', filter_keyword='gym_motion')
118 | fps_value = 17
119 | motion_id = 10
120 | jump_n_sample = 2
121 | jump_length = 5
122 | repeat_step_ratio = 0.8
123 | elif data_type == 'multiview':
124 | video_frames = get_dataset(dataset_folder, data_type='multiview')
125 | fps_value = 7
126 | motion_id = 127
127 | jump_n_sample = 2
128 | jump_length = 5
129 | repeat_step_ratio = 0.8
130 | noise_scale_ratio = 1
131 |
132 | else:
133 | raise ValueError(f"Unsupported data_type: {data_type}")
134 |
135 | # Model directory
136 | re_steps = int((1 - repeat_step_ratio) * num_inference_steps)
137 | model_folder_name = f"{model_card.split('/')[-1]}_fps{fps_value}_id{motion_id}_s-num{num_inference_steps}_re{re_steps}_{jump_length}_{jump_n_sample}_{noise_scale_ratio}"
138 | model_folder = os.path.join(output_folder, model_folder_name)
139 | os.makedirs(model_folder, exist_ok=True)
140 |
141 | # Process each image pair
142 | for idx in tqdm.tqdm(range(len(video_frames))):
143 | image_pair = video_frames[idx]
144 | start_frame = img_preprocess(image_pair[0])
145 | end_frame = img_preprocess(image_pair[1])
146 |
147 | # Generate frame folder name
148 | base_name_start = os.path.splitext(os.path.basename(image_pair[0]))[0]
149 | base_name_end = os.path.splitext(os.path.basename(image_pair[1]))[0]
150 | dir_name = os.path.basename(os.path.dirname(image_pair[0]))
151 |
152 | if data_type == 'image2loop':
153 | frame_folder_name = f"{dir_name}_{base_name_start}"
154 | else:
155 | frame_folder_name = f"{dir_name}_{base_name_start}_{base_name_end}"
156 |
157 | frame_folder = os.path.join(model_folder, frame_folder_name)
158 | os.makedirs(frame_folder, exist_ok=True)
159 | video_file = f"{frame_folder}.mp4"
160 |
161 | frames = time_reversal_fusion(
162 | model_card, start_frame, end_frame, num_inference_steps, fps_value, jump_n_sample, jump_length, repeat_step_ratio, noise_scale_ratio, motion_id, generator
163 | )
164 |
165 | # Save frames
166 | for i, frame in enumerate(frames):
167 | frame.save(os.path.join(frame_folder, f'{i}.png'))
168 |
169 | # Export to video
170 | export_to_video(frames, video_file, fps=fps_value)
171 |
172 | if __name__ == '__main__':
173 | #different input flags for different datasets
174 | args = sys.argv[1:]
175 | data_type = args[0]
176 |
177 | main(data_type)
178 |
--------------------------------------------------------------------------------
/test_data/Multiview_data/mipnerf360_lite/garden/frame_00006.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/Multiview_data/mipnerf360_lite/garden/frame_00006.JPG
--------------------------------------------------------------------------------
/test_data/Multiview_data/mipnerf360_lite/garden/frame_00104.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/Multiview_data/mipnerf360_lite/garden/frame_00104.JPG
--------------------------------------------------------------------------------
/test_data/gym_motion_2024_frames/arm_clip1/frame_00018.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/gym_motion_2024_frames/arm_clip1/frame_00018.jpg
--------------------------------------------------------------------------------
/test_data/gym_motion_2024_frames/arm_clip1/frame_00060.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/gym_motion_2024_frames/arm_clip1/frame_00060.jpg
--------------------------------------------------------------------------------
/test_data/video_frames/dolomite_clip3/frame_00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/video_frames/dolomite_clip3/frame_00000.jpg
--------------------------------------------------------------------------------
/test_data/video_frames/dolomite_clip3/frame_00100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HavenFeng/time_reversal/cdfb4ecfeaa8a2e50e30bf7a383b4b2d8ebcc2fb/test_data/video_frames/dolomite_clip3/frame_00100.jpg
--------------------------------------------------------------------------------
/unet_spatio_temporal_condition.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from diffusers.configuration_utils import ConfigMixin, register_to_config
8 | from diffusers.loaders import UNet2DConditionLoadersMixin
9 | from diffusers.utils import BaseOutput, logging
10 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
11 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
12 | from diffusers.models.modeling_utils import ModelMixin
13 | from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
14 |
15 |
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17 |
18 |
19 | @dataclass
20 | class UNetSpatioTemporalConditionOutput(BaseOutput):
21 | """
22 | The output of [`UNetSpatioTemporalConditionModel`].
23 |
24 | Args:
25 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
26 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27 | """
28 |
29 | sample: torch.FloatTensor = None
30 | haha: str = 'haha'
31 |
32 |
33 | class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
34 | r"""
35 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
36 | shaped output.
37 |
38 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
39 | for all models (such as downloading or saving).
40 |
41 | Parameters:
42 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
43 | Height and width of input/output sample.
44 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
45 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
46 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
47 | The tuple of downsample blocks to use.
48 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
49 | The tuple of upsample blocks to use.
50 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
51 | The tuple of output channels for each block.
52 | addition_time_embed_dim: (`int`, defaults to 256):
53 | Dimension to to encode the additional time ids.
54 | projection_class_embeddings_input_dim (`int`, defaults to 768):
55 | The dimension of the projection of encoded `added_time_ids`.
56 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
57 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
58 | The dimension of the cross attention features.
59 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
60 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
61 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
62 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
63 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
64 | The number of attention heads.
65 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66 | """
67 |
68 | _supports_gradient_checkpointing = True
69 |
70 | @register_to_config
71 | def __init__(
72 | self,
73 | sample_size: Optional[int] = None,
74 | in_channels: int = 8,
75 | out_channels: int = 4,
76 | down_block_types: Tuple[str] = (
77 | "CrossAttnDownBlockSpatioTemporal",
78 | "CrossAttnDownBlockSpatioTemporal",
79 | "CrossAttnDownBlockSpatioTemporal",
80 | "DownBlockSpatioTemporal",
81 | ),
82 | up_block_types: Tuple[str] = (
83 | "UpBlockSpatioTemporal",
84 | "CrossAttnUpBlockSpatioTemporal",
85 | "CrossAttnUpBlockSpatioTemporal",
86 | "CrossAttnUpBlockSpatioTemporal",
87 | ),
88 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
89 | addition_time_embed_dim: int = 256,
90 | projection_class_embeddings_input_dim: int = 768,
91 | layers_per_block: Union[int, Tuple[int]] = 2,
92 | cross_attention_dim: Union[int, Tuple[int]] = 1024,
93 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
94 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
95 | num_frames: int = 25,
96 | ):
97 | super().__init__()
98 |
99 | self.sample_size = sample_size
100 |
101 | # Check inputs
102 | if len(down_block_types) != len(up_block_types):
103 | raise ValueError(
104 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
105 | )
106 |
107 | if len(block_out_channels) != len(down_block_types):
108 | raise ValueError(
109 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
110 | )
111 |
112 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
113 | raise ValueError(
114 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
115 | )
116 |
117 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
118 | raise ValueError(
119 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
120 | )
121 |
122 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
123 | raise ValueError(
124 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
125 | )
126 |
127 | # input
128 | self.conv_in = nn.Conv2d(
129 | in_channels,
130 | block_out_channels[0],
131 | kernel_size=3,
132 | padding=1,
133 | )
134 |
135 | # time
136 | time_embed_dim = block_out_channels[0] * 4
137 | # import pdb; pdb.set_trace()
138 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
139 | timestep_input_dim = block_out_channels[0]
140 |
141 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
142 |
143 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
145 |
146 | self.down_blocks = nn.ModuleList([])
147 | self.up_blocks = nn.ModuleList([])
148 |
149 | if isinstance(num_attention_heads, int):
150 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
151 |
152 | if isinstance(cross_attention_dim, int):
153 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
154 |
155 | if isinstance(layers_per_block, int):
156 | layers_per_block = [layers_per_block] * len(down_block_types)
157 |
158 | if isinstance(transformer_layers_per_block, int):
159 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
160 |
161 | blocks_time_embed_dim = time_embed_dim
162 |
163 | # down
164 | output_channel = block_out_channels[0]
165 | for i, down_block_type in enumerate(down_block_types):
166 | input_channel = output_channel
167 | output_channel = block_out_channels[i]
168 | is_final_block = i == len(block_out_channels) - 1
169 |
170 | down_block = get_down_block(
171 | down_block_type,
172 | num_layers=layers_per_block[i],
173 | transformer_layers_per_block=transformer_layers_per_block[i],
174 | in_channels=input_channel,
175 | out_channels=output_channel,
176 | temb_channels=blocks_time_embed_dim,
177 | add_downsample=not is_final_block,
178 | resnet_eps=1e-5,
179 | cross_attention_dim=cross_attention_dim[i],
180 | num_attention_heads=num_attention_heads[i],
181 | resnet_act_fn="silu",
182 | )
183 | self.down_blocks.append(down_block)
184 |
185 | # mid
186 | self.mid_block = UNetMidBlockSpatioTemporal(
187 | block_out_channels[-1],
188 | temb_channels=blocks_time_embed_dim,
189 | transformer_layers_per_block=transformer_layers_per_block[-1],
190 | cross_attention_dim=cross_attention_dim[-1],
191 | num_attention_heads=num_attention_heads[-1],
192 | )
193 |
194 | # count how many layers upsample the images
195 | self.num_upsamplers = 0
196 |
197 | # up
198 | reversed_block_out_channels = list(reversed(block_out_channels))
199 | reversed_num_attention_heads = list(reversed(num_attention_heads))
200 | reversed_layers_per_block = list(reversed(layers_per_block))
201 | reversed_cross_attention_dim = list(reversed(cross_attention_dim))
202 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
203 |
204 | output_channel = reversed_block_out_channels[0]
205 | for i, up_block_type in enumerate(up_block_types):
206 | is_final_block = i == len(block_out_channels) - 1
207 |
208 | prev_output_channel = output_channel
209 | output_channel = reversed_block_out_channels[i]
210 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
211 |
212 | # add upsample block for all BUT final layer
213 | if not is_final_block:
214 | add_upsample = True
215 | self.num_upsamplers += 1
216 | else:
217 | add_upsample = False
218 |
219 | up_block = get_up_block(
220 | up_block_type,
221 | num_layers=reversed_layers_per_block[i] + 1,
222 | transformer_layers_per_block=reversed_transformer_layers_per_block[i],
223 | in_channels=input_channel,
224 | out_channels=output_channel,
225 | prev_output_channel=prev_output_channel,
226 | temb_channels=blocks_time_embed_dim,
227 | add_upsample=add_upsample,
228 | resnet_eps=1e-5,
229 | resolution_idx=i,
230 | cross_attention_dim=reversed_cross_attention_dim[i],
231 | num_attention_heads=reversed_num_attention_heads[i],
232 | resnet_act_fn="silu",
233 | )
234 | self.up_blocks.append(up_block)
235 | prev_output_channel = output_channel
236 |
237 | # out
238 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
239 | self.conv_act = nn.SiLU()
240 |
241 | self.conv_out = nn.Conv2d(
242 | block_out_channels[0],
243 | out_channels,
244 | kernel_size=3,
245 | padding=1,
246 | )
247 |
248 | @property
249 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
250 | r"""
251 | Returns:
252 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
253 | indexed by its weight name.
254 | """
255 | # set recursively
256 | processors = {}
257 |
258 | def fn_recursive_add_processors(
259 | name: str,
260 | module: torch.nn.Module,
261 | processors: Dict[str, AttentionProcessor],
262 | ):
263 | if hasattr(module, "get_processor"):
264 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
265 |
266 | for sub_name, child in module.named_children():
267 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
268 |
269 | return processors
270 |
271 | for name, module in self.named_children():
272 | fn_recursive_add_processors(name, module, processors)
273 |
274 | return processors
275 |
276 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
277 | r"""
278 | Sets the attention processor to use to compute attention.
279 |
280 | Parameters:
281 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
282 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
283 | for **all** `Attention` layers.
284 |
285 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
286 | processor. This is strongly recommended when setting trainable attention processors.
287 |
288 | """
289 | count = len(self.attn_processors.keys())
290 |
291 | if isinstance(processor, dict) and len(processor) != count:
292 | raise ValueError(
293 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
294 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
295 | )
296 |
297 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
298 | if hasattr(module, "set_processor"):
299 | if not isinstance(processor, dict):
300 | module.set_processor(processor)
301 | else:
302 | module.set_processor(processor.pop(f"{name}.processor"))
303 |
304 | for sub_name, child in module.named_children():
305 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
306 |
307 | for name, module in self.named_children():
308 | fn_recursive_attn_processor(name, module, processor)
309 |
310 | def set_default_attn_processor(self):
311 | """
312 | Disables custom attention processors and sets the default attention implementation.
313 | """
314 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
315 | processor = AttnProcessor()
316 | else:
317 | raise ValueError(
318 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
319 | )
320 |
321 | self.set_attn_processor(processor)
322 |
323 | def _set_gradient_checkpointing(self, module, value=False):
324 | if hasattr(module, "gradient_checkpointing"):
325 | module.gradient_checkpointing = value
326 |
327 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
328 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
329 | """
330 | Sets the attention processor to use [feed forward
331 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
332 |
333 | Parameters:
334 | chunk_size (`int`, *optional*):
335 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
336 | over each tensor of dim=`dim`.
337 | dim (`int`, *optional*, defaults to `0`):
338 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
339 | or dim=1 (sequence length).
340 | """
341 | if dim not in [0, 1]:
342 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
343 |
344 | # By default chunk size is 1
345 | chunk_size = chunk_size or 1
346 |
347 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
348 | if hasattr(module, "set_chunk_feed_forward"):
349 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
350 |
351 | for child in module.children():
352 | fn_recursive_feed_forward(child, chunk_size, dim)
353 |
354 | for module in self.children():
355 | fn_recursive_feed_forward(module, chunk_size, dim)
356 |
357 | def forward(
358 | self,
359 | sample: torch.FloatTensor,
360 | timestep: Union[torch.Tensor, float, int],
361 | encoder_hidden_states: torch.Tensor,
362 | added_time_ids: torch.Tensor,
363 | return_dict: bool = True,
364 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
365 | r"""
366 | The [`UNetSpatioTemporalConditionModel`] forward method.
367 |
368 | Args:
369 | sample (`torch.FloatTensor`):
370 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
371 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
372 | encoder_hidden_states (`torch.FloatTensor`):
373 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
374 | added_time_ids: (`torch.FloatTensor`):
375 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
376 | embeddings and added to the time embeddings.
377 | return_dict (`bool`, *optional*, defaults to `True`):
378 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
379 | tuple.
380 | Returns:
381 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
382 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
383 | a `tuple` is returned where the first element is the sample tensor.
384 | """
385 | # 1. time
386 | timesteps = timestep
387 | if not torch.is_tensor(timesteps):
388 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
389 | # This would be a good case for the `match` statement (Python 3.10+)
390 | is_mps = sample.device.type == "mps"
391 | if isinstance(timestep, float):
392 | dtype = torch.float32 if is_mps else torch.float64
393 | else:
394 | dtype = torch.int32 if is_mps else torch.int64
395 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
396 | elif len(timesteps.shape) == 0:
397 | timesteps = timesteps[None].to(sample.device)
398 |
399 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
400 | batch_size, num_frames = sample.shape[:2]
401 | timesteps = timesteps.expand(batch_size)
402 |
403 | t_emb = self.time_proj(timesteps)
404 |
405 | # `Timesteps` does not contain any weights and will always return f32 tensors
406 | # but time_embedding might actually be running in fp16. so we need to cast here.
407 | # there might be better ways to encapsulate this.
408 | t_emb = t_emb.to(dtype=sample.dtype)
409 |
410 | emb = self.time_embedding(t_emb)
411 |
412 | time_embeds = self.add_time_proj(added_time_ids.flatten())
413 | time_embeds = time_embeds.reshape((batch_size, -1))
414 | time_embeds = time_embeds.to(emb.dtype)
415 | aug_emb = self.add_embedding(time_embeds)
416 | emb = emb + aug_emb
417 |
418 | # Flatten the batch and frames dimensions
419 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
420 | sample = sample.flatten(0, 1)
421 | # Repeat the embeddings num_video_frames times
422 | # emb: [batch, channels] -> [batch * frames, channels]
423 | emb = emb.repeat_interleave(num_frames, dim=0)
424 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
425 |
426 | if encoder_hidden_states.shape[0]==batch_size:
427 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
428 |
429 | # 2. pre-process
430 | sample = self.conv_in(sample)
431 |
432 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
433 |
434 | down_block_res_samples = (sample,)
435 | for downsample_block in self.down_blocks:
436 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
437 | sample, res_samples = downsample_block(
438 | hidden_states=sample,
439 | temb=emb,
440 | encoder_hidden_states=encoder_hidden_states,
441 | image_only_indicator=image_only_indicator,
442 | )
443 | else:
444 | sample, res_samples = downsample_block(
445 | hidden_states=sample,
446 | temb=emb,
447 | image_only_indicator=image_only_indicator,
448 | )
449 |
450 | down_block_res_samples += res_samples
451 |
452 | # 4. mid
453 | sample = self.mid_block(
454 | hidden_states=sample,
455 | temb=emb,
456 | encoder_hidden_states=encoder_hidden_states,
457 | image_only_indicator=image_only_indicator,
458 | )
459 |
460 | # 5. up
461 | for i, upsample_block in enumerate(self.up_blocks):
462 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
463 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
464 |
465 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
466 | sample = upsample_block(
467 | hidden_states=sample,
468 | temb=emb,
469 | res_hidden_states_tuple=res_samples,
470 | encoder_hidden_states=encoder_hidden_states,
471 | image_only_indicator=image_only_indicator,
472 | )
473 | else:
474 | sample = upsample_block(
475 | hidden_states=sample,
476 | temb=emb,
477 | res_hidden_states_tuple=res_samples,
478 | image_only_indicator=image_only_indicator,
479 | )
480 |
481 | # 6. post-process
482 | sample = self.conv_norm_out(sample)
483 | sample = self.conv_act(sample)
484 | sample = self.conv_out(sample)
485 |
486 | # 7. Reshape back to original shape
487 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
488 |
489 | if not return_dict:
490 |
491 | return (sample,)
492 |
493 | return UNetSpatioTemporalConditionOutput(sample=sample)
494 |
--------------------------------------------------------------------------------