├── CogVideoX ├── cli.py ├── pipeline_rgba.py └── rgba_utils.py ├── LICENSE.md ├── Mochi ├── README.md ├── args.py ├── cli.py ├── dataset_simple.py ├── embed.py ├── pipeline_mochi_rgba.py ├── prepare_dataset.sh ├── rgba_utils.py ├── train.py ├── train.sh ├── trim_and_crop_videos.py └── utils.py ├── README.md ├── app.py ├── requirements.txt └── wechat_group.jpg /CogVideoX/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from diffusers import CogVideoXDPMScheduler 4 | from pipeline_rgba import CogVideoXPipeline 5 | from diffusers.utils import export_to_video 6 | import argparse 7 | import numpy as np 8 | from rgba_utils import * 9 | 10 | def main(args): 11 | # 1. load pipeline 12 | pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) 13 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 14 | pipe.enable_sequential_cpu_offload() 15 | pipe.vae.enable_slicing() 16 | pipe.vae.enable_tiling() 17 | 18 | 19 | # 2. define prompt and arguments 20 | pipeline_args = { 21 | "prompt": args.prompt, 22 | "guidance_scale": args.guidance_scale, 23 | "num_inference_steps": args.num_inference_steps, 24 | "height": args.height, 25 | "width": args.width, 26 | "num_frames": args.num_frames, 27 | "output_type": "latent", 28 | "use_dynamic_cfg":True, 29 | } 30 | 31 | # 3. prepare rgbx utils 32 | # breakpoint() 33 | seq_length = 2 * ( 34 | (args.height // pipe.vae_scale_factor_spatial // 2) 35 | * (args.width // pipe.vae_scale_factor_spatial // 2) 36 | * ((args.num_frames - 1) // pipe.vae_scale_factor_temporal + 1) 37 | ) 38 | # seq_length = 35100 39 | 40 | prepare_for_rgba_inference( 41 | pipe.transformer, 42 | rgba_weights_path=args.lora_path, 43 | device="cuda", 44 | dtype=torch.bfloat16, 45 | text_length=226, 46 | seq_length=seq_length, # this is for the creation of attention mask. 47 | ) 48 | 49 | # 4. run inference 50 | generator = torch.manual_seed(args.seed) if args.seed else None 51 | frames_latents = pipe(**pipeline_args, generator=generator).frames 52 | 53 | frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=1) 54 | 55 | frames_rgb = decode_latents(pipe, frames_latents_rgb) 56 | frames_alpha = decode_latents(pipe, frames_latents_alpha) 57 | 58 | 59 | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) 60 | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) 61 | premultiplied_rgb = frames_rgb * frames_alpha_pooled 62 | 63 | if os.path.exists(args.output_path) == False: 64 | os.makedirs(args.output_path) 65 | 66 | export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps) 67 | export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps) 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser(description="Generate a video from a text prompt") 72 | parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") 73 | parser.add_argument("--lora_path", type=str, default="/hpc2hdd/home/lwang592/projects/CogVideo/sat/outputs/training/ckpts-5b-attn_rebias-partial_lora-8gpu-wo_t2a/lora-rgba-12-21-19-11/5000/rgba_lora.safetensors", help="The path of the LoRA weights to be used") 74 | 75 | parser.add_argument( 76 | "--model_path", type=str, default="THUDM/CogVideoX-5B", help="Path of the pre-trained model use" 77 | ) 78 | 79 | 80 | parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video") 81 | parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") 82 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") 83 | parser.add_argument("--num_frames", type=int, default=49, help="Number of steps for the inference process") 84 | parser.add_argument("--width", type=int, default=720, help="Number of steps for the inference process") 85 | parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process") 86 | parser.add_argument("--fps", type=int, default=8, help="Number of steps for the inference process") 87 | parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility") 88 | args = parser.parse_args() 89 | 90 | main(args) -------------------------------------------------------------------------------- /CogVideoX/pipeline_rgba.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import inspect 17 | import math 18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 19 | 20 | import torch 21 | from transformers import T5EncoderModel, T5Tokenizer 22 | 23 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 24 | from diffusers.loaders import CogVideoXLoraLoaderMixin 25 | from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel 26 | from diffusers.models.embeddings import get_3d_rotary_pos_embed 27 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 28 | from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler 29 | from diffusers.utils import logging, replace_example_docstring 30 | from diffusers.utils.torch_utils import randn_tensor 31 | from diffusers.video_processor import VideoProcessor 32 | from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput 33 | 34 | 35 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 36 | 37 | 38 | EXAMPLE_DOC_STRING = """ 39 | Examples: 40 | ```python 41 | >>> import torch 42 | >>> from diffusers import CogVideoXPipeline 43 | >>> from diffusers.utils import export_to_video 44 | 45 | >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" 46 | >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") 47 | >>> prompt = ( 48 | ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " 49 | ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " 50 | ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " 51 | ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " 52 | ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " 53 | ... "atmosphere of this unique musical performance." 54 | ... ) 55 | >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] 56 | >>> export_to_video(video, "output.mp4", fps=8) 57 | ``` 58 | """ 59 | 60 | 61 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid 62 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): 63 | tw = tgt_width 64 | th = tgt_height 65 | h, w = src 66 | r = h / w 67 | if r > (th / tw): 68 | resize_height = th 69 | resize_width = int(round(th / h * w)) 70 | else: 71 | resize_width = tw 72 | resize_height = int(round(tw / w * h)) 73 | 74 | crop_top = int(round((th - resize_height) / 2.0)) 75 | crop_left = int(round((tw - resize_width) / 2.0)) 76 | 77 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 78 | 79 | 80 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 81 | def retrieve_timesteps( 82 | scheduler, 83 | num_inference_steps: Optional[int] = None, 84 | device: Optional[Union[str, torch.device]] = None, 85 | timesteps: Optional[List[int]] = None, 86 | sigmas: Optional[List[float]] = None, 87 | **kwargs, 88 | ): 89 | r""" 90 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 91 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 92 | 93 | Args: 94 | scheduler (`SchedulerMixin`): 95 | The scheduler to get timesteps from. 96 | num_inference_steps (`int`): 97 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 98 | must be `None`. 99 | device (`str` or `torch.device`, *optional*): 100 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 101 | timesteps (`List[int]`, *optional*): 102 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 103 | `num_inference_steps` and `sigmas` must be `None`. 104 | sigmas (`List[float]`, *optional*): 105 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 106 | `num_inference_steps` and `timesteps` must be `None`. 107 | 108 | Returns: 109 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 110 | second element is the number of inference steps. 111 | """ 112 | if timesteps is not None and sigmas is not None: 113 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 114 | if timesteps is not None: 115 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 116 | if not accepts_timesteps: 117 | raise ValueError( 118 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 119 | f" timestep schedules. Please check whether you are using the correct scheduler." 120 | ) 121 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 122 | timesteps = scheduler.timesteps 123 | num_inference_steps = len(timesteps) 124 | elif sigmas is not None: 125 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 126 | if not accept_sigmas: 127 | raise ValueError( 128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 129 | f" sigmas schedules. Please check whether you are using the correct scheduler." 130 | ) 131 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 132 | timesteps = scheduler.timesteps 133 | num_inference_steps = len(timesteps) 134 | else: 135 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | return timesteps, num_inference_steps 138 | 139 | 140 | class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): 141 | r""" 142 | Pipeline for text-to-video generation using CogVideoX. 143 | 144 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 145 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 146 | 147 | Args: 148 | vae ([`AutoencoderKL`]): 149 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 150 | text_encoder ([`T5EncoderModel`]): 151 | Frozen text-encoder. CogVideoX uses 152 | [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the 153 | [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. 154 | tokenizer (`T5Tokenizer`): 155 | Tokenizer of class 156 | [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 157 | transformer ([`CogVideoXTransformer3DModel`]): 158 | A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. 159 | scheduler ([`SchedulerMixin`]): 160 | A scheduler to be used in combination with `transformer` to denoise the encoded video latents. 161 | """ 162 | 163 | _optional_components = [] 164 | model_cpu_offload_seq = "text_encoder->transformer->vae" 165 | 166 | _callback_tensor_inputs = [ 167 | "latents", 168 | "prompt_embeds", 169 | "negative_prompt_embeds", 170 | ] 171 | 172 | def __init__( 173 | self, 174 | tokenizer: T5Tokenizer, 175 | text_encoder: T5EncoderModel, 176 | vae: AutoencoderKLCogVideoX, 177 | transformer: CogVideoXTransformer3DModel, 178 | scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], 179 | ): 180 | super().__init__() 181 | 182 | self.register_modules( 183 | tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler 184 | ) 185 | self.vae_scale_factor_spatial = ( 186 | 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 187 | ) 188 | self.vae_scale_factor_temporal = ( 189 | self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 190 | ) 191 | self.vae_scaling_factor_image = ( 192 | self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 193 | ) 194 | 195 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 196 | 197 | def _get_t5_prompt_embeds( 198 | self, 199 | prompt: Union[str, List[str]] = None, 200 | num_videos_per_prompt: int = 1, 201 | max_sequence_length: int = 226, 202 | device: Optional[torch.device] = None, 203 | dtype: Optional[torch.dtype] = None, 204 | ): 205 | device = device or self._execution_device 206 | dtype = dtype or self.text_encoder.dtype 207 | 208 | prompt = [prompt] if isinstance(prompt, str) else prompt 209 | batch_size = len(prompt) 210 | 211 | text_inputs = self.tokenizer( 212 | prompt, 213 | padding="max_length", 214 | max_length=max_sequence_length, 215 | truncation=True, 216 | add_special_tokens=True, 217 | return_tensors="pt", 218 | ) 219 | text_input_ids = text_inputs.input_ids 220 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 221 | 222 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 223 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) 224 | logger.warning( 225 | "The following part of your input was truncated because `max_sequence_length` is set to " 226 | f" {max_sequence_length} tokens: {removed_text}" 227 | ) 228 | 229 | prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] 230 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 231 | 232 | # duplicate text embeddings for each generation per prompt, using mps friendly method 233 | _, seq_len, _ = prompt_embeds.shape 234 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 235 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 236 | 237 | return prompt_embeds 238 | 239 | def encode_prompt( 240 | self, 241 | prompt: Union[str, List[str]], 242 | negative_prompt: Optional[Union[str, List[str]]] = None, 243 | do_classifier_free_guidance: bool = True, 244 | num_videos_per_prompt: int = 1, 245 | prompt_embeds: Optional[torch.Tensor] = None, 246 | negative_prompt_embeds: Optional[torch.Tensor] = None, 247 | max_sequence_length: int = 226, 248 | device: Optional[torch.device] = None, 249 | dtype: Optional[torch.dtype] = None, 250 | ): 251 | r""" 252 | Encodes the prompt into text encoder hidden states. 253 | 254 | Args: 255 | prompt (`str` or `List[str]`, *optional*): 256 | prompt to be encoded 257 | negative_prompt (`str` or `List[str]`, *optional*): 258 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 259 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 260 | less than `1`). 261 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 262 | Whether to use classifier free guidance or not. 263 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 264 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 265 | prompt_embeds (`torch.Tensor`, *optional*): 266 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 267 | provided, text embeddings will be generated from `prompt` input argument. 268 | negative_prompt_embeds (`torch.Tensor`, *optional*): 269 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 270 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 271 | argument. 272 | device: (`torch.device`, *optional*): 273 | torch device 274 | dtype: (`torch.dtype`, *optional*): 275 | torch dtype 276 | """ 277 | device = device or self._execution_device 278 | 279 | prompt = [prompt] if isinstance(prompt, str) else prompt 280 | if prompt is not None: 281 | batch_size = len(prompt) 282 | else: 283 | batch_size = prompt_embeds.shape[0] 284 | 285 | if prompt_embeds is None: 286 | prompt_embeds = self._get_t5_prompt_embeds( 287 | prompt=prompt, 288 | num_videos_per_prompt=num_videos_per_prompt, 289 | max_sequence_length=max_sequence_length, 290 | device=device, 291 | dtype=dtype, 292 | ) 293 | 294 | if do_classifier_free_guidance and negative_prompt_embeds is None: 295 | negative_prompt = negative_prompt or "" 296 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 297 | 298 | if prompt is not None and type(prompt) is not type(negative_prompt): 299 | raise TypeError( 300 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 301 | f" {type(prompt)}." 302 | ) 303 | elif batch_size != len(negative_prompt): 304 | raise ValueError( 305 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 306 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 307 | " the batch size of `prompt`." 308 | ) 309 | 310 | negative_prompt_embeds = self._get_t5_prompt_embeds( 311 | prompt=negative_prompt, 312 | num_videos_per_prompt=num_videos_per_prompt, 313 | max_sequence_length=max_sequence_length, 314 | device=device, 315 | dtype=dtype, 316 | ) 317 | 318 | return prompt_embeds, negative_prompt_embeds 319 | 320 | def prepare_latents( 321 | self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None 322 | ): 323 | if isinstance(generator, list) and len(generator) != batch_size: 324 | raise ValueError( 325 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 326 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 327 | ) 328 | 329 | shape = ( 330 | batch_size, 331 | (num_frames - 1) // self.vae_scale_factor_temporal + 1, 332 | num_channels_latents, 333 | height // self.vae_scale_factor_spatial, 334 | width // self.vae_scale_factor_spatial, 335 | ) 336 | 337 | if latents is None: 338 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 339 | else: 340 | latents = latents.to(device) 341 | 342 | # scale the initial noise by the standard deviation required by the scheduler 343 | latents = latents * self.scheduler.init_noise_sigma 344 | return latents 345 | 346 | def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: 347 | latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] 348 | latents = 1 / self.vae_scaling_factor_image * latents 349 | 350 | frames = self.vae.decode(latents).sample 351 | return frames 352 | 353 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 354 | def prepare_extra_step_kwargs(self, generator, eta): 355 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 356 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 357 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 358 | # and should be between [0, 1] 359 | 360 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 361 | extra_step_kwargs = {} 362 | if accepts_eta: 363 | extra_step_kwargs["eta"] = eta 364 | 365 | # check if the scheduler accepts generator 366 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 367 | if accepts_generator: 368 | extra_step_kwargs["generator"] = generator 369 | return extra_step_kwargs 370 | 371 | # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs 372 | def check_inputs( 373 | self, 374 | prompt, 375 | height, 376 | width, 377 | negative_prompt, 378 | callback_on_step_end_tensor_inputs, 379 | prompt_embeds=None, 380 | negative_prompt_embeds=None, 381 | ): 382 | if height % 8 != 0 or width % 8 != 0: 383 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 384 | 385 | if callback_on_step_end_tensor_inputs is not None and not all( 386 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 387 | ): 388 | raise ValueError( 389 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 390 | ) 391 | if prompt is not None and prompt_embeds is not None: 392 | raise ValueError( 393 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 394 | " only forward one of the two." 395 | ) 396 | elif prompt is None and prompt_embeds is None: 397 | raise ValueError( 398 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 399 | ) 400 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 401 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 402 | 403 | if prompt is not None and negative_prompt_embeds is not None: 404 | raise ValueError( 405 | f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" 406 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 407 | ) 408 | 409 | if negative_prompt is not None and negative_prompt_embeds is not None: 410 | raise ValueError( 411 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 412 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 413 | ) 414 | 415 | if prompt_embeds is not None and negative_prompt_embeds is not None: 416 | if prompt_embeds.shape != negative_prompt_embeds.shape: 417 | raise ValueError( 418 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 419 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 420 | f" {negative_prompt_embeds.shape}." 421 | ) 422 | 423 | def fuse_qkv_projections(self) -> None: 424 | r"""Enables fused QKV projections.""" 425 | self.fusing_transformer = True 426 | self.transformer.fuse_qkv_projections() 427 | 428 | def unfuse_qkv_projections(self) -> None: 429 | r"""Disable QKV projection fusion if enabled.""" 430 | if not self.fusing_transformer: 431 | logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") 432 | else: 433 | self.transformer.unfuse_qkv_projections() 434 | self.fusing_transformer = False 435 | 436 | def _prepare_rotary_positional_embeddings( 437 | self, 438 | height: int, 439 | width: int, 440 | num_frames: int, 441 | device: torch.device, 442 | ) -> Tuple[torch.Tensor, torch.Tensor]: 443 | grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) 444 | grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) 445 | base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) 446 | base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) 447 | 448 | grid_crops_coords = get_resize_crop_region_for_grid( 449 | (grid_height, grid_width), base_size_width, base_size_height 450 | ) 451 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 452 | embed_dim=self.transformer.config.attention_head_dim, 453 | crops_coords=grid_crops_coords, 454 | grid_size=(grid_height, grid_width), 455 | temporal_size=num_frames, 456 | ) 457 | 458 | freqs_cos = freqs_cos.to(device=device) 459 | freqs_sin = freqs_sin.to(device=device) 460 | return freqs_cos, freqs_sin 461 | 462 | @property 463 | def guidance_scale(self): 464 | return self._guidance_scale 465 | 466 | @property 467 | def num_timesteps(self): 468 | return self._num_timesteps 469 | 470 | @property 471 | def attention_kwargs(self): 472 | return self._attention_kwargs 473 | 474 | @property 475 | def interrupt(self): 476 | return self._interrupt 477 | 478 | @torch.no_grad() 479 | @replace_example_docstring(EXAMPLE_DOC_STRING) 480 | def __call__( 481 | self, 482 | prompt: Optional[Union[str, List[str]]] = None, 483 | negative_prompt: Optional[Union[str, List[str]]] = None, 484 | height: int = 480, 485 | width: int = 720, 486 | num_frames: int = 49, 487 | num_inference_steps: int = 50, 488 | timesteps: Optional[List[int]] = None, 489 | guidance_scale: float = 6, 490 | use_dynamic_cfg: bool = False, 491 | num_videos_per_prompt: int = 1, 492 | eta: float = 0.0, 493 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 494 | latents: Optional[torch.FloatTensor] = None, 495 | prompt_embeds: Optional[torch.FloatTensor] = None, 496 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 497 | output_type: str = "pil", 498 | return_dict: bool = True, 499 | attention_kwargs: Optional[Dict[str, Any]] = None, 500 | callback_on_step_end: Optional[ 501 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 502 | ] = None, 503 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 504 | max_sequence_length: int = 226, 505 | ) -> Union[CogVideoXPipelineOutput, Tuple]: 506 | """ 507 | Function invoked when calling the pipeline for generation. 508 | 509 | Args: 510 | prompt (`str` or `List[str]`, *optional*): 511 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 512 | instead. 513 | negative_prompt (`str` or `List[str]`, *optional*): 514 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 515 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 516 | less than `1`). 517 | height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): 518 | The height in pixels of the generated image. This is set to 480 by default for the best results. 519 | width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): 520 | The width in pixels of the generated image. This is set to 720 by default for the best results. 521 | num_frames (`int`, defaults to `48`): 522 | Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will 523 | contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where 524 | num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that 525 | needs to be satisfied is that of divisibility mentioned above. 526 | num_inference_steps (`int`, *optional*, defaults to 50): 527 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 528 | expense of slower inference. 529 | timesteps (`List[int]`, *optional*): 530 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 531 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 532 | passed will be used. Must be in descending order. 533 | guidance_scale (`float`, *optional*, defaults to 7.0): 534 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 535 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 536 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 537 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 538 | usually at the expense of lower image quality. 539 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 540 | The number of videos to generate per prompt. 541 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 542 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 543 | to make generation deterministic. 544 | latents (`torch.FloatTensor`, *optional*): 545 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 546 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 547 | tensor will ge generated by sampling using the supplied random `generator`. 548 | prompt_embeds (`torch.FloatTensor`, *optional*): 549 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 550 | provided, text embeddings will be generated from `prompt` input argument. 551 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 552 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 553 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 554 | argument. 555 | output_type (`str`, *optional*, defaults to `"pil"`): 556 | The output format of the generate image. Choose between 557 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 558 | return_dict (`bool`, *optional*, defaults to `True`): 559 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead 560 | of a plain tuple. 561 | attention_kwargs (`dict`, *optional*): 562 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 563 | `self.processor` in 564 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 565 | callback_on_step_end (`Callable`, *optional*): 566 | A function that calls at the end of each denoising steps during the inference. The function is called 567 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 568 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 569 | `callback_on_step_end_tensor_inputs`. 570 | callback_on_step_end_tensor_inputs (`List`, *optional*): 571 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 572 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 573 | `._callback_tensor_inputs` attribute of your pipeline class. 574 | max_sequence_length (`int`, defaults to `226`): 575 | Maximum sequence length in encoded prompt. Must be consistent with 576 | `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. 577 | 578 | Examples: 579 | 580 | Returns: 581 | [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: 582 | [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a 583 | `tuple`. When returning a tuple, the first element is a list with the generated images. 584 | """ 585 | 586 | if num_frames > 49: 587 | raise ValueError( 588 | "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." 589 | ) 590 | 591 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 592 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 593 | 594 | num_videos_per_prompt = 1 595 | 596 | # 1. Check inputs. Raise error if not correct 597 | self.check_inputs( 598 | prompt, 599 | height, 600 | width, 601 | negative_prompt, 602 | callback_on_step_end_tensor_inputs, 603 | prompt_embeds, 604 | negative_prompt_embeds, 605 | ) 606 | self._guidance_scale = guidance_scale 607 | self._attention_kwargs = attention_kwargs 608 | self._interrupt = False 609 | 610 | # 2. Default call parameters 611 | if prompt is not None and isinstance(prompt, str): 612 | batch_size = 1 613 | elif prompt is not None and isinstance(prompt, list): 614 | batch_size = len(prompt) 615 | else: 616 | batch_size = prompt_embeds.shape[0] 617 | 618 | device = self._execution_device 619 | 620 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 621 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 622 | # corresponds to doing no classifier free guidance. 623 | do_classifier_free_guidance = guidance_scale > 1.0 624 | 625 | # 3. Encode input prompt 626 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 627 | prompt, 628 | negative_prompt, 629 | do_classifier_free_guidance, 630 | num_videos_per_prompt=num_videos_per_prompt, 631 | prompt_embeds=prompt_embeds, 632 | negative_prompt_embeds=negative_prompt_embeds, 633 | max_sequence_length=max_sequence_length, 634 | device=device, 635 | ) 636 | if do_classifier_free_guidance: 637 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 638 | 639 | # 4. Prepare timesteps 640 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) 641 | self._num_timesteps = len(timesteps) 642 | 643 | # 5. Prepare latents. 644 | latent_channels = self.transformer.config.in_channels 645 | latents = self.prepare_latents( 646 | batch_size * num_videos_per_prompt, 647 | latent_channels, 648 | num_frames, 649 | height, 650 | width, 651 | prompt_embeds.dtype, 652 | device, 653 | generator, 654 | latents, 655 | ).repeat(1,2,1,1,1) # Luozhou 656 | 657 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 658 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 659 | 660 | # 7. Create rotary embeds if required 661 | image_rotary_emb = ( 662 | self._prepare_rotary_positional_embeddings(height, width, latents.size(1) // 2, device) # Luozhou 663 | if self.transformer.config.use_rotary_positional_embeddings 664 | else None 665 | ) 666 | 667 | # 8. Denoising loop 668 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 669 | 670 | with self.progress_bar(total=num_inference_steps) as progress_bar: 671 | # for DPM-solver++ 672 | old_pred_original_sample = None 673 | for i, t in enumerate(timesteps): 674 | if self.interrupt: 675 | continue 676 | 677 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 678 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 679 | 680 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 681 | timestep = t.expand(latent_model_input.shape[0]) 682 | 683 | # predict noise model_output 684 | noise_pred = self.transformer( 685 | hidden_states=latent_model_input, 686 | encoder_hidden_states=prompt_embeds, 687 | timestep=timestep, 688 | image_rotary_emb=image_rotary_emb, 689 | attention_kwargs=attention_kwargs, 690 | return_dict=False, 691 | )[0] 692 | noise_pred = noise_pred.float() 693 | 694 | # perform guidance 695 | if use_dynamic_cfg: 696 | self._guidance_scale = 1 + guidance_scale * ( 697 | (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 698 | ) 699 | if do_classifier_free_guidance: 700 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 701 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 702 | 703 | # compute the previous noisy sample x_t -> x_t-1 704 | if not isinstance(self.scheduler, CogVideoXDPMScheduler): 705 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 706 | else: 707 | latents, old_pred_original_sample = self.scheduler.step( 708 | noise_pred, 709 | old_pred_original_sample, 710 | t, 711 | timesteps[i - 1] if i > 0 else None, 712 | latents, 713 | **extra_step_kwargs, 714 | return_dict=False, 715 | ) 716 | latents = latents.to(prompt_embeds.dtype) 717 | 718 | # call the callback, if provided 719 | if callback_on_step_end is not None: 720 | callback_kwargs = {} 721 | for k in callback_on_step_end_tensor_inputs: 722 | callback_kwargs[k] = locals()[k] 723 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 724 | 725 | latents = callback_outputs.pop("latents", latents) 726 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 727 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 728 | 729 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 730 | progress_bar.update() 731 | 732 | if not output_type == "latent": 733 | video = self.decode_latents(latents) 734 | video = self.video_processor.postprocess_video(video=video, output_type=output_type) 735 | else: 736 | video = latents 737 | 738 | # Offload all models 739 | self.maybe_free_model_hooks() 740 | 741 | if not return_dict: 742 | return (video,) 743 | 744 | return CogVideoXPipelineOutput(frames=video) 745 | -------------------------------------------------------------------------------- /CogVideoX/rgba_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Any, Dict, Optional, Tuple, Union 5 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 6 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 7 | from safetensors.torch import load_file 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | @torch.no_grad() 12 | def decode_latents(pipe, latents): 13 | video = pipe.decode_latents(latents) 14 | video = pipe.video_processor.postprocess_video(video=video, output_type="np") 15 | return video 16 | 17 | def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: 18 | """ 19 | Create an attention mask to block text from attending to alpha. 20 | 21 | Args: 22 | text_length: Length of the text sequence. 23 | seq_length: Length of the other sequence. 24 | device: The device where the mask will be stored. 25 | dtype: The data type of the mask tensor. 26 | 27 | Returns: 28 | An attention mask tensor. 29 | """ 30 | total_length = text_length + seq_length 31 | dense_mask = torch.ones((total_length, total_length), dtype=torch.bool) 32 | dense_mask[:text_length, text_length + seq_length // 2:] = False 33 | return dense_mask.to(device=device, dtype=dtype) 34 | 35 | class RGBALoRACogVideoXAttnProcessor: 36 | r""" 37 | Processor for implementing scaled dot-product attention for the CogVideoX model. 38 | It applies a rotary embedding on query and key vectors, but does not include spatial normalization. 39 | """ 40 | 41 | def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072): 42 | if not hasattr(F, "scaled_dot_product_attention"): 43 | raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.") 44 | 45 | # Initialize LoRA layers 46 | self.lora_alpha = lora_alpha 47 | self.lora_rank = lora_rank 48 | 49 | # Helper function to create LoRA layers 50 | def create_lora_layer(in_dim, mid_dim, out_dim): 51 | return nn.Sequential( 52 | nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype), 53 | nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) 54 | ) 55 | 56 | self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 57 | self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 58 | self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 59 | self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 60 | 61 | # Store attention mask 62 | self.attention_mask = attention_mask 63 | 64 | def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): 65 | """Applies LoRA updates to query, key, and value tensors.""" 66 | query_delta = self.to_q_lora(hidden_states).to(query.device) 67 | query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling 68 | 69 | key_delta = self.to_k_lora(hidden_states).to(key.device) 70 | key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling 71 | 72 | value_delta = self.to_v_lora(hidden_states).to(value.device) 73 | value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling 74 | 75 | return query, key, value 76 | 77 | def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn): 78 | """Applies rotary embeddings to query and key tensors.""" 79 | from diffusers.models.embeddings import apply_rotary_emb 80 | 81 | # Apply rotary embedding to RGB and alpha sections 82 | query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( 83 | query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) 84 | query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( 85 | query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) 86 | 87 | if not attn.is_cross_attention: 88 | key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( 89 | key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) 90 | key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( 91 | key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) 92 | 93 | return query, key 94 | 95 | def __call__( 96 | self, 97 | attn, 98 | hidden_states: torch.Tensor, 99 | encoder_hidden_states: torch.Tensor, 100 | attention_mask: Optional[torch.Tensor] = None, 101 | image_rotary_emb: Optional[torch.Tensor] = None, 102 | ) -> torch.Tensor: 103 | # Concatenate encoder and decoder hidden states 104 | text_seq_length = encoder_hidden_states.size(1) 105 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 106 | 107 | batch_size, sequence_length, _ = hidden_states.shape 108 | seq_len = hidden_states.shape[1] - text_seq_length 109 | scaling = self.lora_alpha / self.lora_rank 110 | 111 | # Apply LoRA to query, key, value 112 | query = attn.to_q(hidden_states) 113 | key = attn.to_k(hidden_states) 114 | value = attn.to_v(hidden_states) 115 | 116 | query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling) 117 | 118 | # Reshape query, key, value for multi-head attention 119 | inner_dim = key.shape[-1] 120 | head_dim = inner_dim // attn.heads 121 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 122 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 123 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 124 | 125 | # Normalize query and key if required 126 | if attn.norm_q is not None: 127 | query = attn.norm_q(query) 128 | if attn.norm_k is not None: 129 | key = attn.norm_k(key) 130 | 131 | # Apply rotary embeddings if provided 132 | if image_rotary_emb is not None: 133 | query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn) 134 | 135 | # Compute scaled dot-product attention 136 | hidden_states = F.scaled_dot_product_attention( 137 | query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False 138 | ) 139 | 140 | # Reshape the output tensor back to the original shape 141 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 142 | 143 | # Apply linear projection and LoRA to the output 144 | original_hidden_states = attn.to_out[0](hidden_states) 145 | hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) 146 | original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling 147 | 148 | # Apply dropout 149 | hidden_states = attn.to_out[1](original_hidden_states) 150 | 151 | # Split back into encoder and decoder hidden states 152 | encoder_hidden_states, hidden_states = hidden_states.split( 153 | [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 154 | ) 155 | 156 | return hidden_states, encoder_hidden_states 157 | 158 | def prepare_for_rgba_inference( 159 | model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype, 160 | lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100 161 | ): 162 | def load_lora_sequential_weights(lora_layer, lora_layers, prefix): 163 | lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]}) 164 | lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]}) 165 | 166 | 167 | rgba_weights = load_file(rgba_weights_path) 168 | aux_emb = rgba_weights['domain_emb'] 169 | 170 | attention_mask = create_attention_mask(text_length, seq_length, device, dtype) 171 | attn_procs = {} 172 | 173 | for name in model.attn_processors.keys(): 174 | attn_processor = RGBALoRACogVideoXAttnProcessor( 175 | device=device, dtype=dtype, attention_mask=attention_mask, 176 | lora_rank=lora_rank, lora_alpha=lora_alpha 177 | ) 178 | 179 | index = name.split('.')[1] 180 | base_prefix = f'transformer.transformer_blocks.{index}.attn1' 181 | 182 | for lora_layer, prefix in [ 183 | (attn_processor.to_q_lora, f'{base_prefix}.to_q'), 184 | (attn_processor.to_k_lora, f'{base_prefix}.to_k'), 185 | (attn_processor.to_v_lora, f'{base_prefix}.to_v'), 186 | (attn_processor.to_out_lora, f'{base_prefix}.to_out.0'), 187 | ]: 188 | load_lora_sequential_weights(lora_layer, rgba_weights, prefix) 189 | 190 | attn_procs[name] = attn_processor 191 | 192 | model.set_attn_processor(attn_procs) 193 | 194 | def custom_forward(self): 195 | def forward( 196 | hidden_states: torch.Tensor, 197 | encoder_hidden_states: torch.Tensor, 198 | timestep: Union[int, float, torch.LongTensor], 199 | timestep_cond: Optional[torch.Tensor] = None, 200 | ofs: Optional[Union[int, float, torch.LongTensor]] = None, 201 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 202 | attention_kwargs: Optional[Dict[str, Any]] = None, 203 | return_dict: bool = True, 204 | ): 205 | if attention_kwargs is not None: 206 | attention_kwargs = attention_kwargs.copy() 207 | lora_scale = attention_kwargs.pop("scale", 1.0) 208 | else: 209 | lora_scale = 1.0 210 | 211 | if USE_PEFT_BACKEND: 212 | # weight the lora layers by setting `lora_scale` for each PEFT layer 213 | scale_lora_layers(self, lora_scale) 214 | else: 215 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 216 | logger.warning( 217 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 218 | ) 219 | 220 | batch_size, num_frames, channels, height, width = hidden_states.shape 221 | 222 | # 1. Time embedding 223 | timesteps = timestep 224 | t_emb = self.time_proj(timesteps) 225 | 226 | # timesteps does not contain any weights and will always return f32 tensors 227 | # but time_embedding might actually be running in fp16. so we need to cast here. 228 | # there might be better ways to encapsulate this. 229 | t_emb = t_emb.to(dtype=hidden_states.dtype) 230 | emb = self.time_embedding(t_emb, timestep_cond) 231 | 232 | if self.ofs_embedding is not None: 233 | ofs_emb = self.ofs_proj(ofs) 234 | ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) 235 | ofs_emb = self.ofs_embedding(ofs_emb) 236 | emb = emb + ofs_emb 237 | 238 | # 2. Patch embedding 239 | hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) 240 | hidden_states = self.embedding_dropout(hidden_states) 241 | 242 | text_seq_length = encoder_hidden_states.shape[1] 243 | encoder_hidden_states = hidden_states[:, :text_seq_length] 244 | hidden_states = hidden_states[:, text_seq_length:] 245 | 246 | hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype) 247 | 248 | # 3. Transformer blocks 249 | for i, block in enumerate(self.transformer_blocks): 250 | if torch.is_grad_enabled() and self.gradient_checkpointing: 251 | 252 | def create_custom_forward(module): 253 | def custom_forward(*inputs): 254 | return module(*inputs) 255 | 256 | return custom_forward 257 | 258 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 259 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 260 | create_custom_forward(block), 261 | hidden_states, 262 | encoder_hidden_states, 263 | emb, 264 | image_rotary_emb, 265 | **ckpt_kwargs, 266 | ) 267 | else: 268 | hidden_states, encoder_hidden_states = block( 269 | hidden_states=hidden_states, 270 | encoder_hidden_states=encoder_hidden_states, 271 | temb=emb, 272 | image_rotary_emb=image_rotary_emb, 273 | ) 274 | 275 | if not self.config.use_rotary_positional_embeddings: 276 | # CogVideoX-2B 277 | hidden_states = self.norm_final(hidden_states) 278 | else: 279 | # CogVideoX-5B 280 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 281 | hidden_states = self.norm_final(hidden_states) 282 | hidden_states = hidden_states[:, text_seq_length:] 283 | 284 | # 4. Final block 285 | hidden_states = self.norm_out(hidden_states, temb=emb) 286 | hidden_states = self.proj_out(hidden_states) 287 | 288 | # 5. Unpatchify 289 | p = self.config.patch_size 290 | p_t = self.config.patch_size_t 291 | 292 | if p_t is None: 293 | output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) 294 | output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) 295 | else: 296 | output = hidden_states.reshape( 297 | batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p 298 | ) 299 | output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) 300 | 301 | if USE_PEFT_BACKEND: 302 | # remove `lora_scale` from each PEFT layer 303 | unscale_lora_layers(self, lora_scale) 304 | 305 | if not return_dict: 306 | return (output,) 307 | return Transformer2DModelOutput(sample=output) 308 | 309 | 310 | return forward 311 | 312 | model.forward = custom_forward(model) 313 | 314 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | **ADOBE RESEARCH LICENSE** 2 | 3 | This license agreement (the “License”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (“Adobe”), and you, the individual or entity exercising rights under this License (“you” or “your”), sets forth the terms for your use of certain research materials that are owned by Adobe (the “Licensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this License on behalf of an entity, then “you” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License. 4 | 5 | 1. **GRANT OF LICENSE.**
6 | 1.1 Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License upon redistribution.
7 | 1.2 You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.
8 | 1.3 For purposes of this License, noncommercial research purposes include academic research and teaching only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain.
9 | 2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must retain all copyright notices and/or disclaimers in the Licensed Materials. 10 | 3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED “AS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS. 11 | 4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 12 | 5. **TERM AND TERMINATION.**
13 | 5.1 The License is effective upon acceptance by you and will remain in effect unless terminated earlier in accordance with Section 5.2.
14 | 5.2 Any breach of any material provision of this License will automatically terminate the rights granted herein.
15 | 5.3 Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), 4 (Limitation of Liability) will survive termination of this License.
16 | -------------------------------------------------------------------------------- /Mochi/README.md: -------------------------------------------------------------------------------- 1 | # RGBA LoRA Training Instructions 2 | 3 | 13 | 21 | 22 | We follow the same steps as the original [finetrainers](https://github.com/a-r-r-o-w/finetrainers/blob/main/training/mochi-1/README.md) to prepare the [RGBA dataset](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets). 23 | For RGBA dataset, you can follow the instructions above to preprocess the dataset yourself. 24 | 25 | Here are some detailed steps to prepare the dataset for Mochi-1 fine-tuning: 26 | 27 | 1. Download our preprocessed [Video RGBA dataset](https://hkustgz-my.sharepoint.com/:u:/g/personal/lwang592_connect_hkust-gz_edu_cn/EezKQoum3IVJiJ9c8GebNfYBe-xN0OS5mVUvAwyL_rQLuw?e=1obdbA), which has undergone preprocessing operations such as color decontamination and background blur. 28 | 2. Use `trim_and_crop_videos.py` to crop and trim the RGB and Alpha videos as needed. 29 | 3. Use `embed.py` to encode the RGB videos into latent representations and embed the video captions into embeddings. 30 | 4. Use `embed.py` to encode the Alpha videos into latent representations. 31 | 5. Concatenate the RGB and Alpha latent representations along the frames dimension. 32 | 33 | Finally, the dataset should be in the following format: 34 | ``` 35 | .latent.pt 36 | .embed.pt 37 | .latent.pt 38 | .embed.pt 39 | ``` 40 | 41 | 42 | Now, we're ready to fine-tune. To launch, run: 43 | 44 | ```bash 45 | bash train.sh 46 | ``` 47 | **Note:** 48 | 49 | The arg `--num_frames` is used to specify the number of frames of generated **RGB** video. During generation, we will actually double the number of frames to generate the **RGB** video and **Alpha** video jointly. This double operation is automatically handled by our implementation. 50 | 51 | For an 80GB GPU, we support processing RGB videos with dimensions of 480 × 848 × 79 (Height × Width × Frames) at a batch size of 1 using bfloat16 precision for training. However, the training is relatively slow (over one minute per iteration) because the model processes a total of 79 × 2 frames as input. 52 | 53 | 54 | 55 | 56 | 57 | ~~We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.~~ 58 | 59 | ## Inference 60 | 61 | To generate the RGBA video, run: 62 | 63 | ```bash 64 | python cli.py \ 65 | --lora_path /path/to/lora \ 66 | --prompt "..." \ 67 | ``` 68 | 69 | This command generates the RGB and Alpha videos simultaneously and saves them. Specifically, the RGB video is saved in its premultiplied form. To blend this video with any background image, you can simply use the following formula: 70 | 71 | ```python 72 | com = rgb + (1 - alpha) * bgr 73 | ``` 74 | 75 | ## Known limitations 76 | 77 | (Contributions are welcome 🤗) 78 | 79 | Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below: 80 | 81 | * No support for distributed training. 82 | * `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. 83 | * No support for 8bit optimizers (but should be relatively easy to add). 84 | 85 | **Misc**: 86 | 87 | * We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). 88 | * `embed.py` script is non-batched. 89 | -------------------------------------------------------------------------------- /Mochi/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Default values taken from 3 | https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml 4 | when applicable. 5 | """ 6 | 7 | import argparse 8 | 9 | 10 | def _get_model_args(parser: argparse.ArgumentParser) -> None: 11 | parser.add_argument( 12 | "--pretrained_model_name_or_path", 13 | type=str, 14 | default=None, 15 | required=True, 16 | help="Path to pretrained model or model identifier from huggingface.co/models.", 17 | ) 18 | parser.add_argument( 19 | "--revision", 20 | type=str, 21 | default=None, 22 | required=False, 23 | help="Revision of pretrained model identifier from huggingface.co/models.", 24 | ) 25 | parser.add_argument( 26 | "--variant", 27 | type=str, 28 | default=None, 29 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 30 | ) 31 | parser.add_argument( 32 | "--cache_dir", 33 | type=str, 34 | default=None, 35 | help="The directory where the downloaded models and datasets will be stored.", 36 | ) 37 | parser.add_argument( 38 | "--cast_dit", 39 | action="store_true", 40 | help="If we should cast DiT params to a lower precision.", 41 | ) 42 | parser.add_argument( 43 | "--compile_dit", 44 | action="store_true", 45 | help="If we should compile the DiT.", 46 | ) 47 | 48 | 49 | def _get_dataset_args(parser: argparse.ArgumentParser) -> None: 50 | parser.add_argument( 51 | "--data_root", 52 | type=str, 53 | default=None, 54 | help=("A folder containing the training data."), 55 | ) 56 | parser.add_argument( 57 | "--caption_dropout", 58 | type=float, 59 | default=None, 60 | help=("Probability to drop out captions randomly."), 61 | ) 62 | 63 | parser.add_argument( 64 | "--dataloader_num_workers", 65 | type=int, 66 | default=0, 67 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", 68 | ) 69 | parser.add_argument( 70 | "--pin_memory", 71 | action="store_true", 72 | help="Whether or not to use the pinned memory setting in pytorch dataloader.", 73 | ) 74 | 75 | 76 | def _get_validation_args(parser: argparse.ArgumentParser) -> None: 77 | parser.add_argument( 78 | "--validation_prompt", 79 | type=str, 80 | default=None, 81 | help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", 82 | ) 83 | parser.add_argument( 84 | "--validation_images", 85 | type=str, 86 | default=None, 87 | help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", 88 | ) 89 | parser.add_argument( 90 | "--validation_prompt_separator", 91 | type=str, 92 | default=":::", 93 | help="String that separates multiple validation prompts", 94 | ) 95 | parser.add_argument( 96 | "--num_validation_videos", 97 | type=int, 98 | default=1, 99 | help="Number of videos that should be generated during validation per `validation_prompt`.", 100 | ) 101 | parser.add_argument( 102 | "--validation_epochs", 103 | type=int, 104 | default=50, 105 | help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", 106 | ) 107 | parser.add_argument( 108 | "--enable_slicing", 109 | action="store_true", 110 | default=False, 111 | help="Whether or not to use VAE slicing for saving memory.", 112 | ) 113 | parser.add_argument( 114 | "--enable_tiling", 115 | action="store_true", 116 | default=False, 117 | help="Whether or not to use VAE tiling for saving memory.", 118 | ) 119 | parser.add_argument( 120 | "--enable_model_cpu_offload", 121 | action="store_true", 122 | default=False, 123 | help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", 124 | ) 125 | parser.add_argument( 126 | "--fps", 127 | type=int, 128 | default=30, 129 | help="FPS to use when serializing the output videos.", 130 | ) 131 | parser.add_argument( 132 | "--height", 133 | type=int, 134 | default=480, 135 | ) 136 | parser.add_argument( 137 | "--width", 138 | type=int, 139 | default=848, 140 | ) 141 | 142 | 143 | def _get_training_args(parser: argparse.ArgumentParser) -> None: 144 | parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") 145 | parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") 146 | parser.add_argument( 147 | "--lora_alpha", 148 | type=int, 149 | default=16, 150 | help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", 151 | ) 152 | parser.add_argument( 153 | "--target_modules", 154 | nargs="+", 155 | type=str, 156 | default=["to_k", "to_q", "to_v", "to_out.0"], 157 | help="Target modules to train LoRA for.", 158 | ) 159 | parser.add_argument( 160 | "--output_dir", 161 | type=str, 162 | default="mochi-lora", 163 | help="The output directory where the model predictions and checkpoints will be written.", 164 | ) 165 | parser.add_argument( 166 | "--train_batch_size", 167 | type=int, 168 | default=4, 169 | help="Batch size (per device) for the training dataloader.", 170 | ) 171 | parser.add_argument("--num_train_epochs", type=int, default=1) 172 | parser.add_argument( 173 | "--max_train_steps", 174 | type=int, 175 | default=None, 176 | help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", 177 | ) 178 | parser.add_argument( 179 | "--gradient_checkpointing", 180 | action="store_true", 181 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 182 | ) 183 | parser.add_argument( 184 | "--learning_rate", 185 | type=float, 186 | default=2e-4, 187 | help="Initial learning rate (after the potential warmup period) to use.", 188 | ) 189 | parser.add_argument( 190 | "--scale_lr", 191 | action="store_true", 192 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 193 | ) 194 | parser.add_argument( 195 | "--lr_warmup_steps", 196 | type=int, 197 | default=200, 198 | help="Number of steps for the warmup in the lr scheduler.", 199 | ) 200 | parser.add_argument( 201 | "--checkpointing_steps", 202 | type=int, 203 | default=1000, 204 | ) 205 | parser.add_argument( 206 | "--resume_from_checkpoint", 207 | type=str, 208 | default=None, 209 | ) 210 | 211 | 212 | def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: 213 | parser.add_argument( 214 | "--optimizer", 215 | type=lambda s: s.lower(), 216 | default="adam", 217 | choices=["adam", "adamw"], 218 | help=("The optimizer type to use."), 219 | ) 220 | parser.add_argument( 221 | "--weight_decay", 222 | type=float, 223 | default=0.01, 224 | help="Weight decay to use for optimizer.", 225 | ) 226 | 227 | 228 | def _get_configuration_args(parser: argparse.ArgumentParser) -> None: 229 | parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") 230 | parser.add_argument( 231 | "--push_to_hub", 232 | action="store_true", 233 | help="Whether or not to push the model to the Hub.", 234 | ) 235 | parser.add_argument( 236 | "--hub_token", 237 | type=str, 238 | default=None, 239 | help="The token to use to push to the Model Hub.", 240 | ) 241 | parser.add_argument( 242 | "--hub_model_id", 243 | type=str, 244 | default=None, 245 | help="The name of the repository to keep in sync with the local `output_dir`.", 246 | ) 247 | parser.add_argument( 248 | "--allow_tf32", 249 | action="store_true", 250 | help=( 251 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 252 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 253 | ), 254 | ) 255 | parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.") 256 | 257 | 258 | def get_args(): 259 | parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") 260 | 261 | _get_model_args(parser) 262 | _get_dataset_args(parser) 263 | _get_training_args(parser) 264 | _get_validation_args(parser) 265 | _get_optimizer_args(parser) 266 | _get_configuration_args(parser) 267 | 268 | return parser.parse_args() 269 | -------------------------------------------------------------------------------- /Mochi/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | # from diffusers import MochiPipeline 4 | from pipeline_mochi_rgba import MochiPipeline 5 | from diffusers.utils import export_to_video 6 | import argparse 7 | from rgba_utils import * 8 | import numpy as np 9 | 10 | 11 | def main(args): 12 | # 1. load pipeline 13 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16).to("cuda") 14 | pipe.enable_vae_tiling() 15 | 16 | # 2. define prompt and arguments 17 | pipeline_args = { 18 | "prompt": args.prompt, 19 | "guidance_scale": args.guidance_scale, 20 | "num_inference_steps": args.num_inference_steps, 21 | "height": args.height, 22 | "width": args.width, 23 | "num_frames": args.num_frames, 24 | "max_sequence_length": 256, 25 | "output_type": "latent", 26 | } 27 | 28 | # 3. prepare rgbx utils 29 | prepare_for_rgba_inference( 30 | pipe.transformer, 31 | device="cuda", 32 | dtype=torch.bfloat16, 33 | ) 34 | 35 | if args.lora_path is not None: 36 | checkpoint = torch.load(args.lora_path, map_location="cpu") 37 | processor_state_dict = checkpoint["state_dict"] 38 | load_processor_state_dict(pipe.transformer, processor_state_dict) 39 | 40 | 41 | # 4. inference 42 | generator = torch.manual_seed(args.seed) if args.seed else None 43 | frames_latents = pipe(**pipeline_args, generator=generator).frames 44 | 45 | frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=2) 46 | 47 | frames_rgb = decode_latents(pipe, frames_latents_rgb) 48 | frames_alpha = decode_latents(pipe, frames_latents_alpha) 49 | 50 | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) 51 | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) 52 | premultiplied_rgb = frames_rgb * frames_alpha_pooled 53 | 54 | if os.path.exists(args.output_path) == False: 55 | os.makedirs(args.output_path) 56 | 57 | export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps) 58 | export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps) 59 | export_to_video(frames_rgb[0], os.path.join(args.output_path, "original_rgb.mp4"), fps=args.fps) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser(description="Generate a video from a text prompt") 63 | parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") 64 | parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used") 65 | 66 | parser.add_argument( 67 | "--model_path", type=str, default="genmo/mochi-1-preview", help="Path of the pre-trained model use" 68 | ) 69 | parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video") 70 | parser.add_argument("--guidance_scale", type=float, default=6, help="The scale for classifier-free guidance") 71 | parser.add_argument("--num_inference_steps", type=int, default=64, help="Inference steps") 72 | parser.add_argument("--num_frames", type=int, default=79, help="Number of steps for the inference process") 73 | parser.add_argument("--width", type=int, default=848, help="Number of steps for the inference process") 74 | parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process") 75 | parser.add_argument("--fps", type=int, default=30, help="Number of steps for the inference process") 76 | parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility") 77 | args = parser.parse_args() 78 | 79 | main(args) 80 | -------------------------------------------------------------------------------- /Mochi/dataset_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from 3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import click 9 | import torch 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | 13 | def load_to_cpu(x): 14 | return torch.load(x, map_location=torch.device("cpu"), weights_only=True) 15 | 16 | 17 | class LatentEmbedDataset(Dataset): 18 | def __init__(self, file_paths, repeat=1): 19 | self.items = [ 20 | (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt")) 21 | for p in file_paths 22 | if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file() 23 | ] 24 | self.items = self.items * repeat 25 | print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.") 26 | 27 | def __len__(self): 28 | return len(self.items) 29 | 30 | def __getitem__(self, idx): 31 | latent_path, embed_path = self.items[idx] 32 | return load_to_cpu(latent_path), load_to_cpu(embed_path) 33 | 34 | 35 | @click.command() 36 | @click.argument("directory", type=click.Path(exists=True, file_okay=False)) 37 | def process_videos(directory): 38 | dir_path = Path(directory) 39 | mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")] 40 | assert mp4_files, f"No mp4 files found" 41 | 42 | dataset = LatentEmbedDataset(mp4_files) 43 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) 44 | 45 | for latents, embeds in dataloader: 46 | print([(k, v.shape) for k, v in latents.items()]) 47 | 48 | 49 | if __name__ == "__main__": 50 | process_videos() 51 | -------------------------------------------------------------------------------- /Mochi/embed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: 3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py 4 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py 5 | """ 6 | 7 | import click 8 | import torch 9 | import torchvision 10 | from pathlib import Path 11 | from diffusers import AutoencoderKLMochi, MochiPipeline 12 | from transformers import T5EncoderModel, T5Tokenizer 13 | from tqdm.auto import tqdm 14 | 15 | 16 | def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): 17 | T, H, W = [int(s) for s in shape.split("x")] 18 | assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" 19 | video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") 20 | fps = metadata["video_fps"] 21 | video = video.permute(3, 0, 1, 2) 22 | og_shape = video.shape 23 | assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" 24 | assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" 25 | assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" 26 | if video.shape[1] > T: 27 | video = video[:, :T] 28 | print(f"Trimmed video from {og_shape[1]} to first {T} frames") 29 | video = video.unsqueeze(0) 30 | video = video.float() / 127.5 - 1.0 31 | video = video.to(model.device) 32 | 33 | assert video.ndim == 5 34 | 35 | with torch.inference_mode(): 36 | with torch.autocast("cuda", dtype=torch.bfloat16): 37 | ldist = model._encode(video) 38 | 39 | torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) 40 | 41 | 42 | @click.command() 43 | @click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) 44 | @click.option( 45 | "--model_id", 46 | type=str, 47 | help="Repo id. Should be genmo/mochi-1-preview", 48 | default="genmo/mochi-1-preview", 49 | ) 50 | @click.option("--shape", default="163x480x848", help="Shape of the video to encode") 51 | @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.") 52 | def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: 53 | """Process all videos and captions in a directory using a single GPU.""" 54 | # comment out when running on unsupported hardware 55 | torch.backends.cuda.matmul.allow_tf32 = True 56 | torch.backends.cudnn.allow_tf32 = True 57 | 58 | # Get all video paths 59 | video_paths = list(output_dir.glob("**/*.mp4")) 60 | if not video_paths: 61 | print(f"No MP4 files found in {output_dir}") 62 | return 63 | 64 | text_paths = list(output_dir.glob("**/*.txt")) 65 | if not text_paths: 66 | print(f"No text files found in {output_dir}") 67 | return 68 | 69 | # load the models 70 | vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") 71 | text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") 72 | tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") 73 | pipeline = MochiPipeline.from_pretrained( 74 | model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None 75 | ).to("cuda") 76 | 77 | for idx, video_path in tqdm(enumerate(sorted(video_paths))): 78 | print(f"Processing {video_path}") 79 | try: 80 | if video_path.with_suffix(".latent.pt").exists() and not overwrite: 81 | print(f"Skipping {video_path}") 82 | continue 83 | 84 | # encode videos. 85 | encode_videos(vae, vid_path=video_path, shape=shape) 86 | 87 | # embed captions. 88 | prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") 89 | embed_path = prompt_path.with_suffix(".embed.pt") 90 | 91 | if embed_path.exists() and not overwrite: 92 | print(f"Skipping {prompt_path} - embeddings already exist") 93 | continue 94 | 95 | with open(prompt_path) as f: 96 | text = f.read().strip() 97 | with torch.inference_mode(): 98 | conditioning = pipeline.encode_prompt(prompt=[text]) 99 | 100 | conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} 101 | torch.save(conditioning, embed_path) 102 | 103 | except Exception as e: 104 | import traceback 105 | 106 | traceback.print_exc() 107 | print(f"Error processing {video_path}: {str(e)}") 108 | 109 | 110 | if __name__ == "__main__": 111 | batch_process() -------------------------------------------------------------------------------- /Mochi/pipeline_mochi_rgba.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Genmo and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import T5EncoderModel, T5TokenizerFast 21 | 22 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 23 | from diffusers.loaders import Mochi1LoraLoaderMixin 24 | from diffusers.models.autoencoders import AutoencoderKL 25 | from diffusers.models.transformers import MochiTransformer3DModel 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import ( 28 | is_torch_xla_available, 29 | logging, 30 | replace_example_docstring, 31 | ) 32 | from diffusers.utils.torch_utils import randn_tensor 33 | from diffusers.video_processor import VideoProcessor 34 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 35 | from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput 36 | 37 | 38 | if is_torch_xla_available(): 39 | import torch_xla.core.xla_model as xm 40 | 41 | XLA_AVAILABLE = True 42 | else: 43 | XLA_AVAILABLE = False 44 | 45 | 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 47 | 48 | EXAMPLE_DOC_STRING = """ 49 | Examples: 50 | ```py 51 | >>> import torch 52 | >>> from diffusers import MochiPipeline 53 | >>> from diffusers.utils import export_to_video 54 | 55 | >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16) 56 | >>> pipe.enable_model_cpu_offload() 57 | >>> pipe.enable_vae_tiling() 58 | >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." 59 | >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] 60 | >>> export_to_video(frames, "mochi.mp4") 61 | ``` 62 | """ 63 | 64 | 65 | def calculate_shift( 66 | image_seq_len, 67 | base_seq_len: int = 256, 68 | max_seq_len: int = 4096, 69 | base_shift: float = 0.5, 70 | max_shift: float = 1.16, 71 | ): 72 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 73 | b = base_shift - m * base_seq_len 74 | mu = image_seq_len * m + b 75 | return mu 76 | 77 | 78 | # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 79 | def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): 80 | if linear_steps is None: 81 | linear_steps = num_steps // 2 82 | linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] 83 | threshold_noise_step_diff = linear_steps - threshold_noise * num_steps 84 | quadratic_steps = num_steps - linear_steps 85 | quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) 86 | linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) 87 | const = quadratic_coef * (linear_steps**2) 88 | quadratic_sigma_schedule = [ 89 | quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) 90 | ] 91 | sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule 92 | sigma_schedule = [1.0 - x for x in sigma_schedule] 93 | return sigma_schedule 94 | 95 | 96 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 97 | def retrieve_timesteps( 98 | scheduler, 99 | num_inference_steps: Optional[int] = None, 100 | device: Optional[Union[str, torch.device]] = None, 101 | timesteps: Optional[List[int]] = None, 102 | sigmas: Optional[List[float]] = None, 103 | **kwargs, 104 | ): 105 | r""" 106 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 107 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 108 | 109 | Args: 110 | scheduler (`SchedulerMixin`): 111 | The scheduler to get timesteps from. 112 | num_inference_steps (`int`): 113 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 114 | must be `None`. 115 | device (`str` or `torch.device`, *optional*): 116 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 117 | timesteps (`List[int]`, *optional*): 118 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 119 | `num_inference_steps` and `sigmas` must be `None`. 120 | sigmas (`List[float]`, *optional*): 121 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 122 | `num_inference_steps` and `timesteps` must be `None`. 123 | 124 | Returns: 125 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 126 | second element is the number of inference steps. 127 | """ 128 | if timesteps is not None and sigmas is not None: 129 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 130 | if timesteps is not None: 131 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 132 | if not accepts_timesteps: 133 | raise ValueError( 134 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 135 | f" timestep schedules. Please check whether you are using the correct scheduler." 136 | ) 137 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 138 | timesteps = scheduler.timesteps 139 | num_inference_steps = len(timesteps) 140 | elif sigmas is not None: 141 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 142 | if not accept_sigmas: 143 | raise ValueError( 144 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 145 | f" sigmas schedules. Please check whether you are using the correct scheduler." 146 | ) 147 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 148 | timesteps = scheduler.timesteps 149 | num_inference_steps = len(timesteps) 150 | else: 151 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 152 | timesteps = scheduler.timesteps 153 | return timesteps, num_inference_steps 154 | 155 | 156 | 157 | 158 | 159 | 160 | def prepare_attention_mask(prompt_attention_mask, latents): 161 | 162 | device = prompt_attention_mask.device 163 | 164 | (_, _, num_frames, height, width) = latents.shape # shape of two modalities 165 | seq_length = (height // 2) * (width // 2) * num_frames 166 | 167 | rect_attention_mask = [] 168 | for prompt_attention_mask_i in prompt_attention_mask: 169 | text_length = torch.sum(prompt_attention_mask_i).item() 170 | total_length = text_length + seq_length 171 | 172 | if text_length == 0: 173 | rect_attention_mask.append(None) 174 | else: 175 | dense_mask = torch.ones((total_length, total_length), dtype=torch.bool) 176 | dense_mask[seq_length:, seq_length // 2: seq_length] = False 177 | rect_attention_mask.append(dense_mask.to(device)) 178 | 179 | return { 180 | "prompt_attention_mask": prompt_attention_mask, 181 | "rect_attention_mask": rect_attention_mask, 182 | } 183 | 184 | 185 | class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): 186 | r""" 187 | The mochi pipeline for text-to-video generation. 188 | 189 | Reference: https://github.com/genmoai/models 190 | 191 | Args: 192 | transformer ([`MochiTransformer3DModel`]): 193 | Conditional Transformer architecture to denoise the encoded video latents. 194 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 195 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 196 | vae ([`AutoencoderKL`]): 197 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 198 | text_encoder ([`T5EncoderModel`]): 199 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 200 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 201 | tokenizer (`CLIPTokenizer`): 202 | Tokenizer of class 203 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 204 | tokenizer (`T5TokenizerFast`): 205 | Second Tokenizer of class 206 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 207 | """ 208 | 209 | model_cpu_offload_seq = "text_encoder->transformer->vae" 210 | _optional_components = [] 211 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 212 | 213 | def __init__( 214 | self, 215 | scheduler: FlowMatchEulerDiscreteScheduler, 216 | vae: AutoencoderKL, 217 | text_encoder: T5EncoderModel, 218 | tokenizer: T5TokenizerFast, 219 | transformer: MochiTransformer3DModel, 220 | force_zeros_for_empty_prompt: bool = False, 221 | ): 222 | super().__init__() 223 | 224 | self.register_modules( 225 | vae=vae, 226 | text_encoder=text_encoder, 227 | tokenizer=tokenizer, 228 | transformer=transformer, 229 | scheduler=scheduler, 230 | ) 231 | # TODO: determine these scaling factors from model parameters 232 | self.vae_spatial_scale_factor = 8 233 | self.vae_temporal_scale_factor = 6 234 | self.patch_size = 2 235 | 236 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) 237 | self.tokenizer_max_length = ( 238 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256 239 | ) 240 | self.default_height = 480 241 | self.default_width = 848 242 | self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) 243 | 244 | def _get_t5_prompt_embeds( 245 | self, 246 | prompt: Union[str, List[str]] = None, 247 | num_videos_per_prompt: int = 1, 248 | max_sequence_length: int = 256, 249 | device: Optional[torch.device] = None, 250 | dtype: Optional[torch.dtype] = None, 251 | ): 252 | device = device or self._execution_device 253 | dtype = dtype or self.text_encoder.dtype 254 | 255 | prompt = [prompt] if isinstance(prompt, str) else prompt 256 | batch_size = len(prompt) 257 | 258 | text_inputs = self.tokenizer( 259 | prompt, 260 | padding="max_length", 261 | max_length=max_sequence_length, 262 | truncation=True, 263 | add_special_tokens=True, 264 | return_tensors="pt", 265 | ) 266 | 267 | text_input_ids = text_inputs.input_ids 268 | prompt_attention_mask = text_inputs.attention_mask 269 | prompt_attention_mask = prompt_attention_mask.bool().to(device) 270 | 271 | # The original Mochi implementation zeros out empty negative prompts 272 | # but this can lead to overflow when placing the entire pipeline under the autocast context 273 | # adding this here so that we can enable zeroing prompts if necessary 274 | if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""): 275 | text_input_ids = torch.zeros_like(text_input_ids, device=device) 276 | prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device) 277 | 278 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 279 | 280 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 281 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) 282 | logger.warning( 283 | "The following part of your input was truncated because `max_sequence_length` is set to " 284 | f" {max_sequence_length} tokens: {removed_text}" 285 | ) 286 | 287 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] 288 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 289 | 290 | # duplicate text embeddings for each generation per prompt, using mps friendly method 291 | _, seq_len, _ = prompt_embeds.shape 292 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 293 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 294 | 295 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) 296 | prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) 297 | 298 | return prompt_embeds, prompt_attention_mask 299 | 300 | # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt 301 | def encode_prompt( 302 | self, 303 | prompt: Union[str, List[str]], 304 | negative_prompt: Optional[Union[str, List[str]]] = None, 305 | do_classifier_free_guidance: bool = True, 306 | num_videos_per_prompt: int = 1, 307 | prompt_embeds: Optional[torch.Tensor] = None, 308 | negative_prompt_embeds: Optional[torch.Tensor] = None, 309 | prompt_attention_mask: Optional[torch.Tensor] = None, 310 | negative_prompt_attention_mask: Optional[torch.Tensor] = None, 311 | max_sequence_length: int = 256, 312 | device: Optional[torch.device] = None, 313 | dtype: Optional[torch.dtype] = None, 314 | ): 315 | r""" 316 | Encodes the prompt into text encoder hidden states. 317 | 318 | Args: 319 | prompt (`str` or `List[str]`, *optional*): 320 | prompt to be encoded 321 | negative_prompt (`str` or `List[str]`, *optional*): 322 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 323 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 324 | less than `1`). 325 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 326 | Whether to use classifier free guidance or not. 327 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 328 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 329 | prompt_embeds (`torch.Tensor`, *optional*): 330 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 331 | provided, text embeddings will be generated from `prompt` input argument. 332 | negative_prompt_embeds (`torch.Tensor`, *optional*): 333 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 334 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 335 | argument. 336 | device: (`torch.device`, *optional*): 337 | torch device 338 | dtype: (`torch.dtype`, *optional*): 339 | torch dtype 340 | """ 341 | device = device or self._execution_device 342 | 343 | prompt = [prompt] if isinstance(prompt, str) else prompt 344 | if prompt is not None: 345 | batch_size = len(prompt) 346 | else: 347 | batch_size = prompt_embeds.shape[0] 348 | 349 | if prompt_embeds is None: 350 | prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( 351 | prompt=prompt, 352 | num_videos_per_prompt=num_videos_per_prompt, 353 | max_sequence_length=max_sequence_length, 354 | device=device, 355 | dtype=dtype, 356 | ) 357 | 358 | if do_classifier_free_guidance and negative_prompt_embeds is None: 359 | negative_prompt = negative_prompt or "" 360 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 361 | 362 | if prompt is not None and type(prompt) is not type(negative_prompt): 363 | raise TypeError( 364 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 365 | f" {type(prompt)}." 366 | ) 367 | elif batch_size != len(negative_prompt): 368 | raise ValueError( 369 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 370 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 371 | " the batch size of `prompt`." 372 | ) 373 | 374 | negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( 375 | prompt=negative_prompt, 376 | num_videos_per_prompt=num_videos_per_prompt, 377 | max_sequence_length=max_sequence_length, 378 | device=device, 379 | dtype=dtype, 380 | ) 381 | 382 | return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask 383 | 384 | def check_inputs( 385 | self, 386 | prompt, 387 | height, 388 | width, 389 | callback_on_step_end_tensor_inputs=None, 390 | prompt_embeds=None, 391 | negative_prompt_embeds=None, 392 | prompt_attention_mask=None, 393 | negative_prompt_attention_mask=None, 394 | ): 395 | if height % 8 != 0 or width % 8 != 0: 396 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 397 | 398 | if callback_on_step_end_tensor_inputs is not None and not all( 399 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 400 | ): 401 | raise ValueError( 402 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 403 | ) 404 | 405 | if prompt is not None and prompt_embeds is not None: 406 | raise ValueError( 407 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 408 | " only forward one of the two." 409 | ) 410 | elif prompt is None and prompt_embeds is None: 411 | raise ValueError( 412 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 413 | ) 414 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 415 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 416 | 417 | if prompt_embeds is not None and prompt_attention_mask is None: 418 | raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") 419 | 420 | if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: 421 | raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") 422 | 423 | if prompt_embeds is not None and negative_prompt_embeds is not None: 424 | if prompt_embeds.shape != negative_prompt_embeds.shape: 425 | raise ValueError( 426 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 427 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 428 | f" {negative_prompt_embeds.shape}." 429 | ) 430 | if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: 431 | raise ValueError( 432 | "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" 433 | f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" 434 | f" {negative_prompt_attention_mask.shape}." 435 | ) 436 | 437 | def enable_vae_slicing(self): 438 | r""" 439 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 440 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 441 | """ 442 | self.vae.enable_slicing() 443 | 444 | def disable_vae_slicing(self): 445 | r""" 446 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 447 | computing decoding in one step. 448 | """ 449 | self.vae.disable_slicing() 450 | 451 | def enable_vae_tiling(self): 452 | r""" 453 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 454 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 455 | processing larger images. 456 | """ 457 | self.vae.enable_tiling() 458 | 459 | def disable_vae_tiling(self): 460 | r""" 461 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 462 | computing decoding in one step. 463 | """ 464 | self.vae.disable_tiling() 465 | 466 | def prepare_latents( 467 | self, 468 | batch_size, 469 | num_channels_latents, 470 | height, 471 | width, 472 | num_frames, 473 | dtype, 474 | device, 475 | generator, 476 | latents=None, 477 | ): 478 | height = height // self.vae_spatial_scale_factor 479 | width = width // self.vae_spatial_scale_factor 480 | num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 481 | 482 | shape = (batch_size, num_channels_latents, num_frames, height, width) 483 | 484 | if latents is not None: 485 | return latents.to(device=device, dtype=dtype) 486 | if isinstance(generator, list) and len(generator) != batch_size: 487 | raise ValueError( 488 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 489 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 490 | ) 491 | 492 | latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) 493 | latents = latents.to(dtype) 494 | return latents 495 | 496 | @property 497 | def guidance_scale(self): 498 | return self._guidance_scale 499 | 500 | @property 501 | def do_classifier_free_guidance(self): 502 | return self._guidance_scale > 1.0 503 | 504 | @property 505 | def num_timesteps(self): 506 | return self._num_timesteps 507 | 508 | @property 509 | def attention_kwargs(self): 510 | return self._attention_kwargs 511 | 512 | @property 513 | def interrupt(self): 514 | return self._interrupt 515 | 516 | 517 | @torch.no_grad() 518 | @replace_example_docstring(EXAMPLE_DOC_STRING) 519 | def __call__( 520 | self, 521 | prompt: Union[str, List[str]] = None, 522 | negative_prompt: Optional[Union[str, List[str]]] = None, 523 | height: Optional[int] = None, 524 | width: Optional[int] = None, 525 | num_frames: int = 19, 526 | num_inference_steps: int = 64, 527 | timesteps: List[int] = None, 528 | guidance_scale: float = 4.5, 529 | num_videos_per_prompt: Optional[int] = 1, 530 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 531 | latents: Optional[torch.Tensor] = None, 532 | prompt_embeds: Optional[torch.Tensor] = None, 533 | prompt_attention_mask: Optional[torch.Tensor] = None, 534 | negative_prompt_embeds: Optional[torch.Tensor] = None, 535 | negative_prompt_attention_mask: Optional[torch.Tensor] = None, 536 | output_type: Optional[str] = "pil", 537 | return_dict: bool = True, 538 | attention_kwargs: Optional[Dict[str, Any]] = None, 539 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 540 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 541 | max_sequence_length: int = 256, 542 | ): 543 | r""" 544 | Function invoked when calling the pipeline for generation. 545 | 546 | Args: 547 | prompt (`str` or `List[str]`, *optional*): 548 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 549 | instead. 550 | height (`int`, *optional*, defaults to `self.default_height`): 551 | The height in pixels of the generated image. This is set to 480 by default for the best results. 552 | width (`int`, *optional*, defaults to `self.default_width`): 553 | The width in pixels of the generated image. This is set to 848 by default for the best results. 554 | num_frames (`int`, defaults to `19`): 555 | The number of video frames to generate 556 | num_inference_steps (`int`, *optional*, defaults to 50): 557 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 558 | expense of slower inference. 559 | timesteps (`List[int]`, *optional*): 560 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 561 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 562 | passed will be used. Must be in descending order. 563 | guidance_scale (`float`, defaults to `4.5`): 564 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 565 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 566 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 567 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 568 | usually at the expense of lower image quality. 569 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 570 | The number of videos to generate per prompt. 571 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 572 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 573 | to make generation deterministic. 574 | latents (`torch.Tensor`, *optional*): 575 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 576 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 577 | tensor will ge generated by sampling using the supplied random `generator`. 578 | prompt_embeds (`torch.Tensor`, *optional*): 579 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 580 | provided, text embeddings will be generated from `prompt` input argument. 581 | prompt_attention_mask (`torch.Tensor`, *optional*): 582 | Pre-generated attention mask for text embeddings. 583 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 584 | Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not 585 | provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. 586 | negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): 587 | Pre-generated attention mask for negative text embeddings. 588 | output_type (`str`, *optional*, defaults to `"pil"`): 589 | The output format of the generate image. Choose between 590 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 591 | return_dict (`bool`, *optional*, defaults to `True`): 592 | Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. 593 | attention_kwargs (`dict`, *optional*): 594 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 595 | `self.processor` in 596 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 597 | callback_on_step_end (`Callable`, *optional*): 598 | A function that calls at the end of each denoising steps during the inference. The function is called 599 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 600 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 601 | `callback_on_step_end_tensor_inputs`. 602 | callback_on_step_end_tensor_inputs (`List`, *optional*): 603 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 604 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 605 | `._callback_tensor_inputs` attribute of your pipeline class. 606 | max_sequence_length (`int` defaults to `256`): 607 | Maximum sequence length to use with the `prompt`. 608 | 609 | Examples: 610 | 611 | Returns: 612 | [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: 613 | If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple` 614 | is returned where the first element is a list with the generated images. 615 | """ 616 | 617 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 618 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 619 | 620 | height = height or self.default_height 621 | width = width or self.default_width 622 | 623 | # 1. Check inputs. Raise error if not correct 624 | self.check_inputs( 625 | prompt=prompt, 626 | height=height, 627 | width=width, 628 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 629 | prompt_embeds=prompt_embeds, 630 | negative_prompt_embeds=negative_prompt_embeds, 631 | prompt_attention_mask=prompt_attention_mask, 632 | negative_prompt_attention_mask=negative_prompt_attention_mask, 633 | ) 634 | 635 | self._guidance_scale = guidance_scale 636 | self._attention_kwargs = attention_kwargs 637 | self._interrupt = False 638 | 639 | # 2. Define call parameters 640 | if prompt is not None and isinstance(prompt, str): 641 | batch_size = 1 642 | elif prompt is not None and isinstance(prompt, list): 643 | batch_size = len(prompt) 644 | else: 645 | batch_size = prompt_embeds.shape[0] 646 | 647 | device = self._execution_device 648 | # 3. Prepare text embeddings 649 | ( 650 | prompt_embeds, 651 | prompt_attention_mask, 652 | negative_prompt_embeds, 653 | negative_prompt_attention_mask, 654 | ) = self.encode_prompt( 655 | prompt=prompt, 656 | negative_prompt=negative_prompt, 657 | do_classifier_free_guidance=self.do_classifier_free_guidance, 658 | num_videos_per_prompt=num_videos_per_prompt, 659 | prompt_embeds=prompt_embeds, 660 | negative_prompt_embeds=negative_prompt_embeds, 661 | prompt_attention_mask=prompt_attention_mask, 662 | negative_prompt_attention_mask=negative_prompt_attention_mask, 663 | max_sequence_length=max_sequence_length, 664 | device=device, 665 | ) 666 | # 4. Prepare latent variables 667 | num_channels_latents = self.transformer.config.in_channels 668 | latents = self.prepare_latents( 669 | batch_size * num_videos_per_prompt, 670 | num_channels_latents, 671 | height, 672 | width, 673 | num_frames, 674 | prompt_embeds.dtype, 675 | device, 676 | generator, 677 | latents, 678 | ).repeat(1,1,2,1,1) 679 | 680 | if self.do_classifier_free_guidance: 681 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 682 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) 683 | 684 | 685 | # 5.5 Prepare attention rectification masks 686 | all_attention_mask = prepare_attention_mask(prompt_attention_mask, latents) 687 | 688 | # 5. Prepare timestep 689 | # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 690 | threshold_noise = 0.025 691 | sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) 692 | sigmas = np.array(sigmas) 693 | 694 | timesteps, num_inference_steps = retrieve_timesteps( 695 | self.scheduler, 696 | num_inference_steps, 697 | device, 698 | timesteps, 699 | sigmas, 700 | ) 701 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 702 | self._num_timesteps = len(timesteps) 703 | 704 | # 6. Denoising loop 705 | with self.progress_bar(total=num_inference_steps) as progress_bar: 706 | for i, t in enumerate(timesteps): 707 | if self.interrupt: 708 | continue 709 | 710 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 711 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 712 | timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) 713 | 714 | noise_pred = self.transformer( 715 | hidden_states=latent_model_input, 716 | encoder_hidden_states=prompt_embeds, 717 | timestep=timestep, 718 | encoder_attention_mask=all_attention_mask, 719 | attention_kwargs=attention_kwargs, 720 | return_dict=False, 721 | )[0] 722 | # Mochi CFG + Sampling runs in FP32 723 | noise_pred = noise_pred.to(torch.float32) 724 | 725 | if self.do_classifier_free_guidance: 726 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 727 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 728 | 729 | # compute the previous noisy sample x_t -> x_t-1 730 | latents_dtype = latents.dtype 731 | latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0] 732 | latents = latents.to(latents_dtype) 733 | 734 | if latents.dtype != latents_dtype: 735 | if torch.backends.mps.is_available(): 736 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 737 | latents = latents.to(latents_dtype) 738 | 739 | if callback_on_step_end is not None: 740 | callback_kwargs = {} 741 | for k in callback_on_step_end_tensor_inputs: 742 | callback_kwargs[k] = locals()[k] 743 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 744 | 745 | latents = callback_outputs.pop("latents", latents) 746 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 747 | 748 | # call the callback, if provided 749 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 750 | progress_bar.update() 751 | 752 | if XLA_AVAILABLE: 753 | xm.mark_step() 754 | 755 | if output_type == "latent": 756 | video = latents 757 | else: 758 | # unscale/denormalize the latents 759 | # denormalize with the mean and std if available and not None 760 | has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None 761 | has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None 762 | if has_latents_mean and has_latents_std: 763 | latents_mean = ( 764 | torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) 765 | ) 766 | latents_std = ( 767 | torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) 768 | ) 769 | latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean 770 | else: 771 | latents = latents / self.vae.config.scaling_factor 772 | 773 | video = self.vae.decode(latents, return_dict=False)[0] 774 | video = self.video_processor.postprocess_video(video, output_type=output_type) 775 | 776 | # Offload all models 777 | self.maybe_free_model_hooks() 778 | 779 | if not return_dict: 780 | return (video,) 781 | 782 | return MochiPipelineOutput(frames=video) 783 | -------------------------------------------------------------------------------- /Mochi/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=0 4 | VIDEO_DIR=video-dataset-disney-organized 5 | OUTPUT_DIR=videos_prepared 6 | NUM_FRAMES=37 7 | RESOLUTION=480x848 8 | 9 | # Extract width and height from RESOLUTION 10 | WIDTH=$(echo $RESOLUTION | cut -dx -f1) 11 | HEIGHT=$(echo $RESOLUTION | cut -dx -f2) 12 | 13 | python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample 14 | 15 | CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT} -------------------------------------------------------------------------------- /Mochi/rgba_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Any, Dict, Optional, Tuple, Union 6 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 8 | 9 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 10 | 11 | @torch.no_grad() 12 | def decode_latents(pipe, latents): 13 | has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None 14 | has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None 15 | if has_latents_mean and has_latents_std: 16 | latents_mean = ( 17 | torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) 18 | ) 19 | latents_std = ( 20 | torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) 21 | ) 22 | latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean 23 | else: 24 | latents = latents / pipe.vae.config.scaling_factor 25 | 26 | video = pipe.vae.decode(latents, return_dict=False)[0] 27 | video = pipe.video_processor.postprocess_video(video, output_type='np') 28 | 29 | return video 30 | 31 | 32 | class RGBALoRAMochiAttnProcessor: 33 | """Attention processor used in Mochi.""" 34 | def __init__(self, device, dtype, lora_rank=128, lora_alpha=1.0, latent_dim=3072): 35 | if not hasattr(F, "scaled_dot_product_attention"): 36 | raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 37 | 38 | 39 | # Initialize LoRA layers 40 | self.lora_alpha = lora_alpha 41 | self.lora_rank = lora_rank 42 | 43 | # Helper function to create LoRA layers 44 | def create_lora_layer(in_dim, mid_dim, out_dim, device=device, dtype=dtype): 45 | # Define the LoRA layers 46 | lora_a = nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype) 47 | lora_b = nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) 48 | 49 | # Initialize lora_a with random parameters (default initialization) 50 | nn.init.kaiming_uniform_(lora_a.weight, a=math.sqrt(5)) # or another suitable initialization 51 | 52 | # Initialize lora_b with zero values 53 | nn.init.zeros_(lora_b.weight) 54 | 55 | lora_a.weight.requires_grad = True 56 | lora_b.weight.requires_grad = True 57 | 58 | # Combine the layers into a sequential module 59 | return nn.Sequential(lora_a, lora_b) 60 | 61 | self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 62 | self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 63 | self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 64 | self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) 65 | 66 | def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): 67 | """Applies LoRA updates to query, key, and value tensors.""" 68 | query_delta = self.to_q_lora(hidden_states).to(query.device) 69 | query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling 70 | 71 | key_delta = self.to_k_lora(hidden_states).to(key.device) 72 | key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling 73 | 74 | value_delta = self.to_v_lora(hidden_states).to(value.device) 75 | value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling 76 | 77 | return query, key, value 78 | 79 | def __call__( 80 | self, 81 | attn, 82 | hidden_states: torch.Tensor, 83 | encoder_hidden_states: torch.Tensor, 84 | attention_mask: Optional[torch.Tensor] = None, 85 | image_rotary_emb: Optional[torch.Tensor] = None, 86 | ) -> torch.Tensor: 87 | query = attn.to_q(hidden_states) 88 | key = attn.to_k(hidden_states) 89 | value = attn.to_v(hidden_states) 90 | 91 | scaling = self.lora_alpha / self.lora_rank 92 | sequence_length = query.size(1) 93 | query, key, value = self._apply_lora(hidden_states, sequence_length, query, key, value, scaling) 94 | 95 | query = query.unflatten(2, (attn.heads, -1)) 96 | key = key.unflatten(2, (attn.heads, -1)) 97 | value = value.unflatten(2, (attn.heads, -1)) 98 | 99 | if attn.norm_q is not None: 100 | query = attn.norm_q(query) 101 | if attn.norm_k is not None: 102 | key = attn.norm_k(key) 103 | 104 | encoder_query = attn.add_q_proj(encoder_hidden_states) 105 | encoder_key = attn.add_k_proj(encoder_hidden_states) 106 | encoder_value = attn.add_v_proj(encoder_hidden_states) 107 | 108 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) 109 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) 110 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) 111 | 112 | if attn.norm_added_q is not None: 113 | encoder_query = attn.norm_added_q(encoder_query) 114 | if attn.norm_added_k is not None: 115 | encoder_key = attn.norm_added_k(encoder_key) 116 | 117 | if image_rotary_emb is not None: 118 | 119 | def apply_rotary_emb(x, freqs_cos, freqs_sin): 120 | x_even = x[..., 0::2].float() 121 | x_odd = x[..., 1::2].float() 122 | 123 | cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) 124 | sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) 125 | 126 | return torch.stack([cos, sin], dim=-1).flatten(-2) 127 | 128 | query[:,sequence_length//2:] = apply_rotary_emb(query[:,sequence_length//2:], *image_rotary_emb) 129 | query[:,:sequence_length//2] = apply_rotary_emb(query[:,:sequence_length//2], *image_rotary_emb) 130 | 131 | key[:,sequence_length//2:] = apply_rotary_emb(key[:,sequence_length//2:], *image_rotary_emb) 132 | key[:,:sequence_length//2] = apply_rotary_emb(key[:,:sequence_length//2], *image_rotary_emb) 133 | # query = apply_rotary_emb(query, *image_rotary_emb) 134 | # key = apply_rotary_emb(key, *image_rotary_emb) 135 | 136 | 137 | query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) 138 | encoder_query, encoder_key, encoder_value = ( 139 | encoder_query.transpose(1, 2), 140 | encoder_key.transpose(1, 2), 141 | encoder_value.transpose(1, 2), 142 | ) 143 | 144 | sequence_length = query.size(2) 145 | encoder_sequence_length = encoder_query.size(2) 146 | total_length = sequence_length + encoder_sequence_length 147 | 148 | batch_size, heads, _, dim = query.shape 149 | 150 | attn_outputs = [] 151 | prompt_attention_mask = attention_mask["prompt_attention_mask"] 152 | rect_attention_mask = attention_mask["rect_attention_mask"] 153 | for idx in range(batch_size): 154 | mask = prompt_attention_mask[idx][None, :] # two components: attention mask and prompt mask 155 | valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() 156 | 157 | valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] 158 | valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] 159 | valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] 160 | 161 | valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) 162 | valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) 163 | valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) 164 | 165 | attn_output = F.scaled_dot_product_attention( 166 | valid_query, 167 | valid_key, 168 | valid_value, 169 | dropout_p=0.0, 170 | attn_mask=rect_attention_mask[idx], 171 | is_causal=False 172 | ) 173 | valid_sequence_length = attn_output.size(2) 174 | attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) 175 | attn_outputs.append(attn_output) 176 | 177 | hidden_states = torch.cat(attn_outputs, dim=0) 178 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 179 | 180 | hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( 181 | (sequence_length, encoder_sequence_length), dim=1 182 | ) 183 | 184 | # linear proj 185 | original_hidden_states = attn.to_out[0](hidden_states) 186 | hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) 187 | original_hidden_states[:, -sequence_length // 2:, :] += hidden_states_delta[:, -sequence_length // 2:, :] * scaling 188 | # dropout 189 | hidden_states = attn.to_out[1](original_hidden_states) 190 | 191 | if hasattr(attn, "to_add_out"): 192 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 193 | 194 | return hidden_states, encoder_hidden_states 195 | 196 | def prepare_for_rgba_inference( 197 | model, device: torch.device, dtype: torch.dtype, 198 | lora_rank: int = 128, lora_alpha: float = 1.0 199 | ): 200 | 201 | def custom_forward(self): 202 | def forward( 203 | hidden_states: torch.Tensor, 204 | encoder_hidden_states: torch.Tensor, 205 | timestep: torch.LongTensor, 206 | encoder_attention_mask: torch.Tensor, 207 | attention_kwargs: Optional[Dict[str, Any]] = None, 208 | return_dict: bool = True, 209 | ) -> torch.Tensor: 210 | if attention_kwargs is not None: 211 | attention_kwargs = attention_kwargs.copy() 212 | lora_scale = attention_kwargs.pop("scale", 1.0) 213 | else: 214 | lora_scale = 1.0 215 | 216 | if USE_PEFT_BACKEND: 217 | # weight the lora layers by setting `lora_scale` for each PEFT layer 218 | scale_lora_layers(self, lora_scale) 219 | else: 220 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 221 | logger.warning( 222 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 223 | ) 224 | 225 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 226 | p = self.config.patch_size 227 | 228 | post_patch_height = height // p 229 | post_patch_width = width // p 230 | 231 | temb, encoder_hidden_states = self.time_embed( 232 | timestep, 233 | encoder_hidden_states, 234 | encoder_attention_mask["prompt_attention_mask"], 235 | hidden_dtype=hidden_states.dtype, 236 | ) 237 | 238 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) 239 | hidden_states = self.patch_embed(hidden_states) 240 | hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) 241 | 242 | image_rotary_emb = self.rope( 243 | self.pos_frequencies, 244 | num_frames // 2, # Identitical PE for RGB and Alpha 245 | post_patch_height, 246 | post_patch_width, 247 | device=hidden_states.device, 248 | dtype=torch.float32, 249 | ) 250 | 251 | for i, block in enumerate(self.transformer_blocks): 252 | if torch.is_grad_enabled() and self.gradient_checkpointing: 253 | 254 | def create_custom_forward(module): 255 | def custom_forward(*inputs): 256 | return module(*inputs) 257 | 258 | return custom_forward 259 | 260 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 261 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 262 | create_custom_forward(block), 263 | hidden_states, 264 | encoder_hidden_states, 265 | temb, 266 | encoder_attention_mask, 267 | image_rotary_emb, 268 | **ckpt_kwargs, 269 | ) 270 | else: 271 | hidden_states, encoder_hidden_states = block( 272 | hidden_states=hidden_states, 273 | encoder_hidden_states=encoder_hidden_states, 274 | temb=temb, 275 | encoder_attention_mask=encoder_attention_mask, 276 | image_rotary_emb=image_rotary_emb, 277 | ) 278 | hidden_states = self.norm_out(hidden_states, temb) 279 | hidden_states = self.proj_out(hidden_states) 280 | 281 | hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) 282 | hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) 283 | output = hidden_states.reshape(batch_size, -1, num_frames, height, width) 284 | 285 | if USE_PEFT_BACKEND: 286 | # remove `lora_scale` from each PEFT layer 287 | unscale_lora_layers(self, lora_scale) 288 | 289 | if not return_dict: 290 | return (output,) 291 | return Transformer2DModelOutput(sample=output) 292 | return forward 293 | 294 | for _, block in enumerate(model.transformer_blocks): 295 | attn_processor = RGBALoRAMochiAttnProcessor( 296 | device=device, 297 | dtype=dtype, 298 | lora_rank=lora_rank, 299 | lora_alpha=lora_alpha 300 | ) 301 | # block.attn1.set_processor(attn_processor) 302 | block.attn1.processor = attn_processor 303 | 304 | model.forward = custom_forward(model) 305 | 306 | def get_processor_state_dict(model): 307 | """Save trainable parameters of processors to a checkpoint.""" 308 | processor_state_dict = {} 309 | 310 | for index, block in enumerate(model.transformer_blocks): 311 | if hasattr(block.attn1, "processor"): 312 | processor = block.attn1.processor 313 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: 314 | if hasattr(processor, attr_name): 315 | lora_layer = getattr(processor, attr_name) 316 | for param_name, param in lora_layer.named_parameters(): 317 | key = f"block_{index}.{attr_name}.{param_name}" 318 | processor_state_dict[key] = param.data.clone() 319 | 320 | # torch.save({"processor_state_dict": processor_state_dict}, checkpoint_path) 321 | # print(f"Processor state_dict saved to {checkpoint_path}") 322 | return processor_state_dict 323 | 324 | def load_processor_state_dict(model, processor_state_dict): 325 | """Load trainable parameters of processors from a checkpoint.""" 326 | for index, block in enumerate(model.transformer_blocks): 327 | if hasattr(block.attn1, "processor"): 328 | processor = block.attn1.processor 329 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: 330 | if hasattr(processor, attr_name): 331 | lora_layer = getattr(processor, attr_name) 332 | for param_name, param in lora_layer.named_parameters(): 333 | key = f"block_{index}.{attr_name}.{param_name}" 334 | if key in processor_state_dict: 335 | param.data.copy_(processor_state_dict[key]) 336 | else: 337 | raise KeyError(f"Missing key {key} in checkpoint.") 338 | 339 | # Prepare training parameters 340 | def get_processor_params(processor): 341 | params = [] 342 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: 343 | if hasattr(processor, attr_name): 344 | lora_layer = getattr(processor, attr_name) 345 | params.extend(p for p in lora_layer.parameters() if p.requires_grad) 346 | return params 347 | 348 | def get_all_processor_params(transformer): 349 | all_params = [] 350 | for block in transformer.transformer_blocks: 351 | if hasattr(block.attn1, "processor"): 352 | processor = block.attn1.processor 353 | all_params.extend(get_processor_params(processor)) 354 | return all_params -------------------------------------------------------------------------------- /Mochi/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gc 17 | import random 18 | from glob import glob 19 | import math 20 | import os 21 | import torch.nn.functional as F 22 | import numpy as np 23 | from pathlib import Path 24 | from typing import Any, Dict, Tuple, List 25 | 26 | import torch 27 | import wandb 28 | from pipeline_mochi_rgba import * 29 | from diffusers import FlowMatchEulerDiscreteScheduler, MochiTransformer3DModel 30 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution 31 | from diffusers.training_utils import cast_training_params 32 | from diffusers.utils import export_to_video 33 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 34 | from huggingface_hub import create_repo, upload_folder 35 | from torch.utils.data import DataLoader 36 | from tqdm.auto import tqdm 37 | 38 | 39 | from args import get_args # isort:skip 40 | from dataset_simple import LatentEmbedDataset 41 | 42 | from utils import print_memory, reset_memory # isort:skip 43 | from rgba_utils import * 44 | 45 | 46 | # Taken from 47 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139 48 | def get_cosine_annealing_lr_scheduler( 49 | optimizer: torch.optim.Optimizer, 50 | warmup_steps: int, 51 | total_steps: int, 52 | ): 53 | def lr_lambda(step): 54 | if step < warmup_steps: 55 | return float(step) / float(max(1, warmup_steps)) 56 | else: 57 | return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) 58 | 59 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 60 | 61 | 62 | def save_model_card( 63 | repo_id: str, 64 | videos=None, 65 | base_model: str = None, 66 | validation_prompt=None, 67 | repo_folder=None, 68 | fps=30, 69 | ): 70 | widget_dict = [] 71 | if videos is not None and len(videos) > 0: 72 | for i, video in enumerate(videos): 73 | export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps) 74 | widget_dict.append( 75 | { 76 | "text": validation_prompt if validation_prompt else " ", 77 | "output": {"url": f"final_video_{i}.mp4"}, 78 | } 79 | ) 80 | 81 | model_description = f""" 82 | # Mochi-1 Preview LoRA Finetune 83 | 84 | 85 | 86 | ## Model description 87 | 88 | This is a lora finetune of the Mochi-1 preview model `{base_model}`. 89 | 90 | The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). 91 | 92 | ## Download model 93 | 94 | [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. 95 | 96 | ## Usage 97 | 98 | Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. 99 | 100 | ```py 101 | from diffusers import MochiPipeline 102 | from diffusers.utils import export_to_video 103 | import torch 104 | 105 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") 106 | pipe.load_lora_weights("CHANGE_ME") 107 | pipe.enable_model_cpu_offload() 108 | 109 | with torch.autocast("cuda", torch.bfloat16): 110 | video = pipe( 111 | prompt="CHANGE_ME", 112 | guidance_scale=6.0, 113 | num_inference_steps=64, 114 | height=480, 115 | width=848, 116 | max_sequence_length=256, 117 | output_type="np" 118 | ).frames[0] 119 | export_to_video(video) 120 | ``` 121 | 122 | For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. 123 | 124 | """ 125 | model_card = load_or_create_model_card( 126 | repo_id_or_path=repo_id, 127 | from_training=True, 128 | license="apache-2.0", 129 | base_model=base_model, 130 | prompt=validation_prompt, 131 | model_description=model_description, 132 | widget=widget_dict, 133 | ) 134 | tags = [ 135 | "text-to-video", 136 | "diffusers-training", 137 | "diffusers", 138 | "lora", 139 | "mochi-1-preview", 140 | "mochi-1-preview-diffusers", 141 | "template:sd-lora", 142 | ] 143 | 144 | model_card = populate_model_card(model_card, tags=tags) 145 | model_card.save(os.path.join(repo_folder, "README.md")) 146 | 147 | 148 | def log_validation( 149 | pipe: MochiPipeline, 150 | args: Dict[str, Any], 151 | pipeline_args: Dict[str, Any], 152 | step: int, 153 | wandb_run: str = None, 154 | is_final_validation: bool = False, 155 | ): 156 | print( 157 | f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." 158 | ) 159 | phase_name = "test" if is_final_validation else "validation" 160 | 161 | if not args.enable_model_cpu_offload: 162 | pipe = pipe.to("cuda") 163 | 164 | # run inference 165 | generator = torch.manual_seed(args.seed) if args.seed else None 166 | 167 | videos = [] 168 | with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): 169 | for _ in range(args.num_validation_videos): 170 | video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] 171 | videos.append(video) 172 | 173 | video_filenames = [] 174 | for i, video in enumerate(videos): 175 | prompt = ( 176 | pipeline_args["prompt"][:25] 177 | .replace(" ", "_") 178 | .replace(" ", "_") 179 | .replace("'", "_") 180 | .replace('"', "_") 181 | .replace("/", "_") 182 | ) 183 | filename = os.path.join(args.output_dir, f"{phase_name}_{str(step)}_video_{i}_{prompt}.mp4") 184 | export_to_video(video, filename, fps=30) 185 | video_filenames.append(filename) 186 | 187 | if wandb_run: 188 | wandb.log( 189 | { 190 | phase_name: [ 191 | wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) 192 | for i, filename in enumerate(video_filenames) 193 | ] 194 | } 195 | ) 196 | 197 | return videos 198 | 199 | 200 | # Adapted from the original code: 201 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578 202 | def cast_dit(model, dtype): 203 | for name, module in model.named_modules(): 204 | if isinstance(module, torch.nn.Linear): 205 | assert any( 206 | n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"] 207 | ), f"Unexpected linear layer: {name}" 208 | module.to(dtype=dtype) 209 | elif isinstance(module, torch.nn.Conv2d): 210 | module.to(dtype=dtype) 211 | return model 212 | 213 | 214 | def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path): 215 | # lora_state_dict = get_peft_model_state_dict(model) 216 | processor_state_dict = get_processor_state_dict(model) 217 | torch.save( 218 | { 219 | "state_dict": processor_state_dict, 220 | "optimizer": optimizer.state_dict(), 221 | "lr_scheduler": lr_scheduler.state_dict(), 222 | "global_step": global_step, 223 | }, 224 | checkpoint_path, 225 | ) 226 | 227 | 228 | class CollateFunction: 229 | def __init__(self, caption_dropout: float = None) -> None: 230 | self.caption_dropout = caption_dropout 231 | 232 | def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]: 233 | ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0) 234 | z = DiagonalGaussianDistribution(ldists).sample() 235 | assert torch.isfinite(z).all() 236 | 237 | # Sample noise which we will add to the samples. 238 | eps = torch.randn_like(z) 239 | sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) 240 | 241 | prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0) 242 | prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0) 243 | if self.caption_dropout and random.random() < self.caption_dropout: 244 | prompt_embeds.zero_() 245 | prompt_attention_mask = prompt_attention_mask.long() 246 | prompt_attention_mask.zero_() 247 | prompt_attention_mask = prompt_attention_mask.bool() 248 | 249 | return dict( 250 | z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask 251 | ) 252 | 253 | 254 | def main(args): 255 | if not torch.cuda.is_available(): 256 | raise ValueError("Not supported without CUDA.") 257 | 258 | if args.report_to == "wandb" and args.hub_token is not None: 259 | raise ValueError( 260 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 261 | " Please use `huggingface-cli login` to authenticate with the Hub." 262 | ) 263 | 264 | # Handle the repository creation 265 | if args.output_dir is not None: 266 | os.makedirs(args.output_dir, exist_ok=True) 267 | 268 | # Prepare models and scheduler 269 | transformer = MochiTransformer3DModel.from_pretrained( 270 | args.pretrained_model_name_or_path, 271 | subfolder="transformer", 272 | revision=args.revision, 273 | variant=args.variant, 274 | ) 275 | scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 276 | args.pretrained_model_name_or_path, subfolder="scheduler" 277 | ) 278 | 279 | transformer.requires_grad_(False) 280 | transformer.to("cuda") 281 | if args.gradient_checkpointing: 282 | transformer.enable_gradient_checkpointing() 283 | if args.cast_dit: 284 | transformer = cast_dit(transformer, torch.bfloat16) 285 | if args.compile_dit: 286 | transformer.compile() 287 | 288 | prepare_for_rgba_inference( 289 | model=transformer, 290 | device=torch.device("cuda"), 291 | dtype=torch.bfloat16, 292 | # seq_length=seq_length, 293 | ) 294 | processor_params = get_all_processor_params(transformer) 295 | 296 | # Enable TF32 for faster training on Ampere GPUs, 297 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 298 | if args.allow_tf32 and torch.cuda.is_available(): 299 | torch.backends.cuda.matmul.allow_tf32 = True 300 | 301 | if args.scale_lr: 302 | args.learning_rate = args.learning_rate * args.train_batch_size 303 | # only upcast trainable parameters (LoRA) into fp32 304 | 305 | if not isinstance(processor_params, list): 306 | processor_params = [processor_params] 307 | for m in processor_params: 308 | for param in m: 309 | # only upcast trainable parameters into fp32 310 | if param.requires_grad: 311 | param.data = param.to(torch.float32) 312 | 313 | # Prepare optimizer 314 | transformer_lora_parameters = processor_params # list(filter(lambda p: p.requires_grad, transformer.parameters())) 315 | num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters) 316 | optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) 317 | 318 | # Dataset and DataLoader 319 | train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) 320 | train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] 321 | print(f"Found {len(train_vids)} training videos in {args.data_root}") 322 | assert len(train_vids) > 0, f"No training data found in {args.data_root}" 323 | 324 | collate_fn = CollateFunction(caption_dropout=args.caption_dropout) 325 | train_dataset = LatentEmbedDataset(train_vids, repeat=1) 326 | train_dataloader = DataLoader( 327 | train_dataset, 328 | collate_fn=collate_fn, 329 | batch_size=args.train_batch_size, 330 | num_workers=args.dataloader_num_workers, 331 | pin_memory=args.pin_memory, 332 | ) 333 | 334 | # LR scheduler and math around the number of training steps. 335 | overrode_max_train_steps = False 336 | num_update_steps_per_epoch = len(train_dataloader) 337 | if args.max_train_steps is None: 338 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 339 | overrode_max_train_steps = True 340 | 341 | lr_scheduler = get_cosine_annealing_lr_scheduler( 342 | optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps 343 | ) 344 | 345 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 346 | num_update_steps_per_epoch = len(train_dataloader) 347 | if overrode_max_train_steps: 348 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 349 | # Afterwards we recalculate our number of training epochs 350 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 351 | 352 | # We need to initialize the trackers we use, and also store our configuration. 353 | # The trackers initializes automatically on the main process. 354 | wandb_run = None 355 | if args.report_to == "wandb": 356 | tracker_name = args.tracker_name or "mochi-1-rgba-lora" 357 | wandb_run = wandb.init(project=tracker_name, config=vars(args)) 358 | 359 | # Resume from checkpoint if specified 360 | if args.resume_from_checkpoint: 361 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 362 | if "global_step" in checkpoint: 363 | global_step = checkpoint["global_step"] 364 | if "optimizer" in checkpoint: 365 | optimizer.load_state_dict(checkpoint["optimizer"]) 366 | if "lr_scheduler" in checkpoint: 367 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 368 | 369 | # set_peft_model_state_dict(transformer, checkpoint["state_dict"]) # Luozhou: modify this line 370 | 371 | processor_state_dict = checkpoint["state_dict"] 372 | load_processor_state_dict(transformer, processor_state_dict) 373 | 374 | print(f"Resuming from checkpoint: {args.resume_from_checkpoint}") 375 | print(f"Resuming from global step: {global_step}") 376 | else: 377 | global_step = 0 378 | 379 | print("===== Memory before training =====") 380 | reset_memory("cuda") 381 | print_memory("cuda") 382 | 383 | # Train! 384 | total_batch_size = args.train_batch_size 385 | print("***** Running training *****") 386 | print(f" Num trainable parameters = {num_trainable_parameters}") 387 | print(f" Num examples = {len(train_dataset)}") 388 | print(f" Num batches each epoch = {len(train_dataloader)}") 389 | print(f" Num epochs = {args.num_train_epochs}") 390 | print(f" Instantaneous batch size per device = {args.train_batch_size}") 391 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 392 | print(f" Total optimization steps = {args.max_train_steps}") 393 | 394 | first_epoch = 0 395 | progress_bar = tqdm( 396 | range(0, args.max_train_steps), 397 | initial=global_step, 398 | desc="Steps", 399 | ) 400 | for epoch in range(first_epoch, args.num_train_epochs): 401 | transformer.train() 402 | 403 | for step, batch in enumerate(train_dataloader): 404 | with torch.no_grad(): 405 | z = batch["z"].to("cuda") 406 | eps = batch["eps"].to("cuda") 407 | sigma = batch["sigma"].to("cuda") 408 | prompt_embeds = batch["prompt_embeds"].to("cuda") 409 | prompt_attention_mask = batch["prompt_attention_mask"].to("cuda") 410 | 411 | all_attention_mask = prepare_attention_mask( 412 | prompt_attention_mask=prompt_attention_mask, 413 | latents=z 414 | ) 415 | 416 | sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] 417 | # Add noise according to flow matching. 418 | # zt = (1 - texp) * x + texp * z1 419 | z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps 420 | ut = z - eps 421 | 422 | # (1 - sigma) because of 423 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 424 | # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. 425 | timesteps = (1 - sigma) * scheduler.config.num_train_timesteps 426 | 427 | with torch.autocast("cuda", torch.bfloat16): 428 | model_pred = transformer( 429 | hidden_states=z_sigma, 430 | encoder_hidden_states=prompt_embeds, 431 | encoder_attention_mask=all_attention_mask, 432 | timestep=timesteps, 433 | return_dict=False, 434 | )[0] 435 | assert model_pred.shape == z.shape 436 | loss = F.mse_loss(model_pred.float(), ut.float()) 437 | loss.backward() 438 | 439 | optimizer.step() 440 | optimizer.zero_grad() 441 | lr_scheduler.step() 442 | 443 | progress_bar.update(1) 444 | global_step += 1 445 | 446 | last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate 447 | logs = {"loss": loss.detach().item(), "lr": last_lr} 448 | progress_bar.set_postfix(**logs) 449 | if wandb_run: 450 | wandb_run.log(logs, step=global_step) 451 | 452 | if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0: 453 | print(f"Saving checkpoint at step {global_step}") 454 | checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") 455 | save_checkpoint( 456 | transformer, 457 | optimizer, 458 | lr_scheduler, 459 | global_step, 460 | checkpoint_path, 461 | ) 462 | 463 | # if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: 464 | print("===== Memory before validation =====") 465 | print_memory("cuda") 466 | 467 | transformer.eval() 468 | pipe = MochiPipeline.from_pretrained( 469 | args.pretrained_model_name_or_path, 470 | transformer=transformer, 471 | scheduler=scheduler, 472 | revision=args.revision, 473 | variant=args.variant, 474 | ) 475 | 476 | if args.enable_slicing: 477 | pipe.vae.enable_slicing() 478 | if args.enable_tiling: 479 | pipe.vae.enable_tiling() 480 | if args.enable_model_cpu_offload: 481 | pipe.enable_model_cpu_offload() 482 | 483 | # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) 484 | validation_prompts = [ 485 | "A boy in a white shirt and shorts is seen bouncing a ball, isolated background", 486 | ] 487 | for validation_prompt in validation_prompts: 488 | pipeline_args = { 489 | "prompt": validation_prompt, 490 | "guidance_scale": 6.0, 491 | "num_frames": 37, 492 | "num_inference_steps": 64, 493 | "height": args.height, 494 | "width": args.width, 495 | "max_sequence_length": 256, 496 | } 497 | log_validation( 498 | pipe=pipe, 499 | args=args, 500 | pipeline_args=pipeline_args, 501 | step=global_step, 502 | wandb_run=wandb_run, 503 | ) 504 | 505 | print("===== Memory after validation =====") 506 | print_memory("cuda") 507 | reset_memory("cuda") 508 | 509 | del pipe.text_encoder 510 | del pipe.vae 511 | del pipe 512 | gc.collect() 513 | torch.cuda.empty_cache() 514 | 515 | transformer.train() 516 | 517 | if global_step >= args.max_train_steps: 518 | break 519 | 520 | if global_step >= args.max_train_steps: 521 | break 522 | 523 | transformer.eval() 524 | 525 | # saving lora weights 526 | # transformer_lora_layers = get_peft_model_state_dict(transformer) 527 | # MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) 528 | 529 | # Cleanup trained models to save memory 530 | del transformer 531 | 532 | gc.collect() 533 | torch.cuda.empty_cache() 534 | 535 | # Final test inference 536 | # validation_outputs = [] 537 | # if args.validation_prompt and args.num_validation_videos > 0: 538 | # print("===== Memory before testing =====") 539 | # print_memory("cuda") 540 | # reset_memory("cuda") 541 | 542 | # pipe = MochiPipeline.from_pretrained( 543 | # args.pretrained_model_name_or_path, 544 | # revision=args.revision, 545 | # variant=args.variant, 546 | # ) 547 | 548 | 549 | 550 | # if args.enable_slicing: 551 | # pipe.vae.enable_slicing() 552 | # if args.enable_tiling: 553 | # pipe.vae.enable_tiling() 554 | # if args.enable_model_cpu_offload: 555 | # pipe.enable_model_cpu_offload() 556 | 557 | # # Load LoRA weights 558 | # # lora_scaling = args.lora_alpha / args.rank 559 | # # pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") 560 | # # pipe.set_adapters(["mochi-lora"], [lora_scaling]) 561 | 562 | # # Run inference 563 | # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) 564 | # for validation_prompt in validation_prompts: 565 | # pipeline_args = { 566 | # "prompt": validation_prompt, 567 | # "guidance_scale": 6.0, 568 | # "num_inference_steps": 64, 569 | # "height": args.height, 570 | # "width": args.width, 571 | # "max_sequence_length": 256, 572 | # } 573 | 574 | # video = log_validation( 575 | # pipe=pipe, 576 | # args=args, 577 | # pipeline_args=pipeline_args, 578 | # epoch=epoch, 579 | # wandb_run=wandb_run, 580 | # is_final_validation=True, 581 | # ) 582 | # validation_outputs.extend(video) 583 | 584 | # print("===== Memory after testing =====") 585 | # print_memory("cuda") 586 | # reset_memory("cuda") 587 | # torch.cuda.synchronize("cuda") 588 | 589 | 590 | 591 | if __name__ == "__main__": 592 | args = get_args() 593 | main(args) 594 | -------------------------------------------------------------------------------- /Mochi/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export NCCL_P2P_DISABLE=1 3 | export TORCH_NCCL_ENABLE_MONITORING=0 4 | 5 | GPU_IDS="3" 6 | 7 | DATA_ROOT="/hpc2hdd/home/lwang592/projects/finetrainers/training/data/video-matte-240k-rgb-prepared-f37" 8 | MODEL="genmo/mochi-1-preview" 9 | OUTPUT_PATH="mochi-rgba-lora-f37" 10 | 11 | cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python train.py \ 12 | --pretrained_model_name_or_path $MODEL \ 13 | --cast_dit \ 14 | --data_root $DATA_ROOT \ 15 | --seed 42 \ 16 | --output_dir $OUTPUT_PATH \ 17 | --train_batch_size 2 \ 18 | --dataloader_num_workers 4 \ 19 | --pin_memory \ 20 | --caption_dropout 0.0 \ 21 | --max_train_steps 5000 \ 22 | --gradient_checkpointing \ 23 | --enable_slicing \ 24 | --enable_tiling \ 25 | --enable_model_cpu_offload \ 26 | --optimizer adamw \ 27 | --allow_tf32" 28 | 29 | echo "Running command: $cmd" 30 | eval $cmd 31 | echo -ne "-------------------- Finished executing script --------------------\n\n" -------------------------------------------------------------------------------- /Mochi/trim_and_crop_videos.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: 3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py 4 | """ 5 | 6 | from pathlib import Path 7 | import shutil 8 | 9 | import click 10 | from moviepy.editor import VideoFileClip 11 | from tqdm import tqdm 12 | 13 | 14 | @click.command() 15 | @click.argument("folder", type=click.Path(exists=True, dir_okay=True)) 16 | @click.argument("output_folder", type=click.Path(dir_okay=True)) 17 | @click.option("--num_frames", "-f", type=float, default=30, help="Number of frames") 18 | @click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution") 19 | @click.option("--force_upsample", is_flag=True, help="Force upsample.") 20 | def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample): 21 | """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution""" 22 | input_path = Path(folder) 23 | output_path = Path(output_folder) 24 | output_path.mkdir(parents=True, exist_ok=True) 25 | 26 | # Parse target resolution 27 | target_height, target_width = map(int, resolution.split("x")) 28 | 29 | # Calculate duration 30 | duration = (num_frames / 30) + 0.09 31 | 32 | # Find all MP4 and MOV files 33 | video_files = ( 34 | list(input_path.rglob("*.mp4")) 35 | + list(input_path.rglob("*.MOV")) 36 | + list(input_path.rglob("*.mov")) 37 | + list(input_path.rglob("*.MP4")) 38 | ) 39 | 40 | for file_path in tqdm(video_files): 41 | try: 42 | relative_path = file_path.relative_to(input_path) 43 | output_file = output_path / relative_path.with_suffix(".mp4") 44 | output_file.parent.mkdir(parents=True, exist_ok=True) 45 | 46 | click.echo(f"Processing: {file_path}") 47 | video = VideoFileClip(str(file_path)) 48 | 49 | # Skip if video is too short 50 | if video.duration < duration: 51 | click.echo(f"Skipping {file_path} as it is too short") 52 | continue 53 | 54 | # Skip if target resolution is larger than input 55 | if target_width > video.w or target_height > video.h: 56 | if force_upsample: 57 | click.echo( 58 | f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video." 59 | ) 60 | video = video.resize(width=target_width, height=target_height) 61 | else: 62 | click.echo( 63 | f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}" 64 | ) 65 | continue 66 | 67 | # First truncate duration 68 | truncated = video.subclip(0, duration) 69 | 70 | # Calculate crop dimensions to maintain aspect ratio 71 | target_ratio = target_width / target_height 72 | current_ratio = truncated.w / truncated.h 73 | 74 | if current_ratio > target_ratio: 75 | # Video is wider than target ratio - crop width 76 | new_width = int(truncated.h * target_ratio) 77 | x1 = (truncated.w - new_width) // 2 78 | final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height)) 79 | else: 80 | # Video is taller than target ratio - crop height 81 | new_height = int(truncated.w / target_ratio) 82 | y1 = (truncated.h - new_height) // 2 83 | final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height)) 84 | 85 | # Set output parameters for consistent MP4 encoding 86 | output_params = { 87 | "codec": "libx264", 88 | "audio": False, # Disable audio 89 | "preset": "medium", # Balance between speed and quality 90 | "bitrate": "5000k", # Adjust as needed 91 | } 92 | 93 | # Set FPS to 30 94 | final = final.set_fps(30) 95 | 96 | # Check for a corresponding .txt file 97 | txt_file_path = file_path.with_suffix(".txt") 98 | if txt_file_path.exists(): 99 | output_txt_file = output_path / relative_path.with_suffix(".txt") 100 | output_txt_file.parent.mkdir(parents=True, exist_ok=True) 101 | shutil.copy(txt_file_path, output_txt_file) 102 | click.echo(f"Copied {txt_file_path} to {output_txt_file}") 103 | else: 104 | # Print warning in bold yellow with a warning emoji 105 | click.echo( 106 | f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m" 107 | ) 108 | output_txt_file = output_path / relative_path.with_suffix(".txt") 109 | output_txt_file.parent.mkdir(parents=True, exist_ok=True) 110 | output_txt_file.touch() 111 | 112 | # Write the output file 113 | final.write_videofile(str(output_file), **output_params) 114 | 115 | # Clean up 116 | video.close() 117 | truncated.close() 118 | final.close() 119 | 120 | except Exception as e: 121 | click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True) 122 | raise 123 | 124 | 125 | if __name__ == "__main__": 126 | truncate_videos() -------------------------------------------------------------------------------- /Mochi/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import inspect 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from accelerate.logging import get_logger 8 | from diffusers.models.embeddings import get_3d_rotary_pos_embed 9 | from diffusers.utils.torch_utils import is_compiled_module 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def get_optimizer( 16 | params_to_optimize, 17 | optimizer_name: str = "adam", 18 | learning_rate: float = 1e-3, 19 | beta1: float = 0.9, 20 | beta2: float = 0.95, 21 | beta3: float = 0.98, 22 | epsilon: float = 1e-8, 23 | weight_decay: float = 1e-4, 24 | prodigy_decouple: bool = False, 25 | prodigy_use_bias_correction: bool = False, 26 | prodigy_safeguard_warmup: bool = False, 27 | use_8bit: bool = False, 28 | use_4bit: bool = False, 29 | use_torchao: bool = False, 30 | use_deepspeed: bool = False, 31 | use_cpu_offload_optimizer: bool = False, 32 | offload_gradients: bool = False, 33 | ) -> torch.optim.Optimizer: 34 | optimizer_name = optimizer_name.lower() 35 | 36 | # Use DeepSpeed optimzer 37 | if use_deepspeed: 38 | from accelerate.utils import DummyOptim 39 | 40 | return DummyOptim( 41 | params_to_optimize, 42 | lr=learning_rate, 43 | betas=(beta1, beta2), 44 | eps=epsilon, 45 | weight_decay=weight_decay, 46 | ) 47 | 48 | if use_8bit and use_4bit: 49 | raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") 50 | 51 | if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: 52 | try: 53 | import torchao 54 | 55 | torchao.__version__ 56 | except ImportError: 57 | raise ImportError( 58 | "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." 59 | ) 60 | 61 | if not use_torchao and use_4bit: 62 | raise ValueError("4-bit Optimizers are only supported with torchao.") 63 | 64 | # Optimizer creation 65 | supported_optimizers = ["adam", "adamw", "prodigy", "came"] 66 | if optimizer_name not in supported_optimizers: 67 | logger.warning( 68 | f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." 69 | ) 70 | optimizer_name = "adamw" 71 | 72 | if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: 73 | raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") 74 | 75 | if use_8bit: 76 | try: 77 | import bitsandbytes as bnb 78 | except ImportError: 79 | raise ImportError( 80 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 81 | ) 82 | 83 | if optimizer_name == "adamw": 84 | if use_torchao: 85 | from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit 86 | 87 | optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW 88 | else: 89 | optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW 90 | 91 | init_kwargs = { 92 | "betas": (beta1, beta2), 93 | "eps": epsilon, 94 | "weight_decay": weight_decay, 95 | } 96 | 97 | elif optimizer_name == "adam": 98 | if use_torchao: 99 | from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit 100 | 101 | optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam 102 | else: 103 | optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam 104 | 105 | init_kwargs = { 106 | "betas": (beta1, beta2), 107 | "eps": epsilon, 108 | "weight_decay": weight_decay, 109 | } 110 | 111 | elif optimizer_name == "prodigy": 112 | try: 113 | import prodigyopt 114 | except ImportError: 115 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") 116 | 117 | optimizer_class = prodigyopt.Prodigy 118 | 119 | if learning_rate <= 0.1: 120 | logger.warning( 121 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 122 | ) 123 | 124 | init_kwargs = { 125 | "lr": learning_rate, 126 | "betas": (beta1, beta2), 127 | "beta3": beta3, 128 | "eps": epsilon, 129 | "weight_decay": weight_decay, 130 | "decouple": prodigy_decouple, 131 | "use_bias_correction": prodigy_use_bias_correction, 132 | "safeguard_warmup": prodigy_safeguard_warmup, 133 | } 134 | 135 | elif optimizer_name == "came": 136 | try: 137 | import came_pytorch 138 | except ImportError: 139 | raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") 140 | 141 | optimizer_class = came_pytorch.CAME 142 | 143 | init_kwargs = { 144 | "lr": learning_rate, 145 | "eps": (1e-30, 1e-16), 146 | "betas": (beta1, beta2, beta3), 147 | "weight_decay": weight_decay, 148 | } 149 | 150 | if use_cpu_offload_optimizer: 151 | from torchao.prototype.low_bit_optim import CPUOffloadOptimizer 152 | 153 | if "fused" in inspect.signature(optimizer_class.__init__).parameters: 154 | init_kwargs.update({"fused": True}) 155 | 156 | optimizer = CPUOffloadOptimizer( 157 | params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs 158 | ) 159 | else: 160 | optimizer = optimizer_class(params_to_optimize, **init_kwargs) 161 | 162 | return optimizer 163 | 164 | 165 | def get_gradient_norm(parameters): 166 | norm = 0 167 | for param in parameters: 168 | if param.grad is None: 169 | continue 170 | local_norm = param.grad.detach().data.norm(2) 171 | norm += local_norm.item() ** 2 172 | norm = norm**0.5 173 | return norm 174 | 175 | 176 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid 177 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): 178 | tw = tgt_width 179 | th = tgt_height 180 | h, w = src 181 | r = h / w 182 | if r > (th / tw): 183 | resize_height = th 184 | resize_width = int(round(th / h * w)) 185 | else: 186 | resize_width = tw 187 | resize_height = int(round(tw / w * h)) 188 | 189 | crop_top = int(round((th - resize_height) / 2.0)) 190 | crop_left = int(round((tw - resize_width) / 2.0)) 191 | 192 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 193 | 194 | 195 | def prepare_rotary_positional_embeddings( 196 | height: int, 197 | width: int, 198 | num_frames: int, 199 | vae_scale_factor_spatial: int = 8, 200 | patch_size: int = 2, 201 | patch_size_t: int = None, 202 | attention_head_dim: int = 64, 203 | device: Optional[torch.device] = None, 204 | base_height: int = 480, 205 | base_width: int = 720, 206 | ) -> Tuple[torch.Tensor, torch.Tensor]: 207 | grid_height = height // (vae_scale_factor_spatial * patch_size) 208 | grid_width = width // (vae_scale_factor_spatial * patch_size) 209 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size) 210 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size) 211 | 212 | if patch_size_t is None: 213 | # CogVideoX 1.0 214 | grid_crops_coords = get_resize_crop_region_for_grid( 215 | (grid_height, grid_width), base_size_width, base_size_height 216 | ) 217 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 218 | embed_dim=attention_head_dim, 219 | crops_coords=grid_crops_coords, 220 | grid_size=(grid_height, grid_width), 221 | temporal_size=num_frames, 222 | ) 223 | else: 224 | # CogVideoX 1.5 225 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t 226 | 227 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 228 | embed_dim=attention_head_dim, 229 | crops_coords=None, 230 | grid_size=(grid_height, grid_width), 231 | temporal_size=base_num_frames, 232 | grid_type="slice", 233 | max_size=(base_size_height, base_size_width), 234 | ) 235 | 236 | freqs_cos = freqs_cos.to(device=device) 237 | freqs_sin = freqs_sin.to(device=device) 238 | return freqs_cos, freqs_sin 239 | 240 | 241 | def reset_memory(device: Union[str, torch.device]) -> None: 242 | gc.collect() 243 | torch.cuda.empty_cache() 244 | torch.cuda.reset_peak_memory_stats(device) 245 | torch.cuda.reset_accumulated_memory_stats(device) 246 | 247 | 248 | def print_memory(device: Union[str, torch.device]) -> None: 249 | memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 250 | max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 251 | max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 252 | print(f"{memory_allocated=:.3f} GB") 253 | print(f"{max_memory_allocated=:.3f} GB") 254 | print(f"{max_memory_reserved=:.3f} GB") 255 | 256 | 257 | def unwrap_model(accelerator: Accelerator, model): 258 | model = accelerator.unwrap_model(model) 259 | model = model._orig_mod if is_compiled_module(model) else model 260 | return model 261 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TransPixeler: Advancing Text-to-Video Generation with Transparency (CVPR2025) 2 |
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | [Luozhou Wang*](https://wileewang.github.io/), 12 | [Yijun Li**](https://yijunmaverick.github.io/), 13 | [Zhifei Chen](), 14 | [Jui-Hsien Wang](http://juiwang.com/), 15 | [Zhifei Zhang](https://zzutk.github.io/), 16 | [He Zhang](https://sites.google.com/site/hezhangsprinter), 17 | [Zhe Lin](https://sites.google.com/site/zhelin625/home), 18 | [Ying-Cong Chen†](https://www.yingcong.me) 19 | 20 | HKUST(GZ), HKUST, Adobe Research. 21 | 22 | \* Internship Project 23 | \** Project Lead 24 | † Corresponding Author 25 | 26 | Text-to-video generative models have made significant strides, enabling diverse applications in entertainment, advertising, and education. However, generating RGBA video, which includes alpha channels for transparency, remains a challenge due to limited datasets and the difficulty of adapting existing models. Alpha channels are crucial for visual effects (VFX), allowing transparent elements like smoke and reflections to blend seamlessly into scenes. 27 | We introduce TransPixar, a method to extend pretrained video models for RGBA generation while retaining the original RGB capabilities. TransPixar leverages a diffusion transformer (DiT) architecture, incorporating alpha-specific tokens and using LoRA-based fine-tuning to jointly generate RGB and alpha channels with high consistency. By optimizing attention mechanisms, TransPixeler preserves the strengths of the original RGB model and achieves strong alignment between RGB and alpha channels despite limited training data. 28 | Our approach effectively generates diverse and consistent RGBA videos, advancing the possibilities for VFX and interactive content creation. 29 | 30 | 31 | 32 | 33 | 34 | 35 | ## 📰 News 36 | 37 | - **[2025.04.28]** We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks. This branch includes training code tailored for generating both RGB and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt. 38 | 39 | - **[2025.02.26]** **TransPixeler** is accepted by CVPR 2025! See you in Nashville! 40 | 41 | - **[2025.01.19]** We've renamed our project from **TransPixar** to **TransPixeler**!! 42 | 43 | - **[2025.01.17]** We’ve created a [Discord group](https://discord.gg/7Xds3Qjr) and a [WeChat group](https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg)! Everyone is welcome to join for discussions and collaborations. 44 | 45 | - **[2025.01.14]** Added new tasks to the repository's roadmap, including support for Hunyuan and LTX video models, and ComfyUI integration. 46 | 47 | - **[2025.01.07]** Released project page, arXiv paper, inference code, and Hugging Face demo. 48 | 49 | 50 | 51 | 52 | ## 🔥 New Branch for Joint Generation with Wan2.1 53 | 54 | We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks. 55 | 56 | In the `wan` branch, we have developed and released training code tailored for joint generation scenarios, enabling the simultaneous generation of RGB videos and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt. 57 | 58 | **Key features of the `wan` branch:** 59 | - **Integration of Wan2.1**: Leverages the capabilities of the Wan2.1 video generation model for enhanced performance. 60 | - **Joint Generation Support**: Facilitates the concurrent generation of RGB and paired modality videos. 61 | - **Dataset Structure**: Expects each sample to include: 62 | - A primary video file (`001.mp4`) representing the RGB content. 63 | - A paired secondary video file (`001_seg.mp4`) with a fixed `_seg` suffix, representing the associated modality. 64 | - A caption text file (`001.txt`) with the same base name as the primary video. 65 | - **Periodic Evaluation**: Supports periodic video sampling during training by setting `eval_every_step` or `eval_every_epoch` in the configuration. 66 | - **Customized Pipelines**: Offers tailored training and inference pipelines designed specifically for joint generation tasks. 67 | 68 | 👉 To utilize the joint generation features, please checkout the [`wan`](https://github.com/wileewang/TransPixar/tree/wan) branch. 69 | 70 | 71 | 72 | 73 | ## Contents 74 | 75 | * [Installation](#installation) 76 | * [TransPixar LoRA Weights](#transpixar-lora-hub) 77 | * [Training](#training) 78 | * [Inference](#inference) 79 | * [Acknowledgement](#acknowledgement) 80 | * [Citation](#citation) 81 | 82 | 83 | 84 | ## Installation 85 | 86 | ```bash 87 | # For the main branch 88 | conda create -n TransPixeler python=3.10 89 | conda activate TransPixeler 90 | pip install -r requirements.txt 91 | ``` 92 | 93 | **Note:** 94 | If you want to use the **Wan2.1 model**, please first checkout the `wan` branch: 95 | 96 | ```bash 97 | git checkout wan 98 | ``` 99 | 100 | ## TransPixeler LoRA Weights 101 | 102 | Our pipeline is designed to support various video tasks, including Text-to-RGBA Video, Image-to-RGBA Video. 103 | 104 | We provide the following pre-trained LoRA weights: 105 | 106 | | Task | Base Model | Frames | LoRA weights | Inference VRAM | 107 | |---------------|---------------------------------------------------------------|--------|--------------------------------------------------------------------|----------------| 108 | | T2V + RGBA | [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5b) | 49 | [link](https://huggingface.co/wileewang/TransPixar/blob/main/cogvideox_rgba_lora.safetensors) | ~24GB | 109 | 110 | 111 | ## Training - RGB + Alpha Joint Generation 112 | We have open-sourced the training code for **Mochi** on RGBA joint generation. Please refer to the [Mochi README](Mochi/README.md) for details. 113 | 114 | 115 | ## Inference - Gradio Demo 116 | In addition to the [Hugging Face online demo](https://huggingface.co/spaces/wileewang/TransPixar), users can also launch a local inference demo based on CogVideoX-5B by running the following command: 117 | 118 | ```bash 119 | python app.py 120 | ``` 121 | 122 | ## Inference - Command Line Interface (CLI) 123 | To generate RGBA videos, navigate to the corresponding directory for the video model and execute the following command: 124 | ```bash 125 | python cli.py \ 126 | --lora_path /path/to/lora \ 127 | --prompt "..." 128 | ``` 129 | 130 | --- 131 | 132 | ## Acknowledgement 133 | 134 | * [finetrainers](https://github.com/a-r-r-o-w/finetrainers): We followed their implementation of Mochi training and inference. 135 | * [CogVideoX](https://github.com/THUDM/CogVideo): We followed their implementation of CogVideoX training and inference. 136 | 137 | We are grateful for their exceptional work and generous contribution to the open-source community. 138 | 139 | ## Citation 140 | 141 | ```bibtex 142 | @misc{wang2025transpixeler, 143 | title={TransPixeler: Advancing Text-to-Video Generation with Transparency}, 144 | author={Luozhou Wang and Yijun Li and Zhifei Chen and Jui-Hsien Wang and Zhifei Zhang and He Zhang and Zhe Lin and Ying-Cong Chen}, 145 | year={2025}, 146 | eprint={2501.03006}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.CV}, 149 | url={https://arxiv.org/abs/2501.03006}, 150 | } 151 | ``` 152 | 153 | ## Star History 154 | 155 | [![Star History Chart](https://api.star-history.com/svg?repos=wileewang/TransPixeler&type=Date)](https://star-history.com/#wileewang/TransPixeler&Date) 156 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo. 3 | set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. 4 | Usage: 5 | OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py 6 | """ 7 | 8 | import math 9 | import os 10 | import random 11 | import threading 12 | import time 13 | 14 | import cv2 15 | import tempfile 16 | import imageio_ffmpeg 17 | import gradio as gr 18 | import torch 19 | from PIL import Image 20 | # from diffusers import ( 21 | # CogVideoXPipeline, 22 | # CogVideoXDPMScheduler, 23 | # CogVideoXVideoToVideoPipeline, 24 | # CogVideoXImageToVideoPipeline, 25 | # CogVideoXTransformer3DModel, 26 | # ) 27 | from typing import Union, List 28 | from CogVideoX.pipeline_rgba import CogVideoXPipeline 29 | from CogVideoX.rgba_utils import * 30 | from diffusers import CogVideoXDPMScheduler 31 | 32 | from diffusers.utils import load_video, load_image, export_to_video 33 | from datetime import datetime, timedelta 34 | 35 | from diffusers.image_processor import VaeImageProcessor 36 | import moviepy.editor as mp 37 | import numpy as np 38 | from huggingface_hub import hf_hub_download, snapshot_download 39 | import gc 40 | 41 | device = "cuda" if torch.cuda.is_available() else "cpu" 42 | 43 | # hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") 44 | hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora") 45 | # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") 46 | 47 | pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16) 48 | # pipe.enable_sequential_cpu_offload() 49 | pipe.vae.enable_slicing() 50 | pipe.vae.enable_tiling() 51 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 52 | seq_length = 2 * ( 53 | (480 // pipe.vae_scale_factor_spatial // 2) 54 | * (720 // pipe.vae_scale_factor_spatial // 2) 55 | * ((13 - 1) // pipe.vae_scale_factor_temporal + 1) 56 | ) 57 | prepare_for_rgba_inference( 58 | pipe.transformer, 59 | rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors", 60 | device="cuda", 61 | dtype=torch.bfloat16, 62 | text_length=226, 63 | seq_length=seq_length, # this is for the creation of attention mask. 64 | ) 65 | 66 | # pipe.transformer.to(memory_format=torch.channels_last) 67 | # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) 68 | # pipe_image.transformer.to(memory_format=torch.channels_last) 69 | # pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True) 70 | 71 | os.makedirs("./output", exist_ok=True) 72 | os.makedirs("./gradio_tmp", exist_ok=True) 73 | 74 | # upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device) 75 | # frame_interpolation_model = load_rife_model("model_rife") 76 | 77 | 78 | sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. 79 | For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 80 | There are a few rules to follow: 81 | You will only ever output a single video description per user request. 82 | When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions. 83 | Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user. 84 | Video descriptions must have the same num of words as examples below. Extra words will be ignored. 85 | """ 86 | def save_video(tensor: Union[List[np.ndarray], List[Image.Image]], fps: int = 8, prefix='rgb'): 87 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 88 | video_path = f"./output/{prefix}_{timestamp}.mp4" 89 | os.makedirs(os.path.dirname(video_path), exist_ok=True) 90 | export_to_video(tensor, video_path, fps=fps) 91 | return video_path 92 | 93 | def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): 94 | width, height = get_video_dimensions(input_video) 95 | 96 | if width == 720 and height == 480: 97 | processed_video = input_video 98 | else: 99 | processed_video = center_crop_resize(input_video) 100 | return processed_video 101 | 102 | 103 | def get_video_dimensions(input_video_path): 104 | reader = imageio_ffmpeg.read_frames(input_video_path) 105 | metadata = next(reader) 106 | return metadata["size"] 107 | 108 | 109 | def center_crop_resize(input_video_path, target_width=720, target_height=480): 110 | cap = cv2.VideoCapture(input_video_path) 111 | 112 | orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 113 | orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 114 | orig_fps = cap.get(cv2.CAP_PROP_FPS) 115 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 116 | 117 | width_factor = target_width / orig_width 118 | height_factor = target_height / orig_height 119 | resize_factor = max(width_factor, height_factor) 120 | 121 | inter_width = int(orig_width * resize_factor) 122 | inter_height = int(orig_height * resize_factor) 123 | 124 | target_fps = 8 125 | ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) 126 | skip = min(5, ideal_skip) # Cap at 5 127 | 128 | while (total_frames / (skip + 1)) < 49 and skip > 0: 129 | skip -= 1 130 | 131 | processed_frames = [] 132 | frame_count = 0 133 | total_read = 0 134 | 135 | while frame_count < 49 and total_read < total_frames: 136 | ret, frame = cap.read() 137 | if not ret: 138 | break 139 | 140 | if total_read % (skip + 1) == 0: 141 | resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) 142 | 143 | start_x = (inter_width - target_width) // 2 144 | start_y = (inter_height - target_height) // 2 145 | cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] 146 | 147 | processed_frames.append(cropped) 148 | frame_count += 1 149 | 150 | total_read += 1 151 | 152 | cap.release() 153 | 154 | with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: 155 | temp_video_path = temp_file.name 156 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 157 | out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) 158 | 159 | for frame in processed_frames: 160 | out.write(frame) 161 | 162 | out.release() 163 | 164 | return temp_video_path 165 | 166 | 167 | 168 | def infer( 169 | prompt: str, 170 | num_inference_steps: int, 171 | guidance_scale: float, 172 | seed: int = -1, 173 | progress=gr.Progress(track_tqdm=True), 174 | ): 175 | if seed == -1: 176 | seed = random.randint(0, 2**8 - 1) 177 | pipe.to(device) 178 | video_pt = pipe( 179 | prompt=prompt + ", isolated background", 180 | num_videos_per_prompt=1, 181 | num_inference_steps=num_inference_steps, 182 | num_frames=13, 183 | use_dynamic_cfg=True, 184 | output_type="latent", 185 | guidance_scale=guidance_scale, 186 | generator=torch.Generator(device=device).manual_seed(int(seed)), 187 | ).frames 188 | # pipe.to("cpu") 189 | gc.collect() 190 | return (video_pt, seed) 191 | 192 | 193 | def convert_to_gif(video_path): 194 | clip = mp.VideoFileClip(video_path) 195 | clip = clip.set_fps(8) 196 | clip = clip.resize(height=240) 197 | gif_path = video_path.replace(".mp4", ".gif") 198 | clip.write_gif(gif_path, fps=8) 199 | return gif_path 200 | 201 | 202 | def delete_old_files(): 203 | while True: 204 | now = datetime.now() 205 | cutoff = now - timedelta(minutes=10) 206 | directories = ["./output", "./gradio_tmp"] 207 | 208 | for directory in directories: 209 | for filename in os.listdir(directory): 210 | file_path = os.path.join(directory, filename) 211 | if os.path.isfile(file_path): 212 | file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) 213 | if file_mtime < cutoff: 214 | os.remove(file_path) 215 | time.sleep(600) 216 | 217 | 218 | threading.Thread(target=delete_old_files, daemon=True).start() 219 | 220 | with gr.Blocks() as demo: 221 | gr.HTML(""" 222 |
223 | TransPixar + CogVideoX-5B Huggingface Space🤗 224 |
225 |
226 | 🤗 TransPixar LoRA Hub | 227 | 🌐 Github | 228 | 📜 arxiv 229 |
230 |
231 | ⚠️ This demo is for academic research and experiential use only. 232 |
233 | """) 234 | with gr.Row(): 235 | with gr.Column(): 236 | prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) 237 | with gr.Group(): 238 | with gr.Column(): 239 | with gr.Row(): 240 | seed_param = gr.Number( 241 | label="Inference Seed (Enter a positive number, -1 for random)", value=-1 242 | ) 243 | 244 | generate_button = gr.Button("🎬 Generate Video") 245 | with gr.Row(): 246 | gr.Markdown( 247 | """ 248 | **Note:** The output RGB is a premultiplied version to avoid the color decontamination problem. 249 | It can directly composite with a background using: 250 | ``` 251 | composite = rgb + (1 - alpha) * background 252 | ``` 253 | """ 254 | ) 255 | 256 | with gr.Column(): 257 | rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480) 258 | alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480) 259 | with gr.Row(): 260 | download_rgb_video_button = gr.File(label="📥 Download RGB Video", visible=False) 261 | download_alpha_video_button = gr.File(label="📥 Download Alpha Video", visible=False) 262 | seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) 263 | 264 | 265 | def generate( 266 | prompt, 267 | seed_value, 268 | progress=gr.Progress(track_tqdm=True) 269 | ): 270 | latents, seed = infer( 271 | prompt, 272 | num_inference_steps=25, # NOT Changed 273 | guidance_scale=7.0, # NOT Changed 274 | seed=seed_value, 275 | progress=progress, 276 | ) 277 | 278 | latents_rgb, latents_alpha = latents.chunk(2, dim=1) 279 | 280 | frames_rgb = decode_latents(pipe, latents_rgb) 281 | frames_alpha = decode_latents(pipe, latents_alpha) 282 | 283 | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) 284 | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) 285 | premultiplied_rgb = frames_rgb * frames_alpha_pooled 286 | 287 | rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb') 288 | rgb_video_update = gr.update(visible=True, value=rgb_video_path) 289 | 290 | alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha') 291 | alpha_video_update = gr.update(visible=True, value=alpha_video_path) 292 | seed_update = gr.update(visible=True, value=seed) 293 | 294 | return rgb_video_path, alpha_video_path, rgb_video_update, alpha_video_update, seed_update 295 | 296 | 297 | generate_button.click( 298 | generate, 299 | inputs=[prompt, seed_param], 300 | outputs=[rgb_video_output, alpha_video_output, download_rgb_video_button, download_alpha_video_button, seed_text], 301 | ) 302 | 303 | 304 | if __name__ == "__main__": 305 | demo.queue(max_size=15) 306 | demo.launch() 307 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | torchvision 3 | torchaudio 4 | wandb 5 | gradio 6 | sentencepiece 7 | diffusers==0.32.0 8 | huggingface_hub==0.27.0 9 | transformers 10 | imageio>=2.5.0 11 | imageio-ffmpeg 12 | moviepy==1.0.3 13 | opencv-python>=4.5 14 | accelerate 15 | -------------------------------------------------------------------------------- /wechat_group.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wileewang/TransPixeler/a704f8037f5ceb93770eb432f9db847ad5a49c27/wechat_group.jpg --------------------------------------------------------------------------------