.embed.pt
39 | ```
40 |
41 |
42 | Now, we're ready to fine-tune. To launch, run:
43 |
44 | ```bash
45 | bash train.sh
46 | ```
47 | **Note:**
48 |
49 | The arg `--num_frames` is used to specify the number of frames of generated **RGB** video. During generation, we will actually double the number of frames to generate the **RGB** video and **Alpha** video jointly. This double operation is automatically handled by our implementation.
50 |
51 | For an 80GB GPU, we support processing RGB videos with dimensions of 480 × 848 × 79 (Height × Width × Frames) at a batch size of 1 using bfloat16 precision for training. However, the training is relatively slow (over one minute per iteration) because the model processes a total of 79 × 2 frames as input.
52 |
53 |
54 |
55 |
56 |
57 | ~~We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.~~
58 |
59 | ## Inference
60 |
61 | To generate the RGBA video, run:
62 |
63 | ```bash
64 | python cli.py \
65 | --lora_path /path/to/lora \
66 | --prompt "..." \
67 | ```
68 |
69 | This command generates the RGB and Alpha videos simultaneously and saves them. Specifically, the RGB video is saved in its premultiplied form. To blend this video with any background image, you can simply use the following formula:
70 |
71 | ```python
72 | com = rgb + (1 - alpha) * bgr
73 | ```
74 |
75 | ## Known limitations
76 |
77 | (Contributions are welcome 🤗)
78 |
79 | Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
80 |
81 | * No support for distributed training.
82 | * `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
83 | * No support for 8bit optimizers (but should be relatively easy to add).
84 |
85 | **Misc**:
86 |
87 | * We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
88 | * `embed.py` script is non-batched.
89 |
--------------------------------------------------------------------------------
/Mochi/args.py:
--------------------------------------------------------------------------------
1 | """
2 | Default values taken from
3 | https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
4 | when applicable.
5 | """
6 |
7 | import argparse
8 |
9 |
10 | def _get_model_args(parser: argparse.ArgumentParser) -> None:
11 | parser.add_argument(
12 | "--pretrained_model_name_or_path",
13 | type=str,
14 | default=None,
15 | required=True,
16 | help="Path to pretrained model or model identifier from huggingface.co/models.",
17 | )
18 | parser.add_argument(
19 | "--revision",
20 | type=str,
21 | default=None,
22 | required=False,
23 | help="Revision of pretrained model identifier from huggingface.co/models.",
24 | )
25 | parser.add_argument(
26 | "--variant",
27 | type=str,
28 | default=None,
29 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
30 | )
31 | parser.add_argument(
32 | "--cache_dir",
33 | type=str,
34 | default=None,
35 | help="The directory where the downloaded models and datasets will be stored.",
36 | )
37 | parser.add_argument(
38 | "--cast_dit",
39 | action="store_true",
40 | help="If we should cast DiT params to a lower precision.",
41 | )
42 | parser.add_argument(
43 | "--compile_dit",
44 | action="store_true",
45 | help="If we should compile the DiT.",
46 | )
47 |
48 |
49 | def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
50 | parser.add_argument(
51 | "--data_root",
52 | type=str,
53 | default=None,
54 | help=("A folder containing the training data."),
55 | )
56 | parser.add_argument(
57 | "--caption_dropout",
58 | type=float,
59 | default=None,
60 | help=("Probability to drop out captions randomly."),
61 | )
62 |
63 | parser.add_argument(
64 | "--dataloader_num_workers",
65 | type=int,
66 | default=0,
67 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
68 | )
69 | parser.add_argument(
70 | "--pin_memory",
71 | action="store_true",
72 | help="Whether or not to use the pinned memory setting in pytorch dataloader.",
73 | )
74 |
75 |
76 | def _get_validation_args(parser: argparse.ArgumentParser) -> None:
77 | parser.add_argument(
78 | "--validation_prompt",
79 | type=str,
80 | default=None,
81 | help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
82 | )
83 | parser.add_argument(
84 | "--validation_images",
85 | type=str,
86 | default=None,
87 | help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
88 | )
89 | parser.add_argument(
90 | "--validation_prompt_separator",
91 | type=str,
92 | default=":::",
93 | help="String that separates multiple validation prompts",
94 | )
95 | parser.add_argument(
96 | "--num_validation_videos",
97 | type=int,
98 | default=1,
99 | help="Number of videos that should be generated during validation per `validation_prompt`.",
100 | )
101 | parser.add_argument(
102 | "--validation_epochs",
103 | type=int,
104 | default=50,
105 | help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
106 | )
107 | parser.add_argument(
108 | "--enable_slicing",
109 | action="store_true",
110 | default=False,
111 | help="Whether or not to use VAE slicing for saving memory.",
112 | )
113 | parser.add_argument(
114 | "--enable_tiling",
115 | action="store_true",
116 | default=False,
117 | help="Whether or not to use VAE tiling for saving memory.",
118 | )
119 | parser.add_argument(
120 | "--enable_model_cpu_offload",
121 | action="store_true",
122 | default=False,
123 | help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
124 | )
125 | parser.add_argument(
126 | "--fps",
127 | type=int,
128 | default=30,
129 | help="FPS to use when serializing the output videos.",
130 | )
131 | parser.add_argument(
132 | "--height",
133 | type=int,
134 | default=480,
135 | )
136 | parser.add_argument(
137 | "--width",
138 | type=int,
139 | default=848,
140 | )
141 |
142 |
143 | def _get_training_args(parser: argparse.ArgumentParser) -> None:
144 | parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
145 | parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
146 | parser.add_argument(
147 | "--lora_alpha",
148 | type=int,
149 | default=16,
150 | help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
151 | )
152 | parser.add_argument(
153 | "--target_modules",
154 | nargs="+",
155 | type=str,
156 | default=["to_k", "to_q", "to_v", "to_out.0"],
157 | help="Target modules to train LoRA for.",
158 | )
159 | parser.add_argument(
160 | "--output_dir",
161 | type=str,
162 | default="mochi-lora",
163 | help="The output directory where the model predictions and checkpoints will be written.",
164 | )
165 | parser.add_argument(
166 | "--train_batch_size",
167 | type=int,
168 | default=4,
169 | help="Batch size (per device) for the training dataloader.",
170 | )
171 | parser.add_argument("--num_train_epochs", type=int, default=1)
172 | parser.add_argument(
173 | "--max_train_steps",
174 | type=int,
175 | default=None,
176 | help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
177 | )
178 | parser.add_argument(
179 | "--gradient_checkpointing",
180 | action="store_true",
181 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
182 | )
183 | parser.add_argument(
184 | "--learning_rate",
185 | type=float,
186 | default=2e-4,
187 | help="Initial learning rate (after the potential warmup period) to use.",
188 | )
189 | parser.add_argument(
190 | "--scale_lr",
191 | action="store_true",
192 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
193 | )
194 | parser.add_argument(
195 | "--lr_warmup_steps",
196 | type=int,
197 | default=200,
198 | help="Number of steps for the warmup in the lr scheduler.",
199 | )
200 | parser.add_argument(
201 | "--checkpointing_steps",
202 | type=int,
203 | default=1000,
204 | )
205 | parser.add_argument(
206 | "--resume_from_checkpoint",
207 | type=str,
208 | default=None,
209 | )
210 |
211 |
212 | def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
213 | parser.add_argument(
214 | "--optimizer",
215 | type=lambda s: s.lower(),
216 | default="adam",
217 | choices=["adam", "adamw"],
218 | help=("The optimizer type to use."),
219 | )
220 | parser.add_argument(
221 | "--weight_decay",
222 | type=float,
223 | default=0.01,
224 | help="Weight decay to use for optimizer.",
225 | )
226 |
227 |
228 | def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
229 | parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
230 | parser.add_argument(
231 | "--push_to_hub",
232 | action="store_true",
233 | help="Whether or not to push the model to the Hub.",
234 | )
235 | parser.add_argument(
236 | "--hub_token",
237 | type=str,
238 | default=None,
239 | help="The token to use to push to the Model Hub.",
240 | )
241 | parser.add_argument(
242 | "--hub_model_id",
243 | type=str,
244 | default=None,
245 | help="The name of the repository to keep in sync with the local `output_dir`.",
246 | )
247 | parser.add_argument(
248 | "--allow_tf32",
249 | action="store_true",
250 | help=(
251 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
252 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
253 | ),
254 | )
255 | parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.")
256 |
257 |
258 | def get_args():
259 | parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
260 |
261 | _get_model_args(parser)
262 | _get_dataset_args(parser)
263 | _get_training_args(parser)
264 | _get_validation_args(parser)
265 | _get_optimizer_args(parser)
266 | _get_configuration_args(parser)
267 |
268 | return parser.parse_args()
269 |
--------------------------------------------------------------------------------
/Mochi/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | # from diffusers import MochiPipeline
4 | from pipeline_mochi_rgba import MochiPipeline
5 | from diffusers.utils import export_to_video
6 | import argparse
7 | from rgba_utils import *
8 | import numpy as np
9 |
10 |
11 | def main(args):
12 | # 1. load pipeline
13 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16).to("cuda")
14 | pipe.enable_vae_tiling()
15 |
16 | # 2. define prompt and arguments
17 | pipeline_args = {
18 | "prompt": args.prompt,
19 | "guidance_scale": args.guidance_scale,
20 | "num_inference_steps": args.num_inference_steps,
21 | "height": args.height,
22 | "width": args.width,
23 | "num_frames": args.num_frames,
24 | "max_sequence_length": 256,
25 | "output_type": "latent",
26 | }
27 |
28 | # 3. prepare rgbx utils
29 | prepare_for_rgba_inference(
30 | pipe.transformer,
31 | device="cuda",
32 | dtype=torch.bfloat16,
33 | )
34 |
35 | if args.lora_path is not None:
36 | checkpoint = torch.load(args.lora_path, map_location="cpu")
37 | processor_state_dict = checkpoint["state_dict"]
38 | load_processor_state_dict(pipe.transformer, processor_state_dict)
39 |
40 |
41 | # 4. inference
42 | generator = torch.manual_seed(args.seed) if args.seed else None
43 | frames_latents = pipe(**pipeline_args, generator=generator).frames
44 |
45 | frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=2)
46 |
47 | frames_rgb = decode_latents(pipe, frames_latents_rgb)
48 | frames_alpha = decode_latents(pipe, frames_latents_alpha)
49 |
50 | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
51 | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
52 | premultiplied_rgb = frames_rgb * frames_alpha_pooled
53 |
54 | if os.path.exists(args.output_path) == False:
55 | os.makedirs(args.output_path)
56 |
57 | export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
58 | export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
59 | export_to_video(frames_rgb[0], os.path.join(args.output_path, "original_rgb.mp4"), fps=args.fps)
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
63 | parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
64 | parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
65 |
66 | parser.add_argument(
67 | "--model_path", type=str, default="genmo/mochi-1-preview", help="Path of the pre-trained model use"
68 | )
69 | parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
70 | parser.add_argument("--guidance_scale", type=float, default=6, help="The scale for classifier-free guidance")
71 | parser.add_argument("--num_inference_steps", type=int, default=64, help="Inference steps")
72 | parser.add_argument("--num_frames", type=int, default=79, help="Number of steps for the inference process")
73 | parser.add_argument("--width", type=int, default=848, help="Number of steps for the inference process")
74 | parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
75 | parser.add_argument("--fps", type=int, default=30, help="Number of steps for the inference process")
76 | parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
77 | args = parser.parse_args()
78 |
79 | main(args)
80 |
--------------------------------------------------------------------------------
/Mochi/dataset_simple.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from
3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
4 | """
5 |
6 | from pathlib import Path
7 |
8 | import click
9 | import torch
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 |
13 | def load_to_cpu(x):
14 | return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
15 |
16 |
17 | class LatentEmbedDataset(Dataset):
18 | def __init__(self, file_paths, repeat=1):
19 | self.items = [
20 | (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
21 | for p in file_paths
22 | if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
23 | ]
24 | self.items = self.items * repeat
25 | print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
26 |
27 | def __len__(self):
28 | return len(self.items)
29 |
30 | def __getitem__(self, idx):
31 | latent_path, embed_path = self.items[idx]
32 | return load_to_cpu(latent_path), load_to_cpu(embed_path)
33 |
34 |
35 | @click.command()
36 | @click.argument("directory", type=click.Path(exists=True, file_okay=False))
37 | def process_videos(directory):
38 | dir_path = Path(directory)
39 | mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
40 | assert mp4_files, f"No mp4 files found"
41 |
42 | dataset = LatentEmbedDataset(mp4_files)
43 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
44 |
45 | for latents, embeds in dataloader:
46 | print([(k, v.shape) for k, v in latents.items()])
47 |
48 |
49 | if __name__ == "__main__":
50 | process_videos()
51 |
--------------------------------------------------------------------------------
/Mochi/embed.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from:
3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
4 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
5 | """
6 |
7 | import click
8 | import torch
9 | import torchvision
10 | from pathlib import Path
11 | from diffusers import AutoencoderKLMochi, MochiPipeline
12 | from transformers import T5EncoderModel, T5Tokenizer
13 | from tqdm.auto import tqdm
14 |
15 |
16 | def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
17 | T, H, W = [int(s) for s in shape.split("x")]
18 | assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
19 | video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
20 | fps = metadata["video_fps"]
21 | video = video.permute(3, 0, 1, 2)
22 | og_shape = video.shape
23 | assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
24 | assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
25 | assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
26 | if video.shape[1] > T:
27 | video = video[:, :T]
28 | print(f"Trimmed video from {og_shape[1]} to first {T} frames")
29 | video = video.unsqueeze(0)
30 | video = video.float() / 127.5 - 1.0
31 | video = video.to(model.device)
32 |
33 | assert video.ndim == 5
34 |
35 | with torch.inference_mode():
36 | with torch.autocast("cuda", dtype=torch.bfloat16):
37 | ldist = model._encode(video)
38 |
39 | torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
40 |
41 |
42 | @click.command()
43 | @click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
44 | @click.option(
45 | "--model_id",
46 | type=str,
47 | help="Repo id. Should be genmo/mochi-1-preview",
48 | default="genmo/mochi-1-preview",
49 | )
50 | @click.option("--shape", default="163x480x848", help="Shape of the video to encode")
51 | @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
52 | def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
53 | """Process all videos and captions in a directory using a single GPU."""
54 | # comment out when running on unsupported hardware
55 | torch.backends.cuda.matmul.allow_tf32 = True
56 | torch.backends.cudnn.allow_tf32 = True
57 |
58 | # Get all video paths
59 | video_paths = list(output_dir.glob("**/*.mp4"))
60 | if not video_paths:
61 | print(f"No MP4 files found in {output_dir}")
62 | return
63 |
64 | text_paths = list(output_dir.glob("**/*.txt"))
65 | if not text_paths:
66 | print(f"No text files found in {output_dir}")
67 | return
68 |
69 | # load the models
70 | vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
71 | text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
72 | tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
73 | pipeline = MochiPipeline.from_pretrained(
74 | model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
75 | ).to("cuda")
76 |
77 | for idx, video_path in tqdm(enumerate(sorted(video_paths))):
78 | print(f"Processing {video_path}")
79 | try:
80 | if video_path.with_suffix(".latent.pt").exists() and not overwrite:
81 | print(f"Skipping {video_path}")
82 | continue
83 |
84 | # encode videos.
85 | encode_videos(vae, vid_path=video_path, shape=shape)
86 |
87 | # embed captions.
88 | prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
89 | embed_path = prompt_path.with_suffix(".embed.pt")
90 |
91 | if embed_path.exists() and not overwrite:
92 | print(f"Skipping {prompt_path} - embeddings already exist")
93 | continue
94 |
95 | with open(prompt_path) as f:
96 | text = f.read().strip()
97 | with torch.inference_mode():
98 | conditioning = pipeline.encode_prompt(prompt=[text])
99 |
100 | conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
101 | torch.save(conditioning, embed_path)
102 |
103 | except Exception as e:
104 | import traceback
105 |
106 | traceback.print_exc()
107 | print(f"Error processing {video_path}: {str(e)}")
108 |
109 |
110 | if __name__ == "__main__":
111 | batch_process()
--------------------------------------------------------------------------------
/Mochi/pipeline_mochi_rgba.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import Any, Callable, Dict, List, Optional, Union
17 |
18 | import numpy as np
19 | import torch
20 | from transformers import T5EncoderModel, T5TokenizerFast
21 |
22 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
23 | from diffusers.loaders import Mochi1LoraLoaderMixin
24 | from diffusers.models.autoencoders import AutoencoderKL
25 | from diffusers.models.transformers import MochiTransformer3DModel
26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27 | from diffusers.utils import (
28 | is_torch_xla_available,
29 | logging,
30 | replace_example_docstring,
31 | )
32 | from diffusers.utils.torch_utils import randn_tensor
33 | from diffusers.video_processor import VideoProcessor
34 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35 | from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
36 |
37 |
38 | if is_torch_xla_available():
39 | import torch_xla.core.xla_model as xm
40 |
41 | XLA_AVAILABLE = True
42 | else:
43 | XLA_AVAILABLE = False
44 |
45 |
46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47 |
48 | EXAMPLE_DOC_STRING = """
49 | Examples:
50 | ```py
51 | >>> import torch
52 | >>> from diffusers import MochiPipeline
53 | >>> from diffusers.utils import export_to_video
54 |
55 | >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
56 | >>> pipe.enable_model_cpu_offload()
57 | >>> pipe.enable_vae_tiling()
58 | >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
59 | >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
60 | >>> export_to_video(frames, "mochi.mp4")
61 | ```
62 | """
63 |
64 |
65 | def calculate_shift(
66 | image_seq_len,
67 | base_seq_len: int = 256,
68 | max_seq_len: int = 4096,
69 | base_shift: float = 0.5,
70 | max_shift: float = 1.16,
71 | ):
72 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
73 | b = base_shift - m * base_seq_len
74 | mu = image_seq_len * m + b
75 | return mu
76 |
77 |
78 | # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
79 | def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
80 | if linear_steps is None:
81 | linear_steps = num_steps // 2
82 | linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
83 | threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
84 | quadratic_steps = num_steps - linear_steps
85 | quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
86 | linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
87 | const = quadratic_coef * (linear_steps**2)
88 | quadratic_sigma_schedule = [
89 | quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
90 | ]
91 | sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
92 | sigma_schedule = [1.0 - x for x in sigma_schedule]
93 | return sigma_schedule
94 |
95 |
96 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
97 | def retrieve_timesteps(
98 | scheduler,
99 | num_inference_steps: Optional[int] = None,
100 | device: Optional[Union[str, torch.device]] = None,
101 | timesteps: Optional[List[int]] = None,
102 | sigmas: Optional[List[float]] = None,
103 | **kwargs,
104 | ):
105 | r"""
106 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
107 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
108 |
109 | Args:
110 | scheduler (`SchedulerMixin`):
111 | The scheduler to get timesteps from.
112 | num_inference_steps (`int`):
113 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
114 | must be `None`.
115 | device (`str` or `torch.device`, *optional*):
116 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
117 | timesteps (`List[int]`, *optional*):
118 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
119 | `num_inference_steps` and `sigmas` must be `None`.
120 | sigmas (`List[float]`, *optional*):
121 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
122 | `num_inference_steps` and `timesteps` must be `None`.
123 |
124 | Returns:
125 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
126 | second element is the number of inference steps.
127 | """
128 | if timesteps is not None and sigmas is not None:
129 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
130 | if timesteps is not None:
131 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132 | if not accepts_timesteps:
133 | raise ValueError(
134 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135 | f" timestep schedules. Please check whether you are using the correct scheduler."
136 | )
137 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
138 | timesteps = scheduler.timesteps
139 | num_inference_steps = len(timesteps)
140 | elif sigmas is not None:
141 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
142 | if not accept_sigmas:
143 | raise ValueError(
144 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
145 | f" sigmas schedules. Please check whether you are using the correct scheduler."
146 | )
147 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
148 | timesteps = scheduler.timesteps
149 | num_inference_steps = len(timesteps)
150 | else:
151 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
152 | timesteps = scheduler.timesteps
153 | return timesteps, num_inference_steps
154 |
155 |
156 |
157 |
158 |
159 |
160 | def prepare_attention_mask(prompt_attention_mask, latents):
161 |
162 | device = prompt_attention_mask.device
163 |
164 | (_, _, num_frames, height, width) = latents.shape # shape of two modalities
165 | seq_length = (height // 2) * (width // 2) * num_frames
166 |
167 | rect_attention_mask = []
168 | for prompt_attention_mask_i in prompt_attention_mask:
169 | text_length = torch.sum(prompt_attention_mask_i).item()
170 | total_length = text_length + seq_length
171 |
172 | if text_length == 0:
173 | rect_attention_mask.append(None)
174 | else:
175 | dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
176 | dense_mask[seq_length:, seq_length // 2: seq_length] = False
177 | rect_attention_mask.append(dense_mask.to(device))
178 |
179 | return {
180 | "prompt_attention_mask": prompt_attention_mask,
181 | "rect_attention_mask": rect_attention_mask,
182 | }
183 |
184 |
185 | class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
186 | r"""
187 | The mochi pipeline for text-to-video generation.
188 |
189 | Reference: https://github.com/genmoai/models
190 |
191 | Args:
192 | transformer ([`MochiTransformer3DModel`]):
193 | Conditional Transformer architecture to denoise the encoded video latents.
194 | scheduler ([`FlowMatchEulerDiscreteScheduler`]):
195 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
196 | vae ([`AutoencoderKL`]):
197 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
198 | text_encoder ([`T5EncoderModel`]):
199 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
200 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
201 | tokenizer (`CLIPTokenizer`):
202 | Tokenizer of class
203 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
204 | tokenizer (`T5TokenizerFast`):
205 | Second Tokenizer of class
206 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
207 | """
208 |
209 | model_cpu_offload_seq = "text_encoder->transformer->vae"
210 | _optional_components = []
211 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
212 |
213 | def __init__(
214 | self,
215 | scheduler: FlowMatchEulerDiscreteScheduler,
216 | vae: AutoencoderKL,
217 | text_encoder: T5EncoderModel,
218 | tokenizer: T5TokenizerFast,
219 | transformer: MochiTransformer3DModel,
220 | force_zeros_for_empty_prompt: bool = False,
221 | ):
222 | super().__init__()
223 |
224 | self.register_modules(
225 | vae=vae,
226 | text_encoder=text_encoder,
227 | tokenizer=tokenizer,
228 | transformer=transformer,
229 | scheduler=scheduler,
230 | )
231 | # TODO: determine these scaling factors from model parameters
232 | self.vae_spatial_scale_factor = 8
233 | self.vae_temporal_scale_factor = 6
234 | self.patch_size = 2
235 |
236 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
237 | self.tokenizer_max_length = (
238 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
239 | )
240 | self.default_height = 480
241 | self.default_width = 848
242 | self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
243 |
244 | def _get_t5_prompt_embeds(
245 | self,
246 | prompt: Union[str, List[str]] = None,
247 | num_videos_per_prompt: int = 1,
248 | max_sequence_length: int = 256,
249 | device: Optional[torch.device] = None,
250 | dtype: Optional[torch.dtype] = None,
251 | ):
252 | device = device or self._execution_device
253 | dtype = dtype or self.text_encoder.dtype
254 |
255 | prompt = [prompt] if isinstance(prompt, str) else prompt
256 | batch_size = len(prompt)
257 |
258 | text_inputs = self.tokenizer(
259 | prompt,
260 | padding="max_length",
261 | max_length=max_sequence_length,
262 | truncation=True,
263 | add_special_tokens=True,
264 | return_tensors="pt",
265 | )
266 |
267 | text_input_ids = text_inputs.input_ids
268 | prompt_attention_mask = text_inputs.attention_mask
269 | prompt_attention_mask = prompt_attention_mask.bool().to(device)
270 |
271 | # The original Mochi implementation zeros out empty negative prompts
272 | # but this can lead to overflow when placing the entire pipeline under the autocast context
273 | # adding this here so that we can enable zeroing prompts if necessary
274 | if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
275 | text_input_ids = torch.zeros_like(text_input_ids, device=device)
276 | prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
277 |
278 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
279 |
280 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
281 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
282 | logger.warning(
283 | "The following part of your input was truncated because `max_sequence_length` is set to "
284 | f" {max_sequence_length} tokens: {removed_text}"
285 | )
286 |
287 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
288 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
289 |
290 | # duplicate text embeddings for each generation per prompt, using mps friendly method
291 | _, seq_len, _ = prompt_embeds.shape
292 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
293 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
294 |
295 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
296 | prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
297 |
298 | return prompt_embeds, prompt_attention_mask
299 |
300 | # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
301 | def encode_prompt(
302 | self,
303 | prompt: Union[str, List[str]],
304 | negative_prompt: Optional[Union[str, List[str]]] = None,
305 | do_classifier_free_guidance: bool = True,
306 | num_videos_per_prompt: int = 1,
307 | prompt_embeds: Optional[torch.Tensor] = None,
308 | negative_prompt_embeds: Optional[torch.Tensor] = None,
309 | prompt_attention_mask: Optional[torch.Tensor] = None,
310 | negative_prompt_attention_mask: Optional[torch.Tensor] = None,
311 | max_sequence_length: int = 256,
312 | device: Optional[torch.device] = None,
313 | dtype: Optional[torch.dtype] = None,
314 | ):
315 | r"""
316 | Encodes the prompt into text encoder hidden states.
317 |
318 | Args:
319 | prompt (`str` or `List[str]`, *optional*):
320 | prompt to be encoded
321 | negative_prompt (`str` or `List[str]`, *optional*):
322 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
323 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
324 | less than `1`).
325 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
326 | Whether to use classifier free guidance or not.
327 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
328 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
329 | prompt_embeds (`torch.Tensor`, *optional*):
330 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
331 | provided, text embeddings will be generated from `prompt` input argument.
332 | negative_prompt_embeds (`torch.Tensor`, *optional*):
333 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
334 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
335 | argument.
336 | device: (`torch.device`, *optional*):
337 | torch device
338 | dtype: (`torch.dtype`, *optional*):
339 | torch dtype
340 | """
341 | device = device or self._execution_device
342 |
343 | prompt = [prompt] if isinstance(prompt, str) else prompt
344 | if prompt is not None:
345 | batch_size = len(prompt)
346 | else:
347 | batch_size = prompt_embeds.shape[0]
348 |
349 | if prompt_embeds is None:
350 | prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
351 | prompt=prompt,
352 | num_videos_per_prompt=num_videos_per_prompt,
353 | max_sequence_length=max_sequence_length,
354 | device=device,
355 | dtype=dtype,
356 | )
357 |
358 | if do_classifier_free_guidance and negative_prompt_embeds is None:
359 | negative_prompt = negative_prompt or ""
360 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
361 |
362 | if prompt is not None and type(prompt) is not type(negative_prompt):
363 | raise TypeError(
364 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
365 | f" {type(prompt)}."
366 | )
367 | elif batch_size != len(negative_prompt):
368 | raise ValueError(
369 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
370 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
371 | " the batch size of `prompt`."
372 | )
373 |
374 | negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
375 | prompt=negative_prompt,
376 | num_videos_per_prompt=num_videos_per_prompt,
377 | max_sequence_length=max_sequence_length,
378 | device=device,
379 | dtype=dtype,
380 | )
381 |
382 | return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
383 |
384 | def check_inputs(
385 | self,
386 | prompt,
387 | height,
388 | width,
389 | callback_on_step_end_tensor_inputs=None,
390 | prompt_embeds=None,
391 | negative_prompt_embeds=None,
392 | prompt_attention_mask=None,
393 | negative_prompt_attention_mask=None,
394 | ):
395 | if height % 8 != 0 or width % 8 != 0:
396 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
397 |
398 | if callback_on_step_end_tensor_inputs is not None and not all(
399 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
400 | ):
401 | raise ValueError(
402 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
403 | )
404 |
405 | if prompt is not None and prompt_embeds is not None:
406 | raise ValueError(
407 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
408 | " only forward one of the two."
409 | )
410 | elif prompt is None and prompt_embeds is None:
411 | raise ValueError(
412 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
413 | )
414 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
415 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
416 |
417 | if prompt_embeds is not None and prompt_attention_mask is None:
418 | raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
419 |
420 | if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
421 | raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
422 |
423 | if prompt_embeds is not None and negative_prompt_embeds is not None:
424 | if prompt_embeds.shape != negative_prompt_embeds.shape:
425 | raise ValueError(
426 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
427 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
428 | f" {negative_prompt_embeds.shape}."
429 | )
430 | if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
431 | raise ValueError(
432 | "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
433 | f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
434 | f" {negative_prompt_attention_mask.shape}."
435 | )
436 |
437 | def enable_vae_slicing(self):
438 | r"""
439 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
440 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
441 | """
442 | self.vae.enable_slicing()
443 |
444 | def disable_vae_slicing(self):
445 | r"""
446 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
447 | computing decoding in one step.
448 | """
449 | self.vae.disable_slicing()
450 |
451 | def enable_vae_tiling(self):
452 | r"""
453 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
454 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
455 | processing larger images.
456 | """
457 | self.vae.enable_tiling()
458 |
459 | def disable_vae_tiling(self):
460 | r"""
461 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
462 | computing decoding in one step.
463 | """
464 | self.vae.disable_tiling()
465 |
466 | def prepare_latents(
467 | self,
468 | batch_size,
469 | num_channels_latents,
470 | height,
471 | width,
472 | num_frames,
473 | dtype,
474 | device,
475 | generator,
476 | latents=None,
477 | ):
478 | height = height // self.vae_spatial_scale_factor
479 | width = width // self.vae_spatial_scale_factor
480 | num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
481 |
482 | shape = (batch_size, num_channels_latents, num_frames, height, width)
483 |
484 | if latents is not None:
485 | return latents.to(device=device, dtype=dtype)
486 | if isinstance(generator, list) and len(generator) != batch_size:
487 | raise ValueError(
488 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
489 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
490 | )
491 |
492 | latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
493 | latents = latents.to(dtype)
494 | return latents
495 |
496 | @property
497 | def guidance_scale(self):
498 | return self._guidance_scale
499 |
500 | @property
501 | def do_classifier_free_guidance(self):
502 | return self._guidance_scale > 1.0
503 |
504 | @property
505 | def num_timesteps(self):
506 | return self._num_timesteps
507 |
508 | @property
509 | def attention_kwargs(self):
510 | return self._attention_kwargs
511 |
512 | @property
513 | def interrupt(self):
514 | return self._interrupt
515 |
516 |
517 | @torch.no_grad()
518 | @replace_example_docstring(EXAMPLE_DOC_STRING)
519 | def __call__(
520 | self,
521 | prompt: Union[str, List[str]] = None,
522 | negative_prompt: Optional[Union[str, List[str]]] = None,
523 | height: Optional[int] = None,
524 | width: Optional[int] = None,
525 | num_frames: int = 19,
526 | num_inference_steps: int = 64,
527 | timesteps: List[int] = None,
528 | guidance_scale: float = 4.5,
529 | num_videos_per_prompt: Optional[int] = 1,
530 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
531 | latents: Optional[torch.Tensor] = None,
532 | prompt_embeds: Optional[torch.Tensor] = None,
533 | prompt_attention_mask: Optional[torch.Tensor] = None,
534 | negative_prompt_embeds: Optional[torch.Tensor] = None,
535 | negative_prompt_attention_mask: Optional[torch.Tensor] = None,
536 | output_type: Optional[str] = "pil",
537 | return_dict: bool = True,
538 | attention_kwargs: Optional[Dict[str, Any]] = None,
539 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
540 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
541 | max_sequence_length: int = 256,
542 | ):
543 | r"""
544 | Function invoked when calling the pipeline for generation.
545 |
546 | Args:
547 | prompt (`str` or `List[str]`, *optional*):
548 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
549 | instead.
550 | height (`int`, *optional*, defaults to `self.default_height`):
551 | The height in pixels of the generated image. This is set to 480 by default for the best results.
552 | width (`int`, *optional*, defaults to `self.default_width`):
553 | The width in pixels of the generated image. This is set to 848 by default for the best results.
554 | num_frames (`int`, defaults to `19`):
555 | The number of video frames to generate
556 | num_inference_steps (`int`, *optional*, defaults to 50):
557 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
558 | expense of slower inference.
559 | timesteps (`List[int]`, *optional*):
560 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
561 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
562 | passed will be used. Must be in descending order.
563 | guidance_scale (`float`, defaults to `4.5`):
564 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
565 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
566 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
567 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
568 | usually at the expense of lower image quality.
569 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
570 | The number of videos to generate per prompt.
571 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
573 | to make generation deterministic.
574 | latents (`torch.Tensor`, *optional*):
575 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
576 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577 | tensor will ge generated by sampling using the supplied random `generator`.
578 | prompt_embeds (`torch.Tensor`, *optional*):
579 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
580 | provided, text embeddings will be generated from `prompt` input argument.
581 | prompt_attention_mask (`torch.Tensor`, *optional*):
582 | Pre-generated attention mask for text embeddings.
583 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
584 | Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
585 | provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
586 | negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
587 | Pre-generated attention mask for negative text embeddings.
588 | output_type (`str`, *optional*, defaults to `"pil"`):
589 | The output format of the generate image. Choose between
590 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
591 | return_dict (`bool`, *optional*, defaults to `True`):
592 | Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
593 | attention_kwargs (`dict`, *optional*):
594 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
595 | `self.processor` in
596 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
597 | callback_on_step_end (`Callable`, *optional*):
598 | A function that calls at the end of each denoising steps during the inference. The function is called
599 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
600 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
601 | `callback_on_step_end_tensor_inputs`.
602 | callback_on_step_end_tensor_inputs (`List`, *optional*):
603 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
604 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
605 | `._callback_tensor_inputs` attribute of your pipeline class.
606 | max_sequence_length (`int` defaults to `256`):
607 | Maximum sequence length to use with the `prompt`.
608 |
609 | Examples:
610 |
611 | Returns:
612 | [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
613 | If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
614 | is returned where the first element is a list with the generated images.
615 | """
616 |
617 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
618 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
619 |
620 | height = height or self.default_height
621 | width = width or self.default_width
622 |
623 | # 1. Check inputs. Raise error if not correct
624 | self.check_inputs(
625 | prompt=prompt,
626 | height=height,
627 | width=width,
628 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
629 | prompt_embeds=prompt_embeds,
630 | negative_prompt_embeds=negative_prompt_embeds,
631 | prompt_attention_mask=prompt_attention_mask,
632 | negative_prompt_attention_mask=negative_prompt_attention_mask,
633 | )
634 |
635 | self._guidance_scale = guidance_scale
636 | self._attention_kwargs = attention_kwargs
637 | self._interrupt = False
638 |
639 | # 2. Define call parameters
640 | if prompt is not None and isinstance(prompt, str):
641 | batch_size = 1
642 | elif prompt is not None and isinstance(prompt, list):
643 | batch_size = len(prompt)
644 | else:
645 | batch_size = prompt_embeds.shape[0]
646 |
647 | device = self._execution_device
648 | # 3. Prepare text embeddings
649 | (
650 | prompt_embeds,
651 | prompt_attention_mask,
652 | negative_prompt_embeds,
653 | negative_prompt_attention_mask,
654 | ) = self.encode_prompt(
655 | prompt=prompt,
656 | negative_prompt=negative_prompt,
657 | do_classifier_free_guidance=self.do_classifier_free_guidance,
658 | num_videos_per_prompt=num_videos_per_prompt,
659 | prompt_embeds=prompt_embeds,
660 | negative_prompt_embeds=negative_prompt_embeds,
661 | prompt_attention_mask=prompt_attention_mask,
662 | negative_prompt_attention_mask=negative_prompt_attention_mask,
663 | max_sequence_length=max_sequence_length,
664 | device=device,
665 | )
666 | # 4. Prepare latent variables
667 | num_channels_latents = self.transformer.config.in_channels
668 | latents = self.prepare_latents(
669 | batch_size * num_videos_per_prompt,
670 | num_channels_latents,
671 | height,
672 | width,
673 | num_frames,
674 | prompt_embeds.dtype,
675 | device,
676 | generator,
677 | latents,
678 | ).repeat(1,1,2,1,1)
679 |
680 | if self.do_classifier_free_guidance:
681 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
682 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
683 |
684 |
685 | # 5.5 Prepare attention rectification masks
686 | all_attention_mask = prepare_attention_mask(prompt_attention_mask, latents)
687 |
688 | # 5. Prepare timestep
689 | # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
690 | threshold_noise = 0.025
691 | sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
692 | sigmas = np.array(sigmas)
693 |
694 | timesteps, num_inference_steps = retrieve_timesteps(
695 | self.scheduler,
696 | num_inference_steps,
697 | device,
698 | timesteps,
699 | sigmas,
700 | )
701 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
702 | self._num_timesteps = len(timesteps)
703 |
704 | # 6. Denoising loop
705 | with self.progress_bar(total=num_inference_steps) as progress_bar:
706 | for i, t in enumerate(timesteps):
707 | if self.interrupt:
708 | continue
709 |
710 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
711 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
712 | timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
713 |
714 | noise_pred = self.transformer(
715 | hidden_states=latent_model_input,
716 | encoder_hidden_states=prompt_embeds,
717 | timestep=timestep,
718 | encoder_attention_mask=all_attention_mask,
719 | attention_kwargs=attention_kwargs,
720 | return_dict=False,
721 | )[0]
722 | # Mochi CFG + Sampling runs in FP32
723 | noise_pred = noise_pred.to(torch.float32)
724 |
725 | if self.do_classifier_free_guidance:
726 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
727 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
728 |
729 | # compute the previous noisy sample x_t -> x_t-1
730 | latents_dtype = latents.dtype
731 | latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
732 | latents = latents.to(latents_dtype)
733 |
734 | if latents.dtype != latents_dtype:
735 | if torch.backends.mps.is_available():
736 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
737 | latents = latents.to(latents_dtype)
738 |
739 | if callback_on_step_end is not None:
740 | callback_kwargs = {}
741 | for k in callback_on_step_end_tensor_inputs:
742 | callback_kwargs[k] = locals()[k]
743 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
744 |
745 | latents = callback_outputs.pop("latents", latents)
746 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
747 |
748 | # call the callback, if provided
749 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
750 | progress_bar.update()
751 |
752 | if XLA_AVAILABLE:
753 | xm.mark_step()
754 |
755 | if output_type == "latent":
756 | video = latents
757 | else:
758 | # unscale/denormalize the latents
759 | # denormalize with the mean and std if available and not None
760 | has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
761 | has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
762 | if has_latents_mean and has_latents_std:
763 | latents_mean = (
764 | torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
765 | )
766 | latents_std = (
767 | torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
768 | )
769 | latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
770 | else:
771 | latents = latents / self.vae.config.scaling_factor
772 |
773 | video = self.vae.decode(latents, return_dict=False)[0]
774 | video = self.video_processor.postprocess_video(video, output_type=output_type)
775 |
776 | # Offload all models
777 | self.maybe_free_model_hooks()
778 |
779 | if not return_dict:
780 | return (video,)
781 |
782 | return MochiPipelineOutput(frames=video)
783 |
--------------------------------------------------------------------------------
/Mochi/prepare_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPU_ID=0
4 | VIDEO_DIR=video-dataset-disney-organized
5 | OUTPUT_DIR=videos_prepared
6 | NUM_FRAMES=37
7 | RESOLUTION=480x848
8 |
9 | # Extract width and height from RESOLUTION
10 | WIDTH=$(echo $RESOLUTION | cut -dx -f1)
11 | HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
12 |
13 | python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
14 |
15 | CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
--------------------------------------------------------------------------------
/Mochi/rgba_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from typing import Any, Dict, Optional, Tuple, Union
6 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
8 |
9 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10 |
11 | @torch.no_grad()
12 | def decode_latents(pipe, latents):
13 | has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
14 | has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
15 | if has_latents_mean and has_latents_std:
16 | latents_mean = (
17 | torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
18 | )
19 | latents_std = (
20 | torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
21 | )
22 | latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
23 | else:
24 | latents = latents / pipe.vae.config.scaling_factor
25 |
26 | video = pipe.vae.decode(latents, return_dict=False)[0]
27 | video = pipe.video_processor.postprocess_video(video, output_type='np')
28 |
29 | return video
30 |
31 |
32 | class RGBALoRAMochiAttnProcessor:
33 | """Attention processor used in Mochi."""
34 | def __init__(self, device, dtype, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
35 | if not hasattr(F, "scaled_dot_product_attention"):
36 | raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
37 |
38 |
39 | # Initialize LoRA layers
40 | self.lora_alpha = lora_alpha
41 | self.lora_rank = lora_rank
42 |
43 | # Helper function to create LoRA layers
44 | def create_lora_layer(in_dim, mid_dim, out_dim, device=device, dtype=dtype):
45 | # Define the LoRA layers
46 | lora_a = nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype)
47 | lora_b = nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
48 |
49 | # Initialize lora_a with random parameters (default initialization)
50 | nn.init.kaiming_uniform_(lora_a.weight, a=math.sqrt(5)) # or another suitable initialization
51 |
52 | # Initialize lora_b with zero values
53 | nn.init.zeros_(lora_b.weight)
54 |
55 | lora_a.weight.requires_grad = True
56 | lora_b.weight.requires_grad = True
57 |
58 | # Combine the layers into a sequential module
59 | return nn.Sequential(lora_a, lora_b)
60 |
61 | self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
62 | self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
63 | self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
64 | self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
65 |
66 | def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
67 | """Applies LoRA updates to query, key, and value tensors."""
68 | query_delta = self.to_q_lora(hidden_states).to(query.device)
69 | query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
70 |
71 | key_delta = self.to_k_lora(hidden_states).to(key.device)
72 | key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
73 |
74 | value_delta = self.to_v_lora(hidden_states).to(value.device)
75 | value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
76 |
77 | return query, key, value
78 |
79 | def __call__(
80 | self,
81 | attn,
82 | hidden_states: torch.Tensor,
83 | encoder_hidden_states: torch.Tensor,
84 | attention_mask: Optional[torch.Tensor] = None,
85 | image_rotary_emb: Optional[torch.Tensor] = None,
86 | ) -> torch.Tensor:
87 | query = attn.to_q(hidden_states)
88 | key = attn.to_k(hidden_states)
89 | value = attn.to_v(hidden_states)
90 |
91 | scaling = self.lora_alpha / self.lora_rank
92 | sequence_length = query.size(1)
93 | query, key, value = self._apply_lora(hidden_states, sequence_length, query, key, value, scaling)
94 |
95 | query = query.unflatten(2, (attn.heads, -1))
96 | key = key.unflatten(2, (attn.heads, -1))
97 | value = value.unflatten(2, (attn.heads, -1))
98 |
99 | if attn.norm_q is not None:
100 | query = attn.norm_q(query)
101 | if attn.norm_k is not None:
102 | key = attn.norm_k(key)
103 |
104 | encoder_query = attn.add_q_proj(encoder_hidden_states)
105 | encoder_key = attn.add_k_proj(encoder_hidden_states)
106 | encoder_value = attn.add_v_proj(encoder_hidden_states)
107 |
108 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
109 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
110 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
111 |
112 | if attn.norm_added_q is not None:
113 | encoder_query = attn.norm_added_q(encoder_query)
114 | if attn.norm_added_k is not None:
115 | encoder_key = attn.norm_added_k(encoder_key)
116 |
117 | if image_rotary_emb is not None:
118 |
119 | def apply_rotary_emb(x, freqs_cos, freqs_sin):
120 | x_even = x[..., 0::2].float()
121 | x_odd = x[..., 1::2].float()
122 |
123 | cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
124 | sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
125 |
126 | return torch.stack([cos, sin], dim=-1).flatten(-2)
127 |
128 | query[:,sequence_length//2:] = apply_rotary_emb(query[:,sequence_length//2:], *image_rotary_emb)
129 | query[:,:sequence_length//2] = apply_rotary_emb(query[:,:sequence_length//2], *image_rotary_emb)
130 |
131 | key[:,sequence_length//2:] = apply_rotary_emb(key[:,sequence_length//2:], *image_rotary_emb)
132 | key[:,:sequence_length//2] = apply_rotary_emb(key[:,:sequence_length//2], *image_rotary_emb)
133 | # query = apply_rotary_emb(query, *image_rotary_emb)
134 | # key = apply_rotary_emb(key, *image_rotary_emb)
135 |
136 |
137 | query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
138 | encoder_query, encoder_key, encoder_value = (
139 | encoder_query.transpose(1, 2),
140 | encoder_key.transpose(1, 2),
141 | encoder_value.transpose(1, 2),
142 | )
143 |
144 | sequence_length = query.size(2)
145 | encoder_sequence_length = encoder_query.size(2)
146 | total_length = sequence_length + encoder_sequence_length
147 |
148 | batch_size, heads, _, dim = query.shape
149 |
150 | attn_outputs = []
151 | prompt_attention_mask = attention_mask["prompt_attention_mask"]
152 | rect_attention_mask = attention_mask["rect_attention_mask"]
153 | for idx in range(batch_size):
154 | mask = prompt_attention_mask[idx][None, :] # two components: attention mask and prompt mask
155 | valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
156 |
157 | valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
158 | valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
159 | valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
160 |
161 | valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
162 | valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
163 | valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
164 |
165 | attn_output = F.scaled_dot_product_attention(
166 | valid_query,
167 | valid_key,
168 | valid_value,
169 | dropout_p=0.0,
170 | attn_mask=rect_attention_mask[idx],
171 | is_causal=False
172 | )
173 | valid_sequence_length = attn_output.size(2)
174 | attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
175 | attn_outputs.append(attn_output)
176 |
177 | hidden_states = torch.cat(attn_outputs, dim=0)
178 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
179 |
180 | hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
181 | (sequence_length, encoder_sequence_length), dim=1
182 | )
183 |
184 | # linear proj
185 | original_hidden_states = attn.to_out[0](hidden_states)
186 | hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
187 | original_hidden_states[:, -sequence_length // 2:, :] += hidden_states_delta[:, -sequence_length // 2:, :] * scaling
188 | # dropout
189 | hidden_states = attn.to_out[1](original_hidden_states)
190 |
191 | if hasattr(attn, "to_add_out"):
192 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
193 |
194 | return hidden_states, encoder_hidden_states
195 |
196 | def prepare_for_rgba_inference(
197 | model, device: torch.device, dtype: torch.dtype,
198 | lora_rank: int = 128, lora_alpha: float = 1.0
199 | ):
200 |
201 | def custom_forward(self):
202 | def forward(
203 | hidden_states: torch.Tensor,
204 | encoder_hidden_states: torch.Tensor,
205 | timestep: torch.LongTensor,
206 | encoder_attention_mask: torch.Tensor,
207 | attention_kwargs: Optional[Dict[str, Any]] = None,
208 | return_dict: bool = True,
209 | ) -> torch.Tensor:
210 | if attention_kwargs is not None:
211 | attention_kwargs = attention_kwargs.copy()
212 | lora_scale = attention_kwargs.pop("scale", 1.0)
213 | else:
214 | lora_scale = 1.0
215 |
216 | if USE_PEFT_BACKEND:
217 | # weight the lora layers by setting `lora_scale` for each PEFT layer
218 | scale_lora_layers(self, lora_scale)
219 | else:
220 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
221 | logger.warning(
222 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
223 | )
224 |
225 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
226 | p = self.config.patch_size
227 |
228 | post_patch_height = height // p
229 | post_patch_width = width // p
230 |
231 | temb, encoder_hidden_states = self.time_embed(
232 | timestep,
233 | encoder_hidden_states,
234 | encoder_attention_mask["prompt_attention_mask"],
235 | hidden_dtype=hidden_states.dtype,
236 | )
237 |
238 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
239 | hidden_states = self.patch_embed(hidden_states)
240 | hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
241 |
242 | image_rotary_emb = self.rope(
243 | self.pos_frequencies,
244 | num_frames // 2, # Identitical PE for RGB and Alpha
245 | post_patch_height,
246 | post_patch_width,
247 | device=hidden_states.device,
248 | dtype=torch.float32,
249 | )
250 |
251 | for i, block in enumerate(self.transformer_blocks):
252 | if torch.is_grad_enabled() and self.gradient_checkpointing:
253 |
254 | def create_custom_forward(module):
255 | def custom_forward(*inputs):
256 | return module(*inputs)
257 |
258 | return custom_forward
259 |
260 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
261 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
262 | create_custom_forward(block),
263 | hidden_states,
264 | encoder_hidden_states,
265 | temb,
266 | encoder_attention_mask,
267 | image_rotary_emb,
268 | **ckpt_kwargs,
269 | )
270 | else:
271 | hidden_states, encoder_hidden_states = block(
272 | hidden_states=hidden_states,
273 | encoder_hidden_states=encoder_hidden_states,
274 | temb=temb,
275 | encoder_attention_mask=encoder_attention_mask,
276 | image_rotary_emb=image_rotary_emb,
277 | )
278 | hidden_states = self.norm_out(hidden_states, temb)
279 | hidden_states = self.proj_out(hidden_states)
280 |
281 | hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
282 | hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
283 | output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
284 |
285 | if USE_PEFT_BACKEND:
286 | # remove `lora_scale` from each PEFT layer
287 | unscale_lora_layers(self, lora_scale)
288 |
289 | if not return_dict:
290 | return (output,)
291 | return Transformer2DModelOutput(sample=output)
292 | return forward
293 |
294 | for _, block in enumerate(model.transformer_blocks):
295 | attn_processor = RGBALoRAMochiAttnProcessor(
296 | device=device,
297 | dtype=dtype,
298 | lora_rank=lora_rank,
299 | lora_alpha=lora_alpha
300 | )
301 | # block.attn1.set_processor(attn_processor)
302 | block.attn1.processor = attn_processor
303 |
304 | model.forward = custom_forward(model)
305 |
306 | def get_processor_state_dict(model):
307 | """Save trainable parameters of processors to a checkpoint."""
308 | processor_state_dict = {}
309 |
310 | for index, block in enumerate(model.transformer_blocks):
311 | if hasattr(block.attn1, "processor"):
312 | processor = block.attn1.processor
313 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
314 | if hasattr(processor, attr_name):
315 | lora_layer = getattr(processor, attr_name)
316 | for param_name, param in lora_layer.named_parameters():
317 | key = f"block_{index}.{attr_name}.{param_name}"
318 | processor_state_dict[key] = param.data.clone()
319 |
320 | # torch.save({"processor_state_dict": processor_state_dict}, checkpoint_path)
321 | # print(f"Processor state_dict saved to {checkpoint_path}")
322 | return processor_state_dict
323 |
324 | def load_processor_state_dict(model, processor_state_dict):
325 | """Load trainable parameters of processors from a checkpoint."""
326 | for index, block in enumerate(model.transformer_blocks):
327 | if hasattr(block.attn1, "processor"):
328 | processor = block.attn1.processor
329 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
330 | if hasattr(processor, attr_name):
331 | lora_layer = getattr(processor, attr_name)
332 | for param_name, param in lora_layer.named_parameters():
333 | key = f"block_{index}.{attr_name}.{param_name}"
334 | if key in processor_state_dict:
335 | param.data.copy_(processor_state_dict[key])
336 | else:
337 | raise KeyError(f"Missing key {key} in checkpoint.")
338 |
339 | # Prepare training parameters
340 | def get_processor_params(processor):
341 | params = []
342 | for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
343 | if hasattr(processor, attr_name):
344 | lora_layer = getattr(processor, attr_name)
345 | params.extend(p for p in lora_layer.parameters() if p.requires_grad)
346 | return params
347 |
348 | def get_all_processor_params(transformer):
349 | all_params = []
350 | for block in transformer.transformer_blocks:
351 | if hasattr(block.attn1, "processor"):
352 | processor = block.attn1.processor
353 | all_params.extend(get_processor_params(processor))
354 | return all_params
--------------------------------------------------------------------------------
/Mochi/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import gc
17 | import random
18 | from glob import glob
19 | import math
20 | import os
21 | import torch.nn.functional as F
22 | import numpy as np
23 | from pathlib import Path
24 | from typing import Any, Dict, Tuple, List
25 |
26 | import torch
27 | import wandb
28 | from pipeline_mochi_rgba import *
29 | from diffusers import FlowMatchEulerDiscreteScheduler, MochiTransformer3DModel
30 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31 | from diffusers.training_utils import cast_training_params
32 | from diffusers.utils import export_to_video
33 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
34 | from huggingface_hub import create_repo, upload_folder
35 | from torch.utils.data import DataLoader
36 | from tqdm.auto import tqdm
37 |
38 |
39 | from args import get_args # isort:skip
40 | from dataset_simple import LatentEmbedDataset
41 |
42 | from utils import print_memory, reset_memory # isort:skip
43 | from rgba_utils import *
44 |
45 |
46 | # Taken from
47 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
48 | def get_cosine_annealing_lr_scheduler(
49 | optimizer: torch.optim.Optimizer,
50 | warmup_steps: int,
51 | total_steps: int,
52 | ):
53 | def lr_lambda(step):
54 | if step < warmup_steps:
55 | return float(step) / float(max(1, warmup_steps))
56 | else:
57 | return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
58 |
59 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
60 |
61 |
62 | def save_model_card(
63 | repo_id: str,
64 | videos=None,
65 | base_model: str = None,
66 | validation_prompt=None,
67 | repo_folder=None,
68 | fps=30,
69 | ):
70 | widget_dict = []
71 | if videos is not None and len(videos) > 0:
72 | for i, video in enumerate(videos):
73 | export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
74 | widget_dict.append(
75 | {
76 | "text": validation_prompt if validation_prompt else " ",
77 | "output": {"url": f"final_video_{i}.mp4"},
78 | }
79 | )
80 |
81 | model_description = f"""
82 | # Mochi-1 Preview LoRA Finetune
83 |
84 |
85 |
86 | ## Model description
87 |
88 | This is a lora finetune of the Mochi-1 preview model `{base_model}`.
89 |
90 | The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
91 |
92 | ## Download model
93 |
94 | [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
95 |
96 | ## Usage
97 |
98 | Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
99 |
100 | ```py
101 | from diffusers import MochiPipeline
102 | from diffusers.utils import export_to_video
103 | import torch
104 |
105 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
106 | pipe.load_lora_weights("CHANGE_ME")
107 | pipe.enable_model_cpu_offload()
108 |
109 | with torch.autocast("cuda", torch.bfloat16):
110 | video = pipe(
111 | prompt="CHANGE_ME",
112 | guidance_scale=6.0,
113 | num_inference_steps=64,
114 | height=480,
115 | width=848,
116 | max_sequence_length=256,
117 | output_type="np"
118 | ).frames[0]
119 | export_to_video(video)
120 | ```
121 |
122 | For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
123 |
124 | """
125 | model_card = load_or_create_model_card(
126 | repo_id_or_path=repo_id,
127 | from_training=True,
128 | license="apache-2.0",
129 | base_model=base_model,
130 | prompt=validation_prompt,
131 | model_description=model_description,
132 | widget=widget_dict,
133 | )
134 | tags = [
135 | "text-to-video",
136 | "diffusers-training",
137 | "diffusers",
138 | "lora",
139 | "mochi-1-preview",
140 | "mochi-1-preview-diffusers",
141 | "template:sd-lora",
142 | ]
143 |
144 | model_card = populate_model_card(model_card, tags=tags)
145 | model_card.save(os.path.join(repo_folder, "README.md"))
146 |
147 |
148 | def log_validation(
149 | pipe: MochiPipeline,
150 | args: Dict[str, Any],
151 | pipeline_args: Dict[str, Any],
152 | step: int,
153 | wandb_run: str = None,
154 | is_final_validation: bool = False,
155 | ):
156 | print(
157 | f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
158 | )
159 | phase_name = "test" if is_final_validation else "validation"
160 |
161 | if not args.enable_model_cpu_offload:
162 | pipe = pipe.to("cuda")
163 |
164 | # run inference
165 | generator = torch.manual_seed(args.seed) if args.seed else None
166 |
167 | videos = []
168 | with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
169 | for _ in range(args.num_validation_videos):
170 | video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
171 | videos.append(video)
172 |
173 | video_filenames = []
174 | for i, video in enumerate(videos):
175 | prompt = (
176 | pipeline_args["prompt"][:25]
177 | .replace(" ", "_")
178 | .replace(" ", "_")
179 | .replace("'", "_")
180 | .replace('"', "_")
181 | .replace("/", "_")
182 | )
183 | filename = os.path.join(args.output_dir, f"{phase_name}_{str(step)}_video_{i}_{prompt}.mp4")
184 | export_to_video(video, filename, fps=30)
185 | video_filenames.append(filename)
186 |
187 | if wandb_run:
188 | wandb.log(
189 | {
190 | phase_name: [
191 | wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
192 | for i, filename in enumerate(video_filenames)
193 | ]
194 | }
195 | )
196 |
197 | return videos
198 |
199 |
200 | # Adapted from the original code:
201 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
202 | def cast_dit(model, dtype):
203 | for name, module in model.named_modules():
204 | if isinstance(module, torch.nn.Linear):
205 | assert any(
206 | n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
207 | ), f"Unexpected linear layer: {name}"
208 | module.to(dtype=dtype)
209 | elif isinstance(module, torch.nn.Conv2d):
210 | module.to(dtype=dtype)
211 | return model
212 |
213 |
214 | def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
215 | # lora_state_dict = get_peft_model_state_dict(model)
216 | processor_state_dict = get_processor_state_dict(model)
217 | torch.save(
218 | {
219 | "state_dict": processor_state_dict,
220 | "optimizer": optimizer.state_dict(),
221 | "lr_scheduler": lr_scheduler.state_dict(),
222 | "global_step": global_step,
223 | },
224 | checkpoint_path,
225 | )
226 |
227 |
228 | class CollateFunction:
229 | def __init__(self, caption_dropout: float = None) -> None:
230 | self.caption_dropout = caption_dropout
231 |
232 | def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
233 | ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
234 | z = DiagonalGaussianDistribution(ldists).sample()
235 | assert torch.isfinite(z).all()
236 |
237 | # Sample noise which we will add to the samples.
238 | eps = torch.randn_like(z)
239 | sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
240 |
241 | prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
242 | prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
243 | if self.caption_dropout and random.random() < self.caption_dropout:
244 | prompt_embeds.zero_()
245 | prompt_attention_mask = prompt_attention_mask.long()
246 | prompt_attention_mask.zero_()
247 | prompt_attention_mask = prompt_attention_mask.bool()
248 |
249 | return dict(
250 | z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
251 | )
252 |
253 |
254 | def main(args):
255 | if not torch.cuda.is_available():
256 | raise ValueError("Not supported without CUDA.")
257 |
258 | if args.report_to == "wandb" and args.hub_token is not None:
259 | raise ValueError(
260 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
261 | " Please use `huggingface-cli login` to authenticate with the Hub."
262 | )
263 |
264 | # Handle the repository creation
265 | if args.output_dir is not None:
266 | os.makedirs(args.output_dir, exist_ok=True)
267 |
268 | # Prepare models and scheduler
269 | transformer = MochiTransformer3DModel.from_pretrained(
270 | args.pretrained_model_name_or_path,
271 | subfolder="transformer",
272 | revision=args.revision,
273 | variant=args.variant,
274 | )
275 | scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
276 | args.pretrained_model_name_or_path, subfolder="scheduler"
277 | )
278 |
279 | transformer.requires_grad_(False)
280 | transformer.to("cuda")
281 | if args.gradient_checkpointing:
282 | transformer.enable_gradient_checkpointing()
283 | if args.cast_dit:
284 | transformer = cast_dit(transformer, torch.bfloat16)
285 | if args.compile_dit:
286 | transformer.compile()
287 |
288 | prepare_for_rgba_inference(
289 | model=transformer,
290 | device=torch.device("cuda"),
291 | dtype=torch.bfloat16,
292 | # seq_length=seq_length,
293 | )
294 | processor_params = get_all_processor_params(transformer)
295 |
296 | # Enable TF32 for faster training on Ampere GPUs,
297 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
298 | if args.allow_tf32 and torch.cuda.is_available():
299 | torch.backends.cuda.matmul.allow_tf32 = True
300 |
301 | if args.scale_lr:
302 | args.learning_rate = args.learning_rate * args.train_batch_size
303 | # only upcast trainable parameters (LoRA) into fp32
304 |
305 | if not isinstance(processor_params, list):
306 | processor_params = [processor_params]
307 | for m in processor_params:
308 | for param in m:
309 | # only upcast trainable parameters into fp32
310 | if param.requires_grad:
311 | param.data = param.to(torch.float32)
312 |
313 | # Prepare optimizer
314 | transformer_lora_parameters = processor_params # list(filter(lambda p: p.requires_grad, transformer.parameters()))
315 | num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
316 | optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
317 |
318 | # Dataset and DataLoader
319 | train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
320 | train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
321 | print(f"Found {len(train_vids)} training videos in {args.data_root}")
322 | assert len(train_vids) > 0, f"No training data found in {args.data_root}"
323 |
324 | collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
325 | train_dataset = LatentEmbedDataset(train_vids, repeat=1)
326 | train_dataloader = DataLoader(
327 | train_dataset,
328 | collate_fn=collate_fn,
329 | batch_size=args.train_batch_size,
330 | num_workers=args.dataloader_num_workers,
331 | pin_memory=args.pin_memory,
332 | )
333 |
334 | # LR scheduler and math around the number of training steps.
335 | overrode_max_train_steps = False
336 | num_update_steps_per_epoch = len(train_dataloader)
337 | if args.max_train_steps is None:
338 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
339 | overrode_max_train_steps = True
340 |
341 | lr_scheduler = get_cosine_annealing_lr_scheduler(
342 | optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
343 | )
344 |
345 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
346 | num_update_steps_per_epoch = len(train_dataloader)
347 | if overrode_max_train_steps:
348 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
349 | # Afterwards we recalculate our number of training epochs
350 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
351 |
352 | # We need to initialize the trackers we use, and also store our configuration.
353 | # The trackers initializes automatically on the main process.
354 | wandb_run = None
355 | if args.report_to == "wandb":
356 | tracker_name = args.tracker_name or "mochi-1-rgba-lora"
357 | wandb_run = wandb.init(project=tracker_name, config=vars(args))
358 |
359 | # Resume from checkpoint if specified
360 | if args.resume_from_checkpoint:
361 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
362 | if "global_step" in checkpoint:
363 | global_step = checkpoint["global_step"]
364 | if "optimizer" in checkpoint:
365 | optimizer.load_state_dict(checkpoint["optimizer"])
366 | if "lr_scheduler" in checkpoint:
367 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
368 |
369 | # set_peft_model_state_dict(transformer, checkpoint["state_dict"]) # Luozhou: modify this line
370 |
371 | processor_state_dict = checkpoint["state_dict"]
372 | load_processor_state_dict(transformer, processor_state_dict)
373 |
374 | print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
375 | print(f"Resuming from global step: {global_step}")
376 | else:
377 | global_step = 0
378 |
379 | print("===== Memory before training =====")
380 | reset_memory("cuda")
381 | print_memory("cuda")
382 |
383 | # Train!
384 | total_batch_size = args.train_batch_size
385 | print("***** Running training *****")
386 | print(f" Num trainable parameters = {num_trainable_parameters}")
387 | print(f" Num examples = {len(train_dataset)}")
388 | print(f" Num batches each epoch = {len(train_dataloader)}")
389 | print(f" Num epochs = {args.num_train_epochs}")
390 | print(f" Instantaneous batch size per device = {args.train_batch_size}")
391 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
392 | print(f" Total optimization steps = {args.max_train_steps}")
393 |
394 | first_epoch = 0
395 | progress_bar = tqdm(
396 | range(0, args.max_train_steps),
397 | initial=global_step,
398 | desc="Steps",
399 | )
400 | for epoch in range(first_epoch, args.num_train_epochs):
401 | transformer.train()
402 |
403 | for step, batch in enumerate(train_dataloader):
404 | with torch.no_grad():
405 | z = batch["z"].to("cuda")
406 | eps = batch["eps"].to("cuda")
407 | sigma = batch["sigma"].to("cuda")
408 | prompt_embeds = batch["prompt_embeds"].to("cuda")
409 | prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
410 |
411 | all_attention_mask = prepare_attention_mask(
412 | prompt_attention_mask=prompt_attention_mask,
413 | latents=z
414 | )
415 |
416 | sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
417 | # Add noise according to flow matching.
418 | # zt = (1 - texp) * x + texp * z1
419 | z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
420 | ut = z - eps
421 |
422 | # (1 - sigma) because of
423 | # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
424 | # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
425 | timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
426 |
427 | with torch.autocast("cuda", torch.bfloat16):
428 | model_pred = transformer(
429 | hidden_states=z_sigma,
430 | encoder_hidden_states=prompt_embeds,
431 | encoder_attention_mask=all_attention_mask,
432 | timestep=timesteps,
433 | return_dict=False,
434 | )[0]
435 | assert model_pred.shape == z.shape
436 | loss = F.mse_loss(model_pred.float(), ut.float())
437 | loss.backward()
438 |
439 | optimizer.step()
440 | optimizer.zero_grad()
441 | lr_scheduler.step()
442 |
443 | progress_bar.update(1)
444 | global_step += 1
445 |
446 | last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
447 | logs = {"loss": loss.detach().item(), "lr": last_lr}
448 | progress_bar.set_postfix(**logs)
449 | if wandb_run:
450 | wandb_run.log(logs, step=global_step)
451 |
452 | if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
453 | print(f"Saving checkpoint at step {global_step}")
454 | checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
455 | save_checkpoint(
456 | transformer,
457 | optimizer,
458 | lr_scheduler,
459 | global_step,
460 | checkpoint_path,
461 | )
462 |
463 | # if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
464 | print("===== Memory before validation =====")
465 | print_memory("cuda")
466 |
467 | transformer.eval()
468 | pipe = MochiPipeline.from_pretrained(
469 | args.pretrained_model_name_or_path,
470 | transformer=transformer,
471 | scheduler=scheduler,
472 | revision=args.revision,
473 | variant=args.variant,
474 | )
475 |
476 | if args.enable_slicing:
477 | pipe.vae.enable_slicing()
478 | if args.enable_tiling:
479 | pipe.vae.enable_tiling()
480 | if args.enable_model_cpu_offload:
481 | pipe.enable_model_cpu_offload()
482 |
483 | # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
484 | validation_prompts = [
485 | "A boy in a white shirt and shorts is seen bouncing a ball, isolated background",
486 | ]
487 | for validation_prompt in validation_prompts:
488 | pipeline_args = {
489 | "prompt": validation_prompt,
490 | "guidance_scale": 6.0,
491 | "num_frames": 37,
492 | "num_inference_steps": 64,
493 | "height": args.height,
494 | "width": args.width,
495 | "max_sequence_length": 256,
496 | }
497 | log_validation(
498 | pipe=pipe,
499 | args=args,
500 | pipeline_args=pipeline_args,
501 | step=global_step,
502 | wandb_run=wandb_run,
503 | )
504 |
505 | print("===== Memory after validation =====")
506 | print_memory("cuda")
507 | reset_memory("cuda")
508 |
509 | del pipe.text_encoder
510 | del pipe.vae
511 | del pipe
512 | gc.collect()
513 | torch.cuda.empty_cache()
514 |
515 | transformer.train()
516 |
517 | if global_step >= args.max_train_steps:
518 | break
519 |
520 | if global_step >= args.max_train_steps:
521 | break
522 |
523 | transformer.eval()
524 |
525 | # saving lora weights
526 | # transformer_lora_layers = get_peft_model_state_dict(transformer)
527 | # MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
528 |
529 | # Cleanup trained models to save memory
530 | del transformer
531 |
532 | gc.collect()
533 | torch.cuda.empty_cache()
534 |
535 | # Final test inference
536 | # validation_outputs = []
537 | # if args.validation_prompt and args.num_validation_videos > 0:
538 | # print("===== Memory before testing =====")
539 | # print_memory("cuda")
540 | # reset_memory("cuda")
541 |
542 | # pipe = MochiPipeline.from_pretrained(
543 | # args.pretrained_model_name_or_path,
544 | # revision=args.revision,
545 | # variant=args.variant,
546 | # )
547 |
548 |
549 |
550 | # if args.enable_slicing:
551 | # pipe.vae.enable_slicing()
552 | # if args.enable_tiling:
553 | # pipe.vae.enable_tiling()
554 | # if args.enable_model_cpu_offload:
555 | # pipe.enable_model_cpu_offload()
556 |
557 | # # Load LoRA weights
558 | # # lora_scaling = args.lora_alpha / args.rank
559 | # # pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
560 | # # pipe.set_adapters(["mochi-lora"], [lora_scaling])
561 |
562 | # # Run inference
563 | # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
564 | # for validation_prompt in validation_prompts:
565 | # pipeline_args = {
566 | # "prompt": validation_prompt,
567 | # "guidance_scale": 6.0,
568 | # "num_inference_steps": 64,
569 | # "height": args.height,
570 | # "width": args.width,
571 | # "max_sequence_length": 256,
572 | # }
573 |
574 | # video = log_validation(
575 | # pipe=pipe,
576 | # args=args,
577 | # pipeline_args=pipeline_args,
578 | # epoch=epoch,
579 | # wandb_run=wandb_run,
580 | # is_final_validation=True,
581 | # )
582 | # validation_outputs.extend(video)
583 |
584 | # print("===== Memory after testing =====")
585 | # print_memory("cuda")
586 | # reset_memory("cuda")
587 | # torch.cuda.synchronize("cuda")
588 |
589 |
590 |
591 | if __name__ == "__main__":
592 | args = get_args()
593 | main(args)
594 |
--------------------------------------------------------------------------------
/Mochi/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export NCCL_P2P_DISABLE=1
3 | export TORCH_NCCL_ENABLE_MONITORING=0
4 |
5 | GPU_IDS="3"
6 |
7 | DATA_ROOT="/hpc2hdd/home/lwang592/projects/finetrainers/training/data/video-matte-240k-rgb-prepared-f37"
8 | MODEL="genmo/mochi-1-preview"
9 | OUTPUT_PATH="mochi-rgba-lora-f37"
10 |
11 | cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python train.py \
12 | --pretrained_model_name_or_path $MODEL \
13 | --cast_dit \
14 | --data_root $DATA_ROOT \
15 | --seed 42 \
16 | --output_dir $OUTPUT_PATH \
17 | --train_batch_size 2 \
18 | --dataloader_num_workers 4 \
19 | --pin_memory \
20 | --caption_dropout 0.0 \
21 | --max_train_steps 5000 \
22 | --gradient_checkpointing \
23 | --enable_slicing \
24 | --enable_tiling \
25 | --enable_model_cpu_offload \
26 | --optimizer adamw \
27 | --allow_tf32"
28 |
29 | echo "Running command: $cmd"
30 | eval $cmd
31 | echo -ne "-------------------- Finished executing script --------------------\n\n"
--------------------------------------------------------------------------------
/Mochi/trim_and_crop_videos.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from:
3 | https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
4 | """
5 |
6 | from pathlib import Path
7 | import shutil
8 |
9 | import click
10 | from moviepy.editor import VideoFileClip
11 | from tqdm import tqdm
12 |
13 |
14 | @click.command()
15 | @click.argument("folder", type=click.Path(exists=True, dir_okay=True))
16 | @click.argument("output_folder", type=click.Path(dir_okay=True))
17 | @click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
18 | @click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
19 | @click.option("--force_upsample", is_flag=True, help="Force upsample.")
20 | def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
21 | """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
22 | input_path = Path(folder)
23 | output_path = Path(output_folder)
24 | output_path.mkdir(parents=True, exist_ok=True)
25 |
26 | # Parse target resolution
27 | target_height, target_width = map(int, resolution.split("x"))
28 |
29 | # Calculate duration
30 | duration = (num_frames / 30) + 0.09
31 |
32 | # Find all MP4 and MOV files
33 | video_files = (
34 | list(input_path.rglob("*.mp4"))
35 | + list(input_path.rglob("*.MOV"))
36 | + list(input_path.rglob("*.mov"))
37 | + list(input_path.rglob("*.MP4"))
38 | )
39 |
40 | for file_path in tqdm(video_files):
41 | try:
42 | relative_path = file_path.relative_to(input_path)
43 | output_file = output_path / relative_path.with_suffix(".mp4")
44 | output_file.parent.mkdir(parents=True, exist_ok=True)
45 |
46 | click.echo(f"Processing: {file_path}")
47 | video = VideoFileClip(str(file_path))
48 |
49 | # Skip if video is too short
50 | if video.duration < duration:
51 | click.echo(f"Skipping {file_path} as it is too short")
52 | continue
53 |
54 | # Skip if target resolution is larger than input
55 | if target_width > video.w or target_height > video.h:
56 | if force_upsample:
57 | click.echo(
58 | f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
59 | )
60 | video = video.resize(width=target_width, height=target_height)
61 | else:
62 | click.echo(
63 | f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
64 | )
65 | continue
66 |
67 | # First truncate duration
68 | truncated = video.subclip(0, duration)
69 |
70 | # Calculate crop dimensions to maintain aspect ratio
71 | target_ratio = target_width / target_height
72 | current_ratio = truncated.w / truncated.h
73 |
74 | if current_ratio > target_ratio:
75 | # Video is wider than target ratio - crop width
76 | new_width = int(truncated.h * target_ratio)
77 | x1 = (truncated.w - new_width) // 2
78 | final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
79 | else:
80 | # Video is taller than target ratio - crop height
81 | new_height = int(truncated.w / target_ratio)
82 | y1 = (truncated.h - new_height) // 2
83 | final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
84 |
85 | # Set output parameters for consistent MP4 encoding
86 | output_params = {
87 | "codec": "libx264",
88 | "audio": False, # Disable audio
89 | "preset": "medium", # Balance between speed and quality
90 | "bitrate": "5000k", # Adjust as needed
91 | }
92 |
93 | # Set FPS to 30
94 | final = final.set_fps(30)
95 |
96 | # Check for a corresponding .txt file
97 | txt_file_path = file_path.with_suffix(".txt")
98 | if txt_file_path.exists():
99 | output_txt_file = output_path / relative_path.with_suffix(".txt")
100 | output_txt_file.parent.mkdir(parents=True, exist_ok=True)
101 | shutil.copy(txt_file_path, output_txt_file)
102 | click.echo(f"Copied {txt_file_path} to {output_txt_file}")
103 | else:
104 | # Print warning in bold yellow with a warning emoji
105 | click.echo(
106 | f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
107 | )
108 | output_txt_file = output_path / relative_path.with_suffix(".txt")
109 | output_txt_file.parent.mkdir(parents=True, exist_ok=True)
110 | output_txt_file.touch()
111 |
112 | # Write the output file
113 | final.write_videofile(str(output_file), **output_params)
114 |
115 | # Clean up
116 | video.close()
117 | truncated.close()
118 | final.close()
119 |
120 | except Exception as e:
121 | click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
122 | raise
123 |
124 |
125 | if __name__ == "__main__":
126 | truncate_videos()
--------------------------------------------------------------------------------
/Mochi/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import inspect
3 | from typing import Optional, Tuple, Union
4 |
5 | import torch
6 | from accelerate import Accelerator
7 | from accelerate.logging import get_logger
8 | from diffusers.models.embeddings import get_3d_rotary_pos_embed
9 | from diffusers.utils.torch_utils import is_compiled_module
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def get_optimizer(
16 | params_to_optimize,
17 | optimizer_name: str = "adam",
18 | learning_rate: float = 1e-3,
19 | beta1: float = 0.9,
20 | beta2: float = 0.95,
21 | beta3: float = 0.98,
22 | epsilon: float = 1e-8,
23 | weight_decay: float = 1e-4,
24 | prodigy_decouple: bool = False,
25 | prodigy_use_bias_correction: bool = False,
26 | prodigy_safeguard_warmup: bool = False,
27 | use_8bit: bool = False,
28 | use_4bit: bool = False,
29 | use_torchao: bool = False,
30 | use_deepspeed: bool = False,
31 | use_cpu_offload_optimizer: bool = False,
32 | offload_gradients: bool = False,
33 | ) -> torch.optim.Optimizer:
34 | optimizer_name = optimizer_name.lower()
35 |
36 | # Use DeepSpeed optimzer
37 | if use_deepspeed:
38 | from accelerate.utils import DummyOptim
39 |
40 | return DummyOptim(
41 | params_to_optimize,
42 | lr=learning_rate,
43 | betas=(beta1, beta2),
44 | eps=epsilon,
45 | weight_decay=weight_decay,
46 | )
47 |
48 | if use_8bit and use_4bit:
49 | raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
50 |
51 | if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
52 | try:
53 | import torchao
54 |
55 | torchao.__version__
56 | except ImportError:
57 | raise ImportError(
58 | "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
59 | )
60 |
61 | if not use_torchao and use_4bit:
62 | raise ValueError("4-bit Optimizers are only supported with torchao.")
63 |
64 | # Optimizer creation
65 | supported_optimizers = ["adam", "adamw", "prodigy", "came"]
66 | if optimizer_name not in supported_optimizers:
67 | logger.warning(
68 | f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
69 | )
70 | optimizer_name = "adamw"
71 |
72 | if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
73 | raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
74 |
75 | if use_8bit:
76 | try:
77 | import bitsandbytes as bnb
78 | except ImportError:
79 | raise ImportError(
80 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
81 | )
82 |
83 | if optimizer_name == "adamw":
84 | if use_torchao:
85 | from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
86 |
87 | optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
88 | else:
89 | optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
90 |
91 | init_kwargs = {
92 | "betas": (beta1, beta2),
93 | "eps": epsilon,
94 | "weight_decay": weight_decay,
95 | }
96 |
97 | elif optimizer_name == "adam":
98 | if use_torchao:
99 | from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
100 |
101 | optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
102 | else:
103 | optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
104 |
105 | init_kwargs = {
106 | "betas": (beta1, beta2),
107 | "eps": epsilon,
108 | "weight_decay": weight_decay,
109 | }
110 |
111 | elif optimizer_name == "prodigy":
112 | try:
113 | import prodigyopt
114 | except ImportError:
115 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
116 |
117 | optimizer_class = prodigyopt.Prodigy
118 |
119 | if learning_rate <= 0.1:
120 | logger.warning(
121 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
122 | )
123 |
124 | init_kwargs = {
125 | "lr": learning_rate,
126 | "betas": (beta1, beta2),
127 | "beta3": beta3,
128 | "eps": epsilon,
129 | "weight_decay": weight_decay,
130 | "decouple": prodigy_decouple,
131 | "use_bias_correction": prodigy_use_bias_correction,
132 | "safeguard_warmup": prodigy_safeguard_warmup,
133 | }
134 |
135 | elif optimizer_name == "came":
136 | try:
137 | import came_pytorch
138 | except ImportError:
139 | raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
140 |
141 | optimizer_class = came_pytorch.CAME
142 |
143 | init_kwargs = {
144 | "lr": learning_rate,
145 | "eps": (1e-30, 1e-16),
146 | "betas": (beta1, beta2, beta3),
147 | "weight_decay": weight_decay,
148 | }
149 |
150 | if use_cpu_offload_optimizer:
151 | from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
152 |
153 | if "fused" in inspect.signature(optimizer_class.__init__).parameters:
154 | init_kwargs.update({"fused": True})
155 |
156 | optimizer = CPUOffloadOptimizer(
157 | params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
158 | )
159 | else:
160 | optimizer = optimizer_class(params_to_optimize, **init_kwargs)
161 |
162 | return optimizer
163 |
164 |
165 | def get_gradient_norm(parameters):
166 | norm = 0
167 | for param in parameters:
168 | if param.grad is None:
169 | continue
170 | local_norm = param.grad.detach().data.norm(2)
171 | norm += local_norm.item() ** 2
172 | norm = norm**0.5
173 | return norm
174 |
175 |
176 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
177 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
178 | tw = tgt_width
179 | th = tgt_height
180 | h, w = src
181 | r = h / w
182 | if r > (th / tw):
183 | resize_height = th
184 | resize_width = int(round(th / h * w))
185 | else:
186 | resize_width = tw
187 | resize_height = int(round(tw / w * h))
188 |
189 | crop_top = int(round((th - resize_height) / 2.0))
190 | crop_left = int(round((tw - resize_width) / 2.0))
191 |
192 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
193 |
194 |
195 | def prepare_rotary_positional_embeddings(
196 | height: int,
197 | width: int,
198 | num_frames: int,
199 | vae_scale_factor_spatial: int = 8,
200 | patch_size: int = 2,
201 | patch_size_t: int = None,
202 | attention_head_dim: int = 64,
203 | device: Optional[torch.device] = None,
204 | base_height: int = 480,
205 | base_width: int = 720,
206 | ) -> Tuple[torch.Tensor, torch.Tensor]:
207 | grid_height = height // (vae_scale_factor_spatial * patch_size)
208 | grid_width = width // (vae_scale_factor_spatial * patch_size)
209 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
210 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
211 |
212 | if patch_size_t is None:
213 | # CogVideoX 1.0
214 | grid_crops_coords = get_resize_crop_region_for_grid(
215 | (grid_height, grid_width), base_size_width, base_size_height
216 | )
217 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
218 | embed_dim=attention_head_dim,
219 | crops_coords=grid_crops_coords,
220 | grid_size=(grid_height, grid_width),
221 | temporal_size=num_frames,
222 | )
223 | else:
224 | # CogVideoX 1.5
225 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
226 |
227 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
228 | embed_dim=attention_head_dim,
229 | crops_coords=None,
230 | grid_size=(grid_height, grid_width),
231 | temporal_size=base_num_frames,
232 | grid_type="slice",
233 | max_size=(base_size_height, base_size_width),
234 | )
235 |
236 | freqs_cos = freqs_cos.to(device=device)
237 | freqs_sin = freqs_sin.to(device=device)
238 | return freqs_cos, freqs_sin
239 |
240 |
241 | def reset_memory(device: Union[str, torch.device]) -> None:
242 | gc.collect()
243 | torch.cuda.empty_cache()
244 | torch.cuda.reset_peak_memory_stats(device)
245 | torch.cuda.reset_accumulated_memory_stats(device)
246 |
247 |
248 | def print_memory(device: Union[str, torch.device]) -> None:
249 | memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
250 | max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
251 | max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
252 | print(f"{memory_allocated=:.3f} GB")
253 | print(f"{max_memory_allocated=:.3f} GB")
254 | print(f"{max_memory_reserved=:.3f} GB")
255 |
256 |
257 | def unwrap_model(accelerator: Accelerator, model):
258 | model = accelerator.unwrap_model(model)
259 | model = model._orig_mod if is_compiled_module(model) else model
260 | return model
261 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## TransPixeler: Advancing Text-to-Video Generation with Transparency (CVPR2025)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | [Luozhou Wang*](https://wileewang.github.io/),
12 | [Yijun Li**](https://yijunmaverick.github.io/),
13 | [Zhifei Chen](),
14 | [Jui-Hsien Wang](http://juiwang.com/),
15 | [Zhifei Zhang](https://zzutk.github.io/),
16 | [He Zhang](https://sites.google.com/site/hezhangsprinter),
17 | [Zhe Lin](https://sites.google.com/site/zhelin625/home),
18 | [Ying-Cong Chen†](https://www.yingcong.me)
19 |
20 | HKUST(GZ), HKUST, Adobe Research.
21 |
22 | \* Internship Project
23 | \** Project Lead
24 | † Corresponding Author
25 |
26 | Text-to-video generative models have made significant strides, enabling diverse applications in entertainment, advertising, and education. However, generating RGBA video, which includes alpha channels for transparency, remains a challenge due to limited datasets and the difficulty of adapting existing models. Alpha channels are crucial for visual effects (VFX), allowing transparent elements like smoke and reflections to blend seamlessly into scenes.
27 | We introduce TransPixar, a method to extend pretrained video models for RGBA generation while retaining the original RGB capabilities. TransPixar leverages a diffusion transformer (DiT) architecture, incorporating alpha-specific tokens and using LoRA-based fine-tuning to jointly generate RGB and alpha channels with high consistency. By optimizing attention mechanisms, TransPixeler preserves the strengths of the original RGB model and achieves strong alignment between RGB and alpha channels despite limited training data.
28 | Our approach effectively generates diverse and consistent RGBA videos, advancing the possibilities for VFX and interactive content creation.
29 |
30 |
31 |
32 |
33 |
34 |
35 | ## 📰 News
36 |
37 | - **[2025.04.28]** We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks. This branch includes training code tailored for generating both RGB and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
38 |
39 | - **[2025.02.26]** **TransPixeler** is accepted by CVPR 2025! See you in Nashville!
40 |
41 | - **[2025.01.19]** We've renamed our project from **TransPixar** to **TransPixeler**!!
42 |
43 | - **[2025.01.17]** We’ve created a [Discord group](https://discord.gg/7Xds3Qjr) and a [WeChat group](https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg)! Everyone is welcome to join for discussions and collaborations.
44 |
45 | - **[2025.01.14]** Added new tasks to the repository's roadmap, including support for Hunyuan and LTX video models, and ComfyUI integration.
46 |
47 | - **[2025.01.07]** Released project page, arXiv paper, inference code, and Hugging Face demo.
48 |
49 |
50 |
51 |
52 | ## 🔥 New Branch for Joint Generation with Wan2.1
53 |
54 | We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks.
55 |
56 | In the `wan` branch, we have developed and released training code tailored for joint generation scenarios, enabling the simultaneous generation of RGB videos and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
57 |
58 | **Key features of the `wan` branch:**
59 | - **Integration of Wan2.1**: Leverages the capabilities of the Wan2.1 video generation model for enhanced performance.
60 | - **Joint Generation Support**: Facilitates the concurrent generation of RGB and paired modality videos.
61 | - **Dataset Structure**: Expects each sample to include:
62 | - A primary video file (`001.mp4`) representing the RGB content.
63 | - A paired secondary video file (`001_seg.mp4`) with a fixed `_seg` suffix, representing the associated modality.
64 | - A caption text file (`001.txt`) with the same base name as the primary video.
65 | - **Periodic Evaluation**: Supports periodic video sampling during training by setting `eval_every_step` or `eval_every_epoch` in the configuration.
66 | - **Customized Pipelines**: Offers tailored training and inference pipelines designed specifically for joint generation tasks.
67 |
68 | 👉 To utilize the joint generation features, please checkout the [`wan`](https://github.com/wileewang/TransPixar/tree/wan) branch.
69 |
70 |
71 |
72 |
73 | ## Contents
74 |
75 | * [Installation](#installation)
76 | * [TransPixar LoRA Weights](#transpixar-lora-hub)
77 | * [Training](#training)
78 | * [Inference](#inference)
79 | * [Acknowledgement](#acknowledgement)
80 | * [Citation](#citation)
81 |
82 |
83 |
84 | ## Installation
85 |
86 | ```bash
87 | # For the main branch
88 | conda create -n TransPixeler python=3.10
89 | conda activate TransPixeler
90 | pip install -r requirements.txt
91 | ```
92 |
93 | **Note:**
94 | If you want to use the **Wan2.1 model**, please first checkout the `wan` branch:
95 |
96 | ```bash
97 | git checkout wan
98 | ```
99 |
100 | ## TransPixeler LoRA Weights
101 |
102 | Our pipeline is designed to support various video tasks, including Text-to-RGBA Video, Image-to-RGBA Video.
103 |
104 | We provide the following pre-trained LoRA weights:
105 |
106 | | Task | Base Model | Frames | LoRA weights | Inference VRAM |
107 | |---------------|---------------------------------------------------------------|--------|--------------------------------------------------------------------|----------------|
108 | | T2V + RGBA | [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5b) | 49 | [link](https://huggingface.co/wileewang/TransPixar/blob/main/cogvideox_rgba_lora.safetensors) | ~24GB |
109 |
110 |
111 | ## Training - RGB + Alpha Joint Generation
112 | We have open-sourced the training code for **Mochi** on RGBA joint generation. Please refer to the [Mochi README](Mochi/README.md) for details.
113 |
114 |
115 | ## Inference - Gradio Demo
116 | In addition to the [Hugging Face online demo](https://huggingface.co/spaces/wileewang/TransPixar), users can also launch a local inference demo based on CogVideoX-5B by running the following command:
117 |
118 | ```bash
119 | python app.py
120 | ```
121 |
122 | ## Inference - Command Line Interface (CLI)
123 | To generate RGBA videos, navigate to the corresponding directory for the video model and execute the following command:
124 | ```bash
125 | python cli.py \
126 | --lora_path /path/to/lora \
127 | --prompt "..."
128 | ```
129 |
130 | ---
131 |
132 | ## Acknowledgement
133 |
134 | * [finetrainers](https://github.com/a-r-r-o-w/finetrainers): We followed their implementation of Mochi training and inference.
135 | * [CogVideoX](https://github.com/THUDM/CogVideo): We followed their implementation of CogVideoX training and inference.
136 |
137 | We are grateful for their exceptional work and generous contribution to the open-source community.
138 |
139 | ## Citation
140 |
141 | ```bibtex
142 | @misc{wang2025transpixeler,
143 | title={TransPixeler: Advancing Text-to-Video Generation with Transparency},
144 | author={Luozhou Wang and Yijun Li and Zhifei Chen and Jui-Hsien Wang and Zhifei Zhang and He Zhang and Zhe Lin and Ying-Cong Chen},
145 | year={2025},
146 | eprint={2501.03006},
147 | archivePrefix={arXiv},
148 | primaryClass={cs.CV},
149 | url={https://arxiv.org/abs/2501.03006},
150 | }
151 | ```
152 |
153 | ## Star History
154 |
155 | [](https://star-history.com/#wileewang/TransPixeler&Date)
156 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | """
2 | THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
3 | set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
4 | Usage:
5 | OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
6 | """
7 |
8 | import math
9 | import os
10 | import random
11 | import threading
12 | import time
13 |
14 | import cv2
15 | import tempfile
16 | import imageio_ffmpeg
17 | import gradio as gr
18 | import torch
19 | from PIL import Image
20 | # from diffusers import (
21 | # CogVideoXPipeline,
22 | # CogVideoXDPMScheduler,
23 | # CogVideoXVideoToVideoPipeline,
24 | # CogVideoXImageToVideoPipeline,
25 | # CogVideoXTransformer3DModel,
26 | # )
27 | from typing import Union, List
28 | from CogVideoX.pipeline_rgba import CogVideoXPipeline
29 | from CogVideoX.rgba_utils import *
30 | from diffusers import CogVideoXDPMScheduler
31 |
32 | from diffusers.utils import load_video, load_image, export_to_video
33 | from datetime import datetime, timedelta
34 |
35 | from diffusers.image_processor import VaeImageProcessor
36 | import moviepy.editor as mp
37 | import numpy as np
38 | from huggingface_hub import hf_hub_download, snapshot_download
39 | import gc
40 |
41 | device = "cuda" if torch.cuda.is_available() else "cpu"
42 |
43 | # hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
44 | hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora")
45 | # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
46 |
47 | pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16)
48 | # pipe.enable_sequential_cpu_offload()
49 | pipe.vae.enable_slicing()
50 | pipe.vae.enable_tiling()
51 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
52 | seq_length = 2 * (
53 | (480 // pipe.vae_scale_factor_spatial // 2)
54 | * (720 // pipe.vae_scale_factor_spatial // 2)
55 | * ((13 - 1) // pipe.vae_scale_factor_temporal + 1)
56 | )
57 | prepare_for_rgba_inference(
58 | pipe.transformer,
59 | rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors",
60 | device="cuda",
61 | dtype=torch.bfloat16,
62 | text_length=226,
63 | seq_length=seq_length, # this is for the creation of attention mask.
64 | )
65 |
66 | # pipe.transformer.to(memory_format=torch.channels_last)
67 | # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
68 | # pipe_image.transformer.to(memory_format=torch.channels_last)
69 | # pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
70 |
71 | os.makedirs("./output", exist_ok=True)
72 | os.makedirs("./gradio_tmp", exist_ok=True)
73 |
74 | # upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
75 | # frame_interpolation_model = load_rife_model("model_rife")
76 |
77 |
78 | sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
79 | For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
80 | There are a few rules to follow:
81 | You will only ever output a single video description per user request.
82 | When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
83 | Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
84 | Video descriptions must have the same num of words as examples below. Extra words will be ignored.
85 | """
86 | def save_video(tensor: Union[List[np.ndarray], List[Image.Image]], fps: int = 8, prefix='rgb'):
87 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
88 | video_path = f"./output/{prefix}_{timestamp}.mp4"
89 | os.makedirs(os.path.dirname(video_path), exist_ok=True)
90 | export_to_video(tensor, video_path, fps=fps)
91 | return video_path
92 |
93 | def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
94 | width, height = get_video_dimensions(input_video)
95 |
96 | if width == 720 and height == 480:
97 | processed_video = input_video
98 | else:
99 | processed_video = center_crop_resize(input_video)
100 | return processed_video
101 |
102 |
103 | def get_video_dimensions(input_video_path):
104 | reader = imageio_ffmpeg.read_frames(input_video_path)
105 | metadata = next(reader)
106 | return metadata["size"]
107 |
108 |
109 | def center_crop_resize(input_video_path, target_width=720, target_height=480):
110 | cap = cv2.VideoCapture(input_video_path)
111 |
112 | orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
113 | orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
114 | orig_fps = cap.get(cv2.CAP_PROP_FPS)
115 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116 |
117 | width_factor = target_width / orig_width
118 | height_factor = target_height / orig_height
119 | resize_factor = max(width_factor, height_factor)
120 |
121 | inter_width = int(orig_width * resize_factor)
122 | inter_height = int(orig_height * resize_factor)
123 |
124 | target_fps = 8
125 | ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
126 | skip = min(5, ideal_skip) # Cap at 5
127 |
128 | while (total_frames / (skip + 1)) < 49 and skip > 0:
129 | skip -= 1
130 |
131 | processed_frames = []
132 | frame_count = 0
133 | total_read = 0
134 |
135 | while frame_count < 49 and total_read < total_frames:
136 | ret, frame = cap.read()
137 | if not ret:
138 | break
139 |
140 | if total_read % (skip + 1) == 0:
141 | resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
142 |
143 | start_x = (inter_width - target_width) // 2
144 | start_y = (inter_height - target_height) // 2
145 | cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
146 |
147 | processed_frames.append(cropped)
148 | frame_count += 1
149 |
150 | total_read += 1
151 |
152 | cap.release()
153 |
154 | with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
155 | temp_video_path = temp_file.name
156 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
157 | out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
158 |
159 | for frame in processed_frames:
160 | out.write(frame)
161 |
162 | out.release()
163 |
164 | return temp_video_path
165 |
166 |
167 |
168 | def infer(
169 | prompt: str,
170 | num_inference_steps: int,
171 | guidance_scale: float,
172 | seed: int = -1,
173 | progress=gr.Progress(track_tqdm=True),
174 | ):
175 | if seed == -1:
176 | seed = random.randint(0, 2**8 - 1)
177 | pipe.to(device)
178 | video_pt = pipe(
179 | prompt=prompt + ", isolated background",
180 | num_videos_per_prompt=1,
181 | num_inference_steps=num_inference_steps,
182 | num_frames=13,
183 | use_dynamic_cfg=True,
184 | output_type="latent",
185 | guidance_scale=guidance_scale,
186 | generator=torch.Generator(device=device).manual_seed(int(seed)),
187 | ).frames
188 | # pipe.to("cpu")
189 | gc.collect()
190 | return (video_pt, seed)
191 |
192 |
193 | def convert_to_gif(video_path):
194 | clip = mp.VideoFileClip(video_path)
195 | clip = clip.set_fps(8)
196 | clip = clip.resize(height=240)
197 | gif_path = video_path.replace(".mp4", ".gif")
198 | clip.write_gif(gif_path, fps=8)
199 | return gif_path
200 |
201 |
202 | def delete_old_files():
203 | while True:
204 | now = datetime.now()
205 | cutoff = now - timedelta(minutes=10)
206 | directories = ["./output", "./gradio_tmp"]
207 |
208 | for directory in directories:
209 | for filename in os.listdir(directory):
210 | file_path = os.path.join(directory, filename)
211 | if os.path.isfile(file_path):
212 | file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
213 | if file_mtime < cutoff:
214 | os.remove(file_path)
215 | time.sleep(600)
216 |
217 |
218 | threading.Thread(target=delete_old_files, daemon=True).start()
219 |
220 | with gr.Blocks() as demo:
221 | gr.HTML("""
222 |
223 | TransPixar + CogVideoX-5B Huggingface Space🤗
224 |
225 |
230 |
231 | ⚠️ This demo is for academic research and experiential use only.
232 |
233 | """)
234 | with gr.Row():
235 | with gr.Column():
236 | prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
237 | with gr.Group():
238 | with gr.Column():
239 | with gr.Row():
240 | seed_param = gr.Number(
241 | label="Inference Seed (Enter a positive number, -1 for random)", value=-1
242 | )
243 |
244 | generate_button = gr.Button("🎬 Generate Video")
245 | with gr.Row():
246 | gr.Markdown(
247 | """
248 | **Note:** The output RGB is a premultiplied version to avoid the color decontamination problem.
249 | It can directly composite with a background using:
250 | ```
251 | composite = rgb + (1 - alpha) * background
252 | ```
253 | """
254 | )
255 |
256 | with gr.Column():
257 | rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480)
258 | alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480)
259 | with gr.Row():
260 | download_rgb_video_button = gr.File(label="📥 Download RGB Video", visible=False)
261 | download_alpha_video_button = gr.File(label="📥 Download Alpha Video", visible=False)
262 | seed_text = gr.Number(label="Seed Used for Video Generation", visible=False)
263 |
264 |
265 | def generate(
266 | prompt,
267 | seed_value,
268 | progress=gr.Progress(track_tqdm=True)
269 | ):
270 | latents, seed = infer(
271 | prompt,
272 | num_inference_steps=25, # NOT Changed
273 | guidance_scale=7.0, # NOT Changed
274 | seed=seed_value,
275 | progress=progress,
276 | )
277 |
278 | latents_rgb, latents_alpha = latents.chunk(2, dim=1)
279 |
280 | frames_rgb = decode_latents(pipe, latents_rgb)
281 | frames_alpha = decode_latents(pipe, latents_alpha)
282 |
283 | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
284 | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
285 | premultiplied_rgb = frames_rgb * frames_alpha_pooled
286 |
287 | rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb')
288 | rgb_video_update = gr.update(visible=True, value=rgb_video_path)
289 |
290 | alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha')
291 | alpha_video_update = gr.update(visible=True, value=alpha_video_path)
292 | seed_update = gr.update(visible=True, value=seed)
293 |
294 | return rgb_video_path, alpha_video_path, rgb_video_update, alpha_video_update, seed_update
295 |
296 |
297 | generate_button.click(
298 | generate,
299 | inputs=[prompt, seed_param],
300 | outputs=[rgb_video_output, alpha_video_output, download_rgb_video_button, download_alpha_video_button, seed_text],
301 | )
302 |
303 |
304 | if __name__ == "__main__":
305 | demo.queue(max_size=15)
306 | demo.launch()
307 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | torchvision
3 | torchaudio
4 | wandb
5 | gradio
6 | sentencepiece
7 | diffusers==0.32.0
8 | huggingface_hub==0.27.0
9 | transformers
10 | imageio>=2.5.0
11 | imageio-ffmpeg
12 | moviepy==1.0.3
13 | opencv-python>=4.5
14 | accelerate
15 |
--------------------------------------------------------------------------------
/wechat_group.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wileewang/TransPixeler/a704f8037f5ceb93770eb432f9db847ad5a49c27/wechat_group.jpg
--------------------------------------------------------------------------------