├── launch_venv.bat ├── requirements.txt ├── LICENSE ├── noise_warp ├── raft.py └── GetWarpedNoiseFromVideo.py ├── test_skyreels_t2v.py ├── .gitignore ├── test_skyreels_i2v.py ├── README.md ├── test_hunyuan_lora.py ├── convert_diffusers_lora_to_original.py ├── train_hunyuan_lora.py └── pipelines ├── pipeline_hunyuan_video.py ├── pipeline_skyreels_t2v.py └── pipeline_skyreels_i2v.py /launch_venv.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call %~dp0.venv\Scripts\activate.bat 3 | cmd /K -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | transformers 3 | accelerate 4 | bitsandbytes 5 | tensorboard 6 | decord 7 | peft 8 | imageio 9 | imageio-ffmpeg 10 | einops -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 spacepxl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /noise_warp/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms 3 | from torchvision.models.optical_flow import raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights 4 | import torch.nn.functional as F 5 | 6 | class RaftOpticalFlow(): 7 | def __init__(self, version='large', device="cuda", dtype=torch.float32): 8 | """ 9 | Automatically downloads the model you select upon instantiation if not already downloaded 10 | """ 11 | 12 | models = { 13 | 'large': raft_large, 14 | 'small': raft_small, 15 | } 16 | weights = { 17 | 'large': Raft_Large_Weights.DEFAULT, 18 | 'small': Raft_Small_Weights.DEFAULT, 19 | } 20 | 21 | assert version in models 22 | 23 | model = models[version](weights=weights[version], progress=False).to(device, dtype=dtype) 24 | model.requires_grad_(False) 25 | model.eval() 26 | 27 | self.version = version 28 | self.device = device 29 | self.dtype = dtype 30 | self.model = model 31 | 32 | def _preprocess_image(self, image): 33 | 34 | image = image.to(self.device, dtype=self.dtype) 35 | 36 | #Floor height and width to the nearest multpiple of 8 37 | height, width = image.shape[-2:] 38 | new_height = (height // 8) * 8 39 | new_width = (width // 8) * 8 40 | 41 | #Resize the image 42 | image = F.interpolate(image.unsqueeze(0), size=(new_height, new_width)).squeeze(0) 43 | 44 | #Map [0, 1] to [-1, 1] 45 | # image = image * 2 - 1 46 | 47 | #CHW --> 1CHW 48 | output = image[None] 49 | 50 | assert output.shape == (1, 3, new_height, new_width) 51 | 52 | return output 53 | 54 | def __call__(self, from_image, to_image): 55 | """ 56 | Calculates the optical flow from from_image to to_image, returned in 2HW form 57 | In other words, returns (dx, dy) where dx and dy are both HW torch matrices with the same height and width as the input image 58 | 59 | Works best when the image's dimensions are multiple of 8 pixels 60 | Works fastest when passed torch images on the same device as this model 61 | 62 | Args: 63 | from_image: Can be an image as defined by rp.is_image, or an RGB torch image (a 3HW torch tensor) 64 | to_image : Can be an image as defined by rp.is_image, or an RGB torch image (a 3HW torch tensor) 65 | """ 66 | height, width = from_image.shape[-2:] 67 | 68 | with torch.no_grad(): 69 | img1 = self._preprocess_image(from_image) 70 | img2 = self._preprocess_image(to_image ) 71 | 72 | list_of_flows = self.model(img1, img2) 73 | output_flow = list_of_flows[-1][0] 74 | 75 | # Resize the predicted flow back to the original image size 76 | resize = torchvision.transforms.Resize((height, width)) 77 | output_flow = resize(output_flow[None])[0] 78 | 79 | assert output_flow.shape == (2, height, width) 80 | 81 | return output_flow -------------------------------------------------------------------------------- /test_skyreels_t2v.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizerFast, LlamaModel, LlamaTokenizerFast 5 | # from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 6 | from diffusers import HunyuanVideoTransformer3DModel 7 | from diffusers.utils import export_to_video 8 | from pipelines.pipeline_skyreels_t2v import HunyuanVideoPipeline 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description = "HunyuanVideo lora test script", 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 15 | ) 16 | parser.add_argument( 17 | "--pretrained_model", 18 | type=str, 19 | default="./models", 20 | help="Path to pretrained model base directory", 21 | ) 22 | parser.add_argument( 23 | "--output_dir", 24 | type = str, 25 | default = "./", 26 | help = "Output directory for results" 27 | ) 28 | parser.add_argument( 29 | "--seed", 30 | type = int, 31 | default = 42, 32 | help = "Seed for inference" 33 | ) 34 | parser.add_argument( 35 | "--width", 36 | type = int, 37 | default = 512, 38 | help = "Width for inference" 39 | ) 40 | parser.add_argument( 41 | "--height", 42 | type = int, 43 | default = 512, 44 | help = "Width for inference" 45 | ) 46 | parser.add_argument( 47 | "--num_frames", 48 | type = int, 49 | default = 33, 50 | help = "Number of frames per video, must be divisible by 4+1" 51 | ) 52 | parser.add_argument( 53 | "--inference_steps", 54 | type = int, 55 | default = 20, 56 | help = "Number of steps for inference", 57 | ) 58 | parser.add_argument( 59 | "--prompt", 60 | type=str, 61 | default="A person typing on a laptop keyboard", 62 | help="Prompt for inference", 63 | ) 64 | 65 | args = parser.parse_args() 66 | return args 67 | 68 | @torch.inference_mode() 69 | def main(args): 70 | 71 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 72 | args.pretrained_model, 73 | subfolder = "transformer-skyreels-t2v", 74 | torch_dtype = torch.bfloat16, 75 | ) 76 | 77 | pipe = HunyuanVideoPipeline.from_pretrained( 78 | args.pretrained_model, 79 | transformer=transformer, 80 | torch_dtype=torch.float16, 81 | ) 82 | 83 | pipe.vae.enable_tiling( 84 | tile_sample_min_height = 256, 85 | tile_sample_min_width = 256, 86 | tile_sample_min_num_frames = 64, 87 | tile_sample_stride_height = 192, 88 | tile_sample_stride_width = 192, 89 | tile_sample_stride_num_frames = 16, 90 | ) 91 | 92 | pipe.enable_sequential_cpu_offload() 93 | # pipe.scheduler.set_shift(17.0) 94 | 95 | output = pipe( 96 | prompt = args.prompt, 97 | guidance_scale = 1.0, 98 | cfg_scale = 6.0, 99 | cfg_steps = 5, 100 | height = args.height, 101 | width = args.width, 102 | num_frames = args.num_frames, 103 | num_inference_steps = args.inference_steps, 104 | generator = torch.Generator(device="cpu").manual_seed(args.seed), 105 | ).frames[0] 106 | 107 | export_to_video( 108 | output, 109 | os.path.join(args.output_dir, "output_skyreels_t2v.mp4"), 110 | fps=15, 111 | ) 112 | 113 | 114 | if __name__ == "__main__": 115 | args = parse_args() 116 | main(args) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | outputs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /test_skyreels_i2v.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from transformers import CLIPTextModel, CLIPTokenizerFast, LlamaModel, LlamaTokenizerFast 8 | from diffusers import HunyuanVideoTransformer3DModel 9 | from diffusers.utils import export_to_video 10 | from pipelines.pipeline_skyreels_i2v import HunyuanVideoPipeline 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser( 15 | description = "HunyuanVideo lora test script", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 17 | ) 18 | parser.add_argument( 19 | "--pretrained_model", 20 | type=str, 21 | default="./models", 22 | help="Path to pretrained model base directory", 23 | ) 24 | parser.add_argument( 25 | "--output_dir", 26 | type = str, 27 | default = "./", 28 | help = "Output directory for results" 29 | ) 30 | parser.add_argument( 31 | "--seed", 32 | type = int, 33 | default = 42, 34 | help = "Seed for inference" 35 | ) 36 | parser.add_argument( 37 | "--width", 38 | type = int, 39 | default = 512, 40 | help = "Width for inference" 41 | ) 42 | parser.add_argument( 43 | "--height", 44 | type = int, 45 | default = 512, 46 | help = "Width for inference" 47 | ) 48 | parser.add_argument( 49 | "--num_frames", 50 | type = int, 51 | default = 33, 52 | help = "Number of frames per video, must be divisible by 4+1" 53 | ) 54 | parser.add_argument( 55 | "--inference_steps", 56 | type = int, 57 | default = 20, 58 | help = "Number of steps for inference", 59 | ) 60 | parser.add_argument( 61 | "--cfg_steps", 62 | type = int, 63 | default = 5, 64 | help = "Number of steps for inference", 65 | ) 66 | parser.add_argument( 67 | "--prompt", 68 | type=str, 69 | default="A person typing on a laptop keyboard", 70 | help="Prompt for inference", 71 | ) 72 | parser.add_argument( 73 | "--image", 74 | type=str, 75 | default="./test.png", 76 | help="First frame image", 77 | ) 78 | 79 | args = parser.parse_args() 80 | return args 81 | 82 | @torch.inference_mode() 83 | def main(args): 84 | 85 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 86 | args.pretrained_model, 87 | subfolder = "transformer-skyreels-i2v", 88 | torch_dtype = torch.bfloat16, 89 | ) 90 | 91 | pipe = HunyuanVideoPipeline.from_pretrained( 92 | args.pretrained_model, 93 | transformer=transformer, 94 | torch_dtype=torch.float16, 95 | ) 96 | 97 | pipe.vae.enable_tiling( 98 | tile_sample_min_height = 256, 99 | tile_sample_min_width = 256, 100 | tile_sample_min_num_frames = 64, 101 | tile_sample_stride_height = 192, 102 | tile_sample_stride_width = 192, 103 | tile_sample_stride_num_frames = 16, 104 | ) 105 | 106 | pipe.enable_sequential_cpu_offload() 107 | # pipe.scheduler.set_shift(17.0) 108 | 109 | image = Image.open(args.image).convert('RGB') 110 | image = torch.as_tensor(np.array(image)).movedim(-1, 0).unsqueeze(0) # BCHW 111 | image = (image.float() / 255) * 2 - 1 112 | image = F.interpolate(image, size=(args.height, args.width), mode="bilinear") 113 | image = image.movedim(1, 0).unsqueeze(0) # BCFHW 114 | 115 | output = pipe( 116 | image = image, 117 | prompt = args.prompt, 118 | guidance_scale = 1.0, 119 | cfg_scale = 6.0, 120 | cfg_steps = args.cfg_steps, 121 | height = args.height, 122 | width = args.width, 123 | num_frames = args.num_frames, 124 | num_inference_steps = args.inference_steps, 125 | generator = torch.Generator(device="cpu").manual_seed(args.seed), 126 | ).frames[0] 127 | 128 | export_to_video( 129 | output, 130 | os.path.join(args.output_dir, "output_skyreels_i2v.mp4"), 131 | fps=15, 132 | ) 133 | 134 | 135 | if __name__ == "__main__": 136 | args = parse_args() 137 | main(args) -------------------------------------------------------------------------------- /noise_warp/GetWarpedNoiseFromVideo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch.nn.functional as F 6 | 7 | from .noise_warp import NoiseWarper, mix_new_noise 8 | from .raft import RaftOpticalFlow 9 | 10 | def get_downtemp_noise(noise, noise_downtemp_interp, interp_to=13): 11 | if noise_downtemp_interp == 'nearest': 12 | return resize_list(noise, interp_to) 13 | elif noise_downtemp_interp == 'blend': 14 | return downsamp_mean(noise, interp_to) 15 | elif noise_downtemp_interp == 'blend_norm': 16 | return normalized_noises(downsamp_mean(noise, interp_to)) 17 | elif noise_downtemp_interp == 'randn': 18 | return torch.randn_like(resize_list(noise, interp_to)) 19 | else: 20 | return noise 21 | 22 | def downsamp_mean(x, l=13): 23 | return torch.stack([sum(u) / len(u) for u in split_into_n_sublists(x, l)]) 24 | 25 | def normalized_noises(noises): 26 | #Noises is in TCHW form 27 | return torch.stack([x / x.std(1, keepdim=True) for x in noises]) 28 | 29 | def resize_list(array:list, length: int): 30 | assert isinstance(length, int), "Length must be an integer, but got %s instead"%repr(type(length)) 31 | assert length >= 0, "Length must be a non-negative integer, but got %i instead"%length 32 | 33 | if len(array) > 1 and length > 1: 34 | step = (len(array) - 1) / (length - 1) 35 | else: 36 | step = 0 # default step size to 0 if array has only 1 element or target length is 1 37 | 38 | indices = [round(i * step) for i in range(length)] 39 | 40 | if isinstance(array, np.ndarray) or isinstance(array, torch.Tensor): 41 | return array[indices] 42 | else: 43 | return [array[i] for i in indices] 44 | 45 | def split_into_n_sublists(l, n): 46 | if n <= 0: 47 | raise ValueError("n must be greater than 0 but n is "+str(n)) 48 | 49 | if isinstance(l, str): 50 | return ''.join(split_into_n_sublists(list(l), n)) 51 | 52 | L = len(l) 53 | indices = [int(i * L / n) for i in range(n + 1)] 54 | return [l[indices[i]:indices[i + 1]] for i in range(n)] 55 | 56 | 57 | class GetWarpedNoiseFromVideo: 58 | def __init__(self, raft_size="large", device="cuda", dtype=torch.float32): 59 | self.device = device 60 | self.dtype = dtype 61 | self.raft_model = RaftOpticalFlow(version=raft_size, device=self.device, dtype=self.dtype) 62 | 63 | def __call__( 64 | self, 65 | images, 66 | degradation = 0.0, 67 | noise_channels = 4, 68 | spatial_downscale_factor = 8, 69 | target_latent_count = 16, 70 | noise_downtemp_interp = "nearest", 71 | ): 72 | resize_flow = 1 73 | resize_frames = 1 74 | downscale_factor = round(resize_frames * resize_flow) * spatial_downscale_factor 75 | 76 | # Load video frames into a [B, C, H, W] tensor, where C=3 and values are between -1 and 1 77 | B, C, H, W = images.shape 78 | video_frames = images 79 | 80 | def downscale_noise(noise): 81 | down_noise = F.interpolate(noise, scale_factor=1/downscale_factor, mode='area') # Avg pooling 82 | down_noise = down_noise * downscale_factor #Adjust for STD 83 | return down_noise 84 | 85 | warper = NoiseWarper( 86 | c = noise_channels, 87 | h = resize_flow * H, 88 | w = resize_flow * W, 89 | device = self.device, 90 | post_noise_alpha = 0, 91 | progressive_noise_alpha = 0, 92 | ) 93 | 94 | prev_video_frame = video_frames[0] 95 | noise = warper.noise 96 | 97 | down_noise = downscale_noise(noise) 98 | numpy_noise = down_noise.cpu().numpy().astype(np.float16) # In HWC form. Using float16 to save RAM, but it might cause problems on come CPU 99 | 100 | numpy_noises = [numpy_noise] 101 | # for video_frame in tqdm(video_frames[1:], desc="Calculating noise warp", leave=False): 102 | for video_frame in video_frames[1:]: 103 | dx, dy = self.raft_model(prev_video_frame, video_frame) 104 | noise = warper(dx, dy).noise 105 | prev_video_frame = video_frame 106 | 107 | numpy_flow = np.stack( 108 | [ 109 | dx.cpu().numpy().astype(np.float16), 110 | dy.cpu().numpy().astype(np.float16), 111 | ] 112 | ) 113 | down_noise = downscale_noise(noise) 114 | numpy_noise = down_noise.cpu().numpy().astype(np.float16) 115 | numpy_noises.append(numpy_noise) 116 | 117 | numpy_noises = np.stack(numpy_noises).astype(np.float16) 118 | noise_tensor = torch.from_numpy(numpy_noises).squeeze(1).cpu().float() 119 | 120 | downtemp_noise_tensor = get_downtemp_noise( 121 | noise_tensor, 122 | noise_downtemp_interp = noise_downtemp_interp, 123 | interp_to = target_latent_count, 124 | ) # B, F, C, H, W 125 | downtemp_noise_tensor = downtemp_noise_tensor[None] 126 | downtemp_noise_tensor = mix_new_noise(downtemp_noise_tensor, degradation) 127 | downtemp_noise_tensor = downtemp_noise_tensor.squeeze(0) 128 | 129 | return downtemp_noise_tensor # BCHW? -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HunyuanVideo-Training 2 | 3 | This is not intended to be a comprehensive all-in-one trainer, it's just the bare minimum framework to support simple one-file training scripts, in the spirit of diffusers example training scripts. It's meant to be easy to read and modify, without too much abstraction getting in the way, but with enough optimization to not require rental cloud compute. The default configuration uses < 16 GB of VRAM (although < 24 GB is the target), and runs natively on windows. To achieve this, the diffusion model is quantized to nf4 using bitsandbytes (similar to QLoRA), and text embeddings are pre-computed. Latents are encoded on the fly however, to reduce overfitting and make it easier to work with larger datasets. 4 | 5 | Don't expect much support, as this is primarily for my own use, but I'm sharing it for others who want to tinker with training code, and because I was frustrated with diffusers' switch with recent video models from single file training scripts to finetrainers. Any code written by me is released under MIT license (aka do whatever you want), but the HunyuanVideo model is subject to the [tencent community license](https://github.com/Tencent/HunyuanVideo/blob/main/LICENSE.txt). 6 | 7 | ## Install 8 | 9 | ``` 10 | git clone https://github.com/spacepxl/HunyuanVideo-Training 11 | cd HunyuanVideo-Training 12 | 13 | python -m venv .venv 14 | .venv\Scripts\activate 15 | 16 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 17 | ``` 18 | 19 | (or follow whatever approach you prefer for environment management and pytorch/triton/etc) 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Before training 26 | 27 | Activate the venv (you can use launch_venv.bat if on windows) 28 | 29 | Download models: 30 | 31 | ``` 32 | python train_hunyuan_lora.py --download_model 33 | ``` 34 | 35 | or if you want to train skyreels i2v: 36 | 37 | ``` 38 | python train_hunyuan_lora.py --download_model --skyreels_i2v 39 | ``` 40 | 41 | (if you already have the diffusers models saved elsewhere, you can skip downloading and train with `--pretrained_model` pointing to the correct folder) 42 | 43 | Expected folder structure: 44 | ``` 45 | HunyuanVideo-Training 46 | /models 47 | /scheduler 48 | /text_encoder 49 | /text_encoder_2 50 | /tokenizer 51 | /tokenizer_2 52 | /transformer 53 | /transformer-skyreels-i2v 54 | /vae 55 | ``` 56 | 57 | ## Dataset prep 58 | 59 | Your dataset should be structured something like this: 60 | 61 | ``` 62 | dataset 63 | /train 64 | /subfolder 65 | sample1.mp4 66 | sample1.txt 67 | sample2.jpg 68 | sample2.txt 69 | ... 70 | /val 71 | ... 72 | ``` 73 | Training data goes in `/train`, validation data goes in `/val`, `/validation`, or `/test`. Subfolders within the train or val folder are scanned recursively, so organize them however you like. 74 | 75 | If no validation set is provided, then validation will fall back to the training set. This is NOT RECOMMENDED, it will make it impossible to judge overfitting from the validation loss. 76 | 77 | Image and video files are supported, and captions should be the same filename as the media file but with .txt extension. See [here](https://github.com/spacepxl/HunyuanVideo-Training/blob/main/train_hunyuan_lora.py#L43) for the list of filetypes, although not all are tested, and others might work just by adding them to the list (anything supported by decord for videos or Pillow for images) 78 | 79 | Once you have a dataset, cache the text embeddings: 80 | 81 | ``` 82 | python train_hunyuan_lora.py --dataset "example/dataset" --cache_embeddings 83 | ``` 84 | 85 | ## Training 86 | 87 | ``` 88 | python train_hunyuan_lora.py --dataset "example/dataset" 89 | ``` 90 | 91 | All other arguments are optional, the defaults should be a reasonable starting point. 92 | 93 | By default, resolutions are randomized in buckets, frame length is set based on a context length budget (so, semi-random but keeping similar memory and compute cost), and the start frame is randomly chosen from the range allowed by frame length. The default `--token_limit` of 10000 is good for < 16 GB, and about 30s/step on a 3090. Raising the token limit will use more memory and more time per step, and increase the number of frames at every resolution. 94 | 95 | Optionally you can set `--resolution` to disable the resolution buckets and use a square crop of that exact size. If you set `--num_frames` it will use that as the upper limit (some samples may use fewer frames depending on resolution, or limited by short clips/images). 96 | 97 | For image to video training, use `--skyreels_i2v` to load the skyreels model and use the first frame as image conditioning. 98 | 99 | Warped noise from [Go With The Flow](https://eyeline-research.github.io/Go-with-the-Flow/) can be enabled by `--warped_noise`, note that this will take longer to adapt to than normal random noise, so it's more for general adapter training than character/style loras, and you should use a large dataset, ideally larger than the number of training steps. 100 | 101 | ## After training 102 | 103 | The saved lora checkpoints are in diffusers format, so if you want to use them with the original tencent model (or in comfyui), you'll need to convert them: 104 | 105 | ``` 106 | python convert_diffusers_lora_to_original.py --input_lora "./outputs/example/checkpoints/hyv-lora-00001000.safetensors" 107 | ``` 108 | 109 | You can optionally convert the lora dtype to bf16 or fp16 to save file size. If you set alpha to anything other than rank during training, you'll need to manually input the alpha during conversion with `--alpha` 110 | -------------------------------------------------------------------------------- /test_hunyuan_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import argparse 4 | import torch 5 | import bitsandbytes as bnb 6 | 7 | from peft import PeftModel, LoraConfig, set_peft_model_state_dict 8 | 9 | from transformers import CLIPTextModel, CLIPTokenizerFast, LlamaModel, LlamaTokenizerFast 10 | from safetensors.torch import load_file, save_file 11 | 12 | from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 13 | from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig 14 | from diffusers.utils import export_to_video 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description = "HunyuanVideo lora test script", 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 21 | ) 22 | parser.add_argument( 23 | "--pretrained_model", 24 | type=str, 25 | default="./models", 26 | help="Path to pretrained model base directory", 27 | ) 28 | parser.add_argument( 29 | "--lora", 30 | type = str, 31 | default = None, 32 | help = "LoRA file to test", 33 | ) 34 | parser.add_argument( 35 | "--alpha", 36 | type = int, 37 | default = None, 38 | help = "lora alpha, defaults to 1" 39 | ) 40 | parser.add_argument( 41 | "--output_dir", 42 | type = str, 43 | default = "./", 44 | help = "Output directory for results" 45 | ) 46 | parser.add_argument( 47 | "--seed", 48 | type = int, 49 | default = 42, 50 | help = "Seed for inference" 51 | ) 52 | parser.add_argument( 53 | "--width", 54 | type = int, 55 | default = 512, 56 | help = "Width for inference" 57 | ) 58 | parser.add_argument( 59 | "--height", 60 | type = int, 61 | default = 512, 62 | help = "Width for inference" 63 | ) 64 | parser.add_argument( 65 | "--num_frames", 66 | type = int, 67 | default = 33, 68 | help = "Number of frames per video, must be divisible by 4+1" 69 | ) 70 | parser.add_argument( 71 | "--inference_steps", 72 | type = int, 73 | default = 20, 74 | help = "Number of steps for inference", 75 | ) 76 | parser.add_argument( 77 | "--prompt", 78 | type=str, 79 | default="A person typing on a laptop keyboard", 80 | help="Prompt for inference", 81 | ) 82 | 83 | args = parser.parse_args() 84 | return args 85 | 86 | @torch.inference_mode() 87 | def main(args): 88 | 89 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 90 | args.pretrained_model, 91 | subfolder = "transformer", 92 | torch_dtype = torch.bfloat16, 93 | ) 94 | 95 | pipe = HunyuanVideoPipeline.from_pretrained(args.pretrained_model, transformer=transformer, torch_dtype=torch.float16) 96 | pipe.vae.enable_tiling( 97 | tile_sample_min_height = 256, 98 | tile_sample_min_width = 256, 99 | tile_sample_min_num_frames = 64, 100 | tile_sample_stride_height = 192, 101 | tile_sample_stride_width = 192, 102 | tile_sample_stride_num_frames = 16, 103 | ) 104 | pipe.enable_sequential_cpu_offload() 105 | 106 | output = pipe( 107 | prompt = args.prompt, 108 | height = args.height, 109 | width = args.width, 110 | num_frames = args.num_frames, 111 | num_inference_steps = args.inference_steps, 112 | generator = torch.Generator(device="cpu").manual_seed(args.seed), 113 | ).frames[0] 114 | 115 | export_to_video( 116 | output, 117 | os.path.join(args.output_dir, "output_base.mp4"), 118 | fps=15, 119 | ) 120 | 121 | if args.lora is not None: 122 | del transformer 123 | pipe.transformer = None 124 | gc.collect() 125 | torch.cuda.empty_cache() 126 | 127 | lora_sd = load_file(args.lora) 128 | rank = 0 129 | for key in lora_sd.keys(): 130 | if ".lora_A.weight" in key: 131 | rank = lora_sd[key].shape[0] 132 | 133 | alpha = 1 if args.alpha is None else args.alpha 134 | lora_weight = alpha / rank 135 | 136 | print(f"lora rank = {rank}") 137 | print(f"alpha = {alpha}") 138 | print(f"lora weight = {lora_weight}") 139 | 140 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 141 | args.pretrained_model, 142 | subfolder = "transformer", 143 | torch_dtype = torch.bfloat16, 144 | ) 145 | 146 | transformer.load_lora_adapter(lora_sd, adapter_name="default_lora") 147 | 148 | transformer.set_adapters(adapter_names = "default_lora", weights = lora_weight) 149 | pipe.transformer = transformer 150 | pipe.enable_sequential_cpu_offload() 151 | 152 | output_lora = pipe( 153 | prompt = args.prompt, 154 | height = args.height, 155 | width = args.width, 156 | num_frames = args.num_frames, 157 | num_inference_steps = args.inference_steps, 158 | generator = torch.Generator(device="cpu").manual_seed(args.seed), 159 | ).frames[0] 160 | 161 | export_to_video( 162 | output_lora, 163 | os.path.join(args.output_dir, "output_lora.mp4"), 164 | fps=15, 165 | ) 166 | 167 | 168 | if __name__ == "__main__": 169 | args = parse_args() 170 | main(args) -------------------------------------------------------------------------------- /convert_diffusers_lora_to_original.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from safetensors.torch import load_file, save_file 6 | 7 | 8 | def convert_lora_sd(diffusers_lora_sd): 9 | double_block_patterns = { 10 | "attn.to_out.0": "img_attn_proj", 11 | "ff.net.0.proj": "img_mlp.0", 12 | "ff.net.2": "img_mlp.2", 13 | "attn.to_add_out": "txt_attn_proj", 14 | "ff_context.net.0.proj": "txt_mlp.0", 15 | "ff_context.net.2": "txt_mlp.2", 16 | } 17 | 18 | token_refiner_patterns = { 19 | "attn.to_out.0": "self_attn_proj", 20 | "ff.net.0.proj": "mlp.fc1", 21 | "ff.net.2": "mlp.fc2", 22 | } 23 | 24 | prefix = "diffusion_model." 25 | 26 | converted_lora_sd = {} 27 | for key in diffusers_lora_sd.keys(): 28 | # txt_in.individual_token_refiner 29 | if key.startswith("context_embedder.token_refiner"): 30 | if key.endswith("to_q.lora_A.weight"): 31 | # lora_A 32 | to_q_A = diffusers_lora_sd[key] 33 | to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")] 34 | to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")] 35 | 36 | to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0) 37 | qkv_A_key = key.replace( 38 | "context_embedder.token_refiner.refiner_blocks", 39 | prefix + "txt_in.individual_token_refiner.blocks", 40 | ) 41 | qkv_A_key = qkv_A_key.replace("attn.to_q", "self_attn_qkv") 42 | converted_lora_sd[qkv_A_key] = to_qkv_A 43 | 44 | # lora_B 45 | to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")] 46 | to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")] 47 | to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")] 48 | 49 | to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B) 50 | qkv_B_key = qkv_A_key.replace("lora_A", "lora_B") 51 | converted_lora_sd[qkv_B_key] = to_qkv_B 52 | 53 | # just rename 54 | for k, v in token_refiner_patterns.items(): 55 | if k in key: 56 | new_key = key.replace(k, v).replace( 57 | "context_embedder.token_refiner.refiner_blocks", 58 | prefix + "txt_in.individual_token_refiner.blocks", 59 | ) 60 | converted_lora_sd[new_key] = diffusers_lora_sd[key] 61 | 62 | # double_blocks 63 | elif key.startswith("transformer_blocks"): 64 | # img_attn 65 | if key.endswith("to_q.lora_A.weight"): 66 | # lora_A 67 | to_q_A = diffusers_lora_sd[key] 68 | to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")] 69 | to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")] 70 | 71 | to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0) 72 | qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace( 73 | "attn.to_q", "img_attn_qkv" 74 | ) 75 | converted_lora_sd[qkv_A_key] = to_qkv_A 76 | 77 | # lora_B 78 | to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")] 79 | to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")] 80 | to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")] 81 | 82 | to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B) 83 | qkv_B_key = qkv_A_key.replace("lora_A", "lora_B") 84 | converted_lora_sd[qkv_B_key] = to_qkv_B 85 | 86 | # txt_attn 87 | elif key.endswith("add_q_proj.lora_A.weight"): 88 | # lora_A 89 | to_q_A = diffusers_lora_sd[key] 90 | to_k_A = diffusers_lora_sd[key.replace("add_q_proj", "add_k_proj")] 91 | to_v_A = diffusers_lora_sd[key.replace("add_q_proj", "add_v_proj")] 92 | 93 | to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0) 94 | qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace( 95 | "attn.add_q_proj", "txt_attn_qkv" 96 | ) 97 | converted_lora_sd[qkv_A_key] = to_qkv_A 98 | 99 | # lora_B 100 | to_q_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_q_proj.lora_B")] 101 | to_k_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_k_proj.lora_B")] 102 | to_v_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_v_proj.lora_B")] 103 | 104 | to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B) 105 | qkv_B_key = qkv_A_key.replace("lora_A", "lora_B") 106 | converted_lora_sd[qkv_B_key] = to_qkv_B 107 | 108 | # just rename 109 | for k, v in double_block_patterns.items(): 110 | if k in key: 111 | new_key = key.replace(k, v).replace("transformer_blocks", prefix + "double_blocks") 112 | converted_lora_sd[new_key] = diffusers_lora_sd[key] 113 | 114 | # single_blocks 115 | elif key.startswith("single_transformer_blocks"): 116 | if key.endswith("to_q.lora_A.weight"): 117 | # lora_A 118 | to_q_A = diffusers_lora_sd[key] 119 | to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")] 120 | to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")] 121 | proj_mlp_A = diffusers_lora_sd[key.replace("attn.to_q", "proj_mlp")] 122 | 123 | linear1_A = torch.cat([to_q_A, to_k_A, to_v_A, proj_mlp_A], dim=0) 124 | linear1_A_key = key.replace("single_transformer_blocks", prefix + "single_blocks").replace( 125 | "attn.to_q", "linear1" 126 | ) 127 | converted_lora_sd[linear1_A_key] = linear1_A 128 | 129 | # lora_B 130 | to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")] 131 | to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")] 132 | to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")] 133 | proj_mlp_B = diffusers_lora_sd[key.replace("attn.to_q.lora_A", "proj_mlp.lora_B")] 134 | 135 | linear1_B = torch.block_diag(to_q_B, to_k_B, to_v_B, proj_mlp_B) 136 | linear1_B_key = linear1_A_key.replace("lora_A", "lora_B") 137 | converted_lora_sd[linear1_B_key] = linear1_B 138 | 139 | elif "proj_out" in key: 140 | new_key = key.replace("proj_out", "linear2").replace( 141 | "single_transformer_blocks", prefix + "single_blocks" 142 | ) 143 | converted_lora_sd[new_key] = diffusers_lora_sd[key] 144 | 145 | else: 146 | print(f"unknown or not implemented: {key}") 147 | 148 | return converted_lora_sd 149 | 150 | 151 | def get_args(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("--input_lora", type=str, required=True, help="Path to LoRA .safetensors") 154 | parser.add_argument("--alpha", type=float, default=None, help="Optional alpha value, defaults to rank") 155 | parser.add_argument( 156 | "--dtype", type=str, default=None, help="Optional dtype (bfloat16, float16, float32), defaults to input dtype" 157 | ) 158 | parser.add_argument("--debug", action="store_true", help="Print converted keys instead of saving") 159 | return parser.parse_args() 160 | 161 | 162 | if __name__ == "__main__": 163 | args = get_args() 164 | 165 | converted_lora_sd = convert_lora_sd(load_file(args.input_lora)) 166 | 167 | if args.alpha is not None: 168 | for key in list(converted_lora_sd.keys()): 169 | if "lora_A" in key: 170 | alpha_name = key.replace(".lora_A.weight", ".alpha") 171 | converted_lora_sd[alpha_name] = torch.tensor([args.alpha], dtype=converted_lora_sd[key].dtype) 172 | 173 | dtype = None 174 | if args.dtype == "bfloat16": 175 | dtype = torch.bfloat16 176 | elif args.dtype == "float16": 177 | dtype = torch.float16 178 | elif args.dtype == "float32": 179 | dtype = torch.float32 180 | 181 | if dtype is not None: 182 | dtype_min = torch.finfo(dtype).min 183 | dtype_max = torch.finfo(dtype).max 184 | for key in converted_lora_sd.keys(): 185 | if converted_lora_sd[key].min() < dtype_min or converted_lora_sd[key].max() > dtype_max: 186 | print(f"warning: {key} has values outside of {dtype} {dtype_min} {dtype_max} range") 187 | converted_lora_sd[key] = converted_lora_sd[key].to(dtype) 188 | 189 | if args.debug: 190 | for key in sorted(converted_lora_sd.keys()): 191 | print(key, converted_lora_sd[key].shape, converted_lora_sd[key].dtype) 192 | exit() 193 | 194 | output_path = os.path.splitext(args.input_lora)[0] + "_converted.safetensors" 195 | save_file(converted_lora_sd, output_path) 196 | print(f"saved to {output_path}") 197 | -------------------------------------------------------------------------------- /train_hunyuan_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | if torch.cuda.is_available(): 7 | device = torch.cuda.current_device() 8 | torch.cuda.init() 9 | torch.backends.cuda.matmul.allow_tf32 = True 10 | else: 11 | raise Exception("unable to initialize CUDA") 12 | 13 | import os 14 | import gc 15 | import random 16 | import numpy as np 17 | import argparse 18 | import json 19 | import datetime 20 | from tqdm import tqdm 21 | import decord 22 | from contextlib import contextmanager 23 | from time import perf_counter 24 | from glob import glob 25 | from PIL import Image 26 | 27 | from torchvision.transforms import v2, InterpolationMode 28 | from safetensors.torch import load_file, save_file 29 | 30 | from transformers import CLIPTextModel, CLIPTokenizerFast, LlamaModel, LlamaTokenizerFast 31 | from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel, FlowMatchEulerDiscreteScheduler 32 | from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig 33 | from peft import LoraConfig, set_peft_model_state_dict 34 | from peft.utils import get_peft_model_state_dict 35 | import bitsandbytes as bnb 36 | 37 | 38 | DEFAULT_PROMPT_TEMPLATE = { 39 | "template": ( 40 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 41 | "1. The main content and theme of the video." 42 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 43 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 44 | "4. background environment, light, style and atmosphere." 45 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 46 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 47 | ), 48 | "crop_start": 95, 49 | } 50 | 51 | IMAGE_TYPES = [".jpg", ".png"] 52 | VIDEO_TYPES = [".mp4", ".mkv", ".mov", ".avi", ".webm"] 53 | 54 | BUCKET_RESOLUTIONS = { 55 | "1x1": [ 56 | # (256, 256), 57 | (384, 384), 58 | (512, 512), 59 | (640, 640), 60 | (720, 720), 61 | (768, 768), 62 | (848, 848), 63 | (960, 960), 64 | (1024, 1024), 65 | ], 66 | "4x3": [ 67 | # (320, 240), 68 | (384, 288), 69 | (512, 384), 70 | (640, 480), 71 | (768, 576), 72 | (960, 720), 73 | (1024, 768), 74 | (1152, 864), 75 | (1280, 960), 76 | ], 77 | "16x9": [ 78 | # (256, 144), 79 | (512, 288), 80 | (768, 432), 81 | (848, 480), 82 | (960, 544), 83 | (1024, 576), 84 | (1280, 720), 85 | ], 86 | } 87 | 88 | def count_tokens(width, height, frames): 89 | return (width // 16) * (height // 16) * ((frames - 1) // 4 + 1) 90 | 91 | 92 | class CombinedDataset(Dataset): 93 | def __init__( 94 | self, 95 | root_folder, 96 | token_limit = 10_000, 97 | limit_samples = None, 98 | max_frame_stride = 4, 99 | manual_resolution = None, 100 | manual_frames = None, 101 | ): 102 | self.root_folder = root_folder 103 | self.token_limit = token_limit 104 | self.max_frame_stride = max_frame_stride 105 | self.manual_resolution = manual_resolution 106 | self.manual_frames = manual_frames 107 | 108 | # search for all files matching image or video extensions 109 | self.media_files = [] 110 | for ext in IMAGE_TYPES + VIDEO_TYPES: 111 | self.media_files.extend( 112 | glob(os.path.join(self.root_folder, "**", "*" + ext), recursive=True) 113 | ) 114 | 115 | # pull samples evenly from the whole dataset 116 | if limit_samples is not None: 117 | stride = max(1, len(self.media_files) // limit_samples) 118 | self.media_files = self.media_files[::stride] 119 | self.media_files = self.media_files[:limit_samples] 120 | 121 | def __len__(self): 122 | return len(self.media_files) 123 | 124 | def find_max_frames(self, width, height): 125 | if self.manual_frames is not None: 126 | return self.manual_frames 127 | 128 | frames = 1 129 | tokens = count_tokens(width, height, frames) 130 | while tokens < self.token_limit: 131 | new_frames = frames + 4 132 | new_tokens = count_tokens(width, height, new_frames) 133 | if new_tokens < self.token_limit: 134 | frames = new_frames 135 | tokens = new_tokens 136 | else: 137 | return frames 138 | 139 | def get_ar_buckets(self, width, height): 140 | if self.manual_resolution is not None: 141 | return [(self.manual_resolution, self.manual_resolution)] 142 | 143 | ar = width / height 144 | if ar > 1.555: 145 | buckets = BUCKET_RESOLUTIONS["16x9"] 146 | elif ar > 1.166: 147 | buckets = BUCKET_RESOLUTIONS["4x3"] 148 | elif ar > 0.875: 149 | buckets = BUCKET_RESOLUTIONS["1x1"] 150 | elif ar > 0.656: 151 | buckets = [b[::-1] for b in BUCKET_RESOLUTIONS["4x3"]] 152 | else: 153 | buckets = [b[::-1] for b in BUCKET_RESOLUTIONS["16x9"]] 154 | 155 | return [b for b in buckets if b[0] <= width and b[1] <= height] 156 | 157 | def __getitem__(self, idx): 158 | ext = os.path.splitext(self.media_files[idx])[1].lower() 159 | if ext in IMAGE_TYPES: 160 | image = Image.open(self.media_files[idx]).convert('RGB') 161 | pixels = torch.as_tensor(np.array(image)).unsqueeze(0) # FHWC 162 | buckets = self.get_ar_buckets(pixels.shape[2], pixels.shape[1]) 163 | width, height = random.choice(buckets) 164 | else: 165 | vr = decord.VideoReader(self.media_files[idx]) 166 | orig_height, orig_width = vr[0].shape[:2] 167 | orig_frames = len(vr) 168 | 169 | # randomize resolution bucket and frame length 170 | buckets = self.get_ar_buckets(orig_width, orig_height) 171 | width, height = random.choice(buckets) 172 | max_frames = self.find_max_frames(width, height) 173 | stride = max(min(random.randint(1, self.max_frame_stride), orig_frames // max_frames), 1) 174 | 175 | # sample a clip from the video based on frame stride and length 176 | seg_len = min(stride * max_frames, orig_frames) 177 | start_frame = random.randint(0, orig_frames - seg_len) 178 | pixels = vr[start_frame : start_frame+seg_len : stride] 179 | max_frames = ((pixels.shape[0] - 1) // 4) * 4 + 1 180 | pixels = pixels[:max_frames] # clip frames to match vae 181 | 182 | # determine crop dimensions to prevent stretching during resize 183 | pixels_ar = pixels.shape[2] / pixels.shape[1] 184 | target_ar = width / height 185 | if pixels_ar > target_ar: 186 | crop_width = min(int(pixels.shape[1] * target_ar), pixels.shape[2]) 187 | crop_height = pixels.shape[1] 188 | elif pixels_ar < target_ar: 189 | crop_width = pixels.shape[2] 190 | crop_height = min(int(pixels.shape[2] / target_ar), pixels.shape[1]) 191 | else: 192 | crop_width = pixels.shape[2] 193 | crop_height = pixels.shape[1] 194 | 195 | # convert to expected dtype, resolution, shape, and value range 196 | transform = v2.Compose([ 197 | v2.ToDtype(torch.float32, scale=True), 198 | v2.RandomCrop(size=(crop_height, crop_width)), 199 | v2.Resize(size=(height, width)), 200 | ]) 201 | 202 | pixels = pixels.movedim(3, 1).unsqueeze(0).contiguous() # FHWC -> FCHW -> BFCHW 203 | pixels = transform(pixels) * 2 - 1 204 | pixels = torch.clamp(torch.nan_to_num(pixels), min=-1, max=1) 205 | 206 | # load precomputed text embeddings from file 207 | embedding_file = os.path.splitext(self.media_files[idx])[0] + "_hyv.safetensors" 208 | if not os.path.exists(embedding_file): 209 | embedding_file = os.path.join( 210 | os.path.dirname(self.media_files[idx]), 211 | random.choice(["caption_original_hyv.safetensors", "caption_florence_hyv.safetensors"]), 212 | ) 213 | 214 | if os.path.exists(embedding_file): 215 | embedding_dict = load_file(embedding_file) 216 | else: 217 | raise Exception(f"No embedding file found for {self.media_files[idx]}, you may need to precompute embeddings with --cache_embeddings") 218 | 219 | return {"pixels": pixels, "embedding_dict": embedding_dict} 220 | 221 | 222 | @contextmanager 223 | def temp_rng(new_seed=None): 224 | """ 225 | https://github.com/fpgaminer/bigasp-training/blob/main/utils.py#L73 226 | Context manager that saves and restores the RNG state of PyTorch, NumPy and Python. 227 | If new_seed is not None, the RNG state is set to this value before the context is entered. 228 | """ 229 | 230 | # Save RNG state 231 | old_torch_rng_state = torch.get_rng_state() 232 | old_torch_cuda_rng_state = torch.cuda.get_rng_state() 233 | old_numpy_rng_state = np.random.get_state() 234 | old_python_rng_state = random.getstate() 235 | 236 | # Set new seed 237 | if new_seed is not None: 238 | torch.manual_seed(new_seed) 239 | torch.cuda.manual_seed(new_seed) 240 | np.random.seed(new_seed) 241 | random.seed(new_seed) 242 | 243 | yield 244 | 245 | # Restore RNG state 246 | torch.set_rng_state(old_torch_rng_state) 247 | torch.cuda.set_rng_state(old_torch_cuda_rng_state) 248 | np.random.set_state(old_numpy_rng_state) 249 | random.setstate(old_python_rng_state) 250 | 251 | 252 | @contextmanager 253 | def timer(message=""): 254 | start_time = perf_counter() 255 | yield 256 | end_time = perf_counter() 257 | print(f"{message} {end_time - start_time:0.2f} seconds") 258 | 259 | 260 | def make_dir(base, folder): 261 | new_dir = os.path.join(base, folder) 262 | os.makedirs(new_dir, exist_ok=True) 263 | return new_dir 264 | 265 | 266 | def download_model(args): 267 | from huggingface_hub import snapshot_download 268 | snapshot_download( 269 | repo_type = "model", 270 | repo_id = "hunyuanvideo-community/HunyuanVideo", 271 | local_dir = "./models", 272 | max_workers = 1, 273 | ) 274 | 275 | if args.skyreels_i2v: 276 | snapshot_download( 277 | repo_type = "model", 278 | repo_id = "Skywork/SkyReels-V1-Hunyuan-I2V", 279 | local_dir = "./models/transformer-skyreels-i2v", 280 | max_workers = 1, 281 | ) 282 | 283 | 284 | def parse_args(): 285 | parser = argparse.ArgumentParser( 286 | description = "HunyuanVideo training script", 287 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 288 | ) 289 | parser.add_argument( 290 | "--download_model", 291 | action = "store_true", 292 | help = "auto download hunyuanvideo-community/HunyuanVideo to ./models if it's missing", 293 | ) 294 | parser.add_argument( 295 | "--skyreels_i2v", 296 | action = "store_true", 297 | help = "download/train skyreels image to video model", 298 | ) 299 | parser.add_argument( 300 | "--cache_embeddings", 301 | action = "store_true", 302 | help = "preprocess dataset to encode captions", 303 | ) 304 | parser.add_argument( 305 | "--pretrained_model", 306 | type = str, 307 | default = "./models", 308 | help = "Path to pretrained model base directory", 309 | ) 310 | parser.add_argument( 311 | "--quant_type", 312 | type = str, 313 | default = "nf4", 314 | help = "Bit depth for the base model, default config with nf4=16GB", 315 | choices=["nf4", "int8", "bf16"], 316 | ) 317 | parser.add_argument( 318 | "--init_lora", 319 | type = str, 320 | default = None, 321 | help = "LoRA checkpoint to load instead of random init, must be the same rank and target layers", 322 | ) 323 | parser.add_argument( 324 | "--dataset", 325 | type = str, 326 | required = True, 327 | help = "Path to dataset directory with train and val subdirectories", 328 | ) 329 | parser.add_argument( 330 | "--val_samples", 331 | type = int, 332 | default = 4, 333 | help = "Maximum number of videos to use for validation loss" 334 | ) 335 | parser.add_argument( 336 | "--output_dir", 337 | type = str, 338 | default = "./outputs", 339 | help = "Output directory for training results" 340 | ) 341 | parser.add_argument( 342 | "--seed", 343 | type = int, 344 | default = 42, 345 | help = "Seed for reproducible training" 346 | ) 347 | parser.add_argument( 348 | "--resolution", 349 | type = int, 350 | default = None, 351 | help = "Manual override resolution for training/testing" 352 | ) 353 | parser.add_argument( 354 | "--num_frames", 355 | type = int, 356 | default = None, 357 | help = "Manual override number of frames per video, must be divisible by 4+1" 358 | ) 359 | parser.add_argument( 360 | "--token_limit", 361 | type = int, 362 | default = 10_000, 363 | help = "Combined resolution/frame limit based on transformer patch sequence length: (width // 16) * (height // 16) * ((frames - 1) // 4 + 1)" 364 | ) 365 | parser.add_argument( 366 | "--max_frame_stride", 367 | type = int, 368 | default = 2, 369 | help = "1: use native framerate only. Higher values allow randomly choosing lower framerates (skipping frames to speed up the video)" 370 | ) 371 | parser.add_argument( 372 | "--learning_rate", 373 | type = float, 374 | default = 2e-4, 375 | help = "Base learning rate", 376 | ) 377 | parser.add_argument( 378 | "--lora_rank", 379 | type = int, 380 | default = 16, 381 | help = "The dimension of the LoRA update matrices", 382 | ) 383 | parser.add_argument( 384 | "--lora_alpha", 385 | type = int, 386 | default = None, 387 | help = "The alpha value for LoRA, defaults to alpha=rank. Note: changing alpha will affect the learning rate, and if alpha=rank then changing rank will also affect learning rate", 388 | ) 389 | parser.add_argument( 390 | "--val_steps", 391 | type = int, 392 | default = 50, 393 | help = "Validate after every n steps", 394 | ) 395 | parser.add_argument( 396 | "--checkpointing_steps", 397 | type = int, 398 | default = 50, 399 | help = "Save a checkpoint of the training state every X steps", 400 | ) 401 | parser.add_argument( 402 | "--max_train_steps", 403 | type = int, 404 | default = 1000, 405 | help = "Total number of training steps", 406 | ) 407 | parser.add_argument( 408 | "--warped_noise", 409 | action = "store_true", 410 | help = "Use warped noise from Go-With-The-Flow instead of pure random noise", 411 | ) 412 | 413 | args = parser.parse_args() 414 | return args 415 | 416 | 417 | def cache_embeddings(args): 418 | print("loading CLIP") 419 | with timer("loaded CLIP in"): 420 | tokenizer_clip = CLIPTokenizerFast.from_pretrained(args.pretrained_model, subfolder="tokenizer_2") 421 | text_encoder_clip = CLIPTextModel.from_pretrained(args.pretrained_model, subfolder="text_encoder_2").to(device=device, dtype=torch.bfloat16) 422 | text_encoder_clip.requires_grad_(False) 423 | 424 | print("loading Llama") 425 | with timer("loaded Llama in"): 426 | tokenizer_llama = LlamaTokenizerFast.from_pretrained(args.pretrained_model, subfolder="tokenizer") 427 | text_encoder_llama = LlamaModel.from_pretrained(args.pretrained_model, subfolder="text_encoder").to(device=device, dtype=torch.bfloat16) 428 | text_encoder_llama.requires_grad_(False) 429 | 430 | def encode_clip(prompt): 431 | input_ids = tokenizer_clip( 432 | prompt, 433 | padding = "max_length", 434 | max_length = 77, 435 | truncation = True, 436 | return_tensors = "pt", 437 | ).input_ids.to(text_encoder_clip.device) 438 | 439 | prompt_embeds = text_encoder_clip( 440 | input_ids, 441 | output_hidden_states = False, 442 | ).pooler_output 443 | 444 | return prompt_embeds 445 | 446 | def encode_llama( 447 | prompt, 448 | prompt_template=DEFAULT_PROMPT_TEMPLATE, 449 | max_sequence_length = 256, 450 | num_hidden_layers_to_skip = 2, 451 | ): 452 | prompt = prompt_template["template"].format(prompt) 453 | crop_start = prompt_template.get("crop_start", None) 454 | max_sequence_length += crop_start 455 | 456 | text_inputs = tokenizer_llama( 457 | prompt, 458 | max_length=max_sequence_length, 459 | padding="max_length", 460 | truncation=True, 461 | return_tensors="pt", 462 | return_length=False, 463 | return_overflowing_tokens=False, 464 | return_attention_mask=True, 465 | ) 466 | text_input_ids = text_inputs.input_ids.to(device=text_encoder_llama.device) 467 | prompt_attention_mask = text_inputs.attention_mask.to(device=text_encoder_llama.device) 468 | 469 | prompt_embeds = text_encoder_llama( 470 | input_ids = text_input_ids, 471 | attention_mask = prompt_attention_mask, 472 | output_hidden_states = True, 473 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 474 | 475 | if crop_start is not None and crop_start > 0: 476 | prompt_embeds = prompt_embeds[:, crop_start:] 477 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 478 | 479 | return prompt_embeds, prompt_attention_mask 480 | 481 | def preprocess_captions(dataset_path): 482 | caption_files = glob(os.path.join(dataset_path, "**", "*.txt" ), recursive=True) 483 | for file in tqdm(caption_files): 484 | embedding_path = os.path.splitext(file)[0] + "_hyv.safetensors" 485 | 486 | if not os.path.exists(embedding_path): 487 | with open(file, "r") as f: 488 | caption = f.read() 489 | 490 | clip_embed = encode_clip(caption) 491 | llama_embed, llama_mask = encode_llama(caption) 492 | 493 | embedding_dict = {"clip_embed": clip_embed, "llama_embed": llama_embed, "llama_mask": llama_mask} 494 | save_file(embedding_dict, embedding_path) 495 | 496 | print("preprocessing caption embeddings") 497 | preprocess_captions(args.dataset) 498 | 499 | del tokenizer_clip, text_encoder_clip, tokenizer_llama, text_encoder_llama 500 | gc.collect() 501 | torch.cuda.empty_cache() 502 | 503 | 504 | def main(args): 505 | decord.bridge.set_bridge('torch') 506 | 507 | torch.manual_seed(args.seed) 508 | torch.cuda.manual_seed(args.seed) 509 | np.random.seed(args.seed) 510 | random.seed(args.seed) 511 | 512 | date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 513 | real_output_dir = make_dir(args.output_dir, date_time) 514 | checkpoint_dir = make_dir(real_output_dir, "checkpoints") 515 | t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) 516 | with open(os.path.join(real_output_dir, "command_args.json"), "w") as f: 517 | json.dump(args.__dict__, f, indent=4) 518 | 519 | def collate_batch(batch): 520 | pixels = torch.cat([sample["pixels"] for sample in batch], dim=0) 521 | clip_embed = torch.cat([sample["embedding_dict"]["clip_embed"] for sample in batch], dim=0) 522 | llama_embed = torch.cat([sample["embedding_dict"]["llama_embed"] for sample in batch], dim=0) 523 | llama_mask = torch.cat([sample["embedding_dict"]["llama_mask"] for sample in batch], dim=0) 524 | return pixels, clip_embed, llama_embed, llama_mask 525 | 526 | train_dataset = os.path.join(args.dataset, "train") 527 | if not os.path.exists(train_dataset): 528 | train_dataset = args.dataset 529 | print(f"WARNING: train subfolder not found, using root folder {train_dataset} as train dataset") 530 | 531 | val_dataset = None 532 | for subfolder in ["val", "validation", "test"]: 533 | subfolder_path = os.path.join(args.dataset, subfolder) 534 | if os.path.exists(subfolder_path): 535 | val_dataset = subfolder_path 536 | break 537 | 538 | if val_dataset is None: 539 | val_dataset = args.dataset 540 | print(f"WARNING: val/validation/test subfolder not found, using root folder {val_dataset} for stable loss validation") 541 | print("\033[33mThis will make it impossible to judge overfitting by the validation loss. Using a val split held out from training is highly recommended\033[m") 542 | 543 | with timer("scanned dataset in"): 544 | train_dataset = CombinedDataset( 545 | root_folder = train_dataset, 546 | token_limit = args.token_limit, 547 | max_frame_stride = args.max_frame_stride, 548 | manual_resolution = args.resolution, 549 | manual_frames = args.num_frames, 550 | ) 551 | val_dataset = CombinedDataset( 552 | root_folder = val_dataset, 553 | token_limit = args.token_limit, 554 | limit_samples = args.val_samples, 555 | max_frame_stride = args.max_frame_stride, 556 | manual_resolution = args.resolution, 557 | manual_frames = args.num_frames, 558 | ) 559 | 560 | train_dataloader = DataLoader( 561 | train_dataset, 562 | shuffle = True, 563 | collate_fn = collate_batch, 564 | batch_size = 1, 565 | num_workers = 0, 566 | pin_memory = True, 567 | ) 568 | val_dataloader = DataLoader( 569 | val_dataset, 570 | shuffle = False, 571 | collate_fn = collate_batch, 572 | batch_size = 1, 573 | num_workers = 0, 574 | pin_memory = True, 575 | ) 576 | 577 | # noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model, subfolder="scheduler") 578 | 579 | with timer("loaded VAE in"): 580 | vae = AutoencoderKLHunyuanVideo.from_pretrained(args.pretrained_model, subfolder="vae").to(device=device, dtype=torch.float16) 581 | vae.requires_grad_(False) 582 | vae.enable_tiling( 583 | tile_sample_min_height = 256, 584 | tile_sample_min_width = 256, 585 | tile_sample_min_num_frames = 48, 586 | tile_sample_stride_height = 192, 587 | tile_sample_stride_width = 192, 588 | tile_sample_stride_num_frames = 32, 589 | ) 590 | 591 | with timer("loaded diffusion model in"): 592 | if args.quant_type == "nf4": 593 | quant_config = DiffusersBitsAndBytesConfig( 594 | load_in_4bit=True, 595 | bnb_4bit_quant_type="nf4", 596 | bnb_4bit_use_double_quant=True, 597 | bnb_4bit_compute_dtype=torch.bfloat16 598 | ) 599 | elif args.quant_type == "int8": 600 | quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) 601 | else: 602 | quant_config = None 603 | 604 | if args.skyreels_i2v: 605 | transformer_subfolder = "transformer-skyreels-i2v" 606 | else: 607 | transformer_subfolder = "transformer" 608 | 609 | diffusion_model = HunyuanVideoTransformer3DModel.from_pretrained( 610 | args.pretrained_model, 611 | subfolder = transformer_subfolder, 612 | quantization_config = quant_config, 613 | torch_dtype = torch.bfloat16, 614 | ) 615 | 616 | diffusion_model.requires_grad_(False) 617 | diffusion_model.enable_gradient_checkpointing() 618 | torch.cuda.empty_cache() 619 | 620 | with timer("added LoRA in"): 621 | lora_params = [] 622 | attn_blocks = ["transformer_blocks", "single_transformer_blocks"] 623 | lora_keys = ["to_k", "to_q", "to_v", "to_out.0", "proj_mlp"] # mmdit img attention + single blocks attention 624 | # lora_keys += ["add_q_proj", "add_k_proj", "add_v_proj", "to_add_out"] # mmdit text attention 625 | # lora_keys += ["ff.net", "proj_out"] # mmdit img mlp + single blocks mlp 626 | # lora_keys += ["ff_context.net"] # mmdit text mlp 627 | for name, param in diffusion_model.named_parameters(): 628 | name = name.replace(".weight", "").replace(".bias", "") 629 | for block in attn_blocks: 630 | if name.startswith(block): 631 | for key in lora_keys: 632 | if key in name: 633 | lora_params.append(name) 634 | 635 | lora_config = LoraConfig( 636 | r = args.lora_rank, 637 | lora_alpha = args.lora_alpha or args.lora_rank, 638 | init_lora_weights = "gaussian", 639 | target_modules = lora_params, 640 | ) 641 | diffusion_model.add_adapter(lora_config) 642 | 643 | if args.init_lora is not None: 644 | loaded_lora_sd = load_file(args.init_lora) 645 | outcome = set_peft_model_state_dict(diffusion_model, loaded_lora_sd) 646 | if len(outcome.unexpected_keys) > 0: 647 | for key in outcome.unexpected_keys: 648 | print(key) 649 | 650 | lora_parameters = [] 651 | total_parameters = 0 652 | for param in diffusion_model.parameters(): 653 | if param.requires_grad: 654 | param.data = param.to(torch.float32) 655 | lora_parameters.append(param) 656 | total_parameters += param.numel() 657 | print(f"total trainable parameters: {total_parameters:,}") 658 | 659 | # Instead of having just one optimizer, we will have a dict of optimizers 660 | # for every parameter so we could reference them in our hook. 661 | optimizer_dict = {p: bnb.optim.AdamW8bit([p], lr=args.learning_rate) for p in lora_parameters} 662 | 663 | # Define our hook, which will call the optimizer step() and zero_grad() 664 | def optimizer_hook(parameter) -> None: 665 | optimizer_dict[parameter].step() 666 | optimizer_dict[parameter].zero_grad() 667 | 668 | # Register the hook onto every trainable parameter 669 | for p in lora_parameters: 670 | p.register_post_accumulate_grad_hook(optimizer_hook) 671 | 672 | if args.warped_noise: 673 | from noise_warp.GetWarpedNoiseFromVideo import GetWarpedNoiseFromVideo 674 | get_warped_noise = GetWarpedNoiseFromVideo(raft_size="large", device=device, dtype=torch.float32) 675 | 676 | def prepare_conditions(batch): 677 | pixels, clip_embed, llama_embed, llama_mask = batch 678 | pixels = pixels.movedim(1, 2).to(device=vae.device, dtype=vae.dtype) # BFCHW -> BCFHW 679 | latents = vae.encode(pixels).latent_dist.sample() * vae.config.scaling_factor 680 | 681 | if args.skyreels_i2v: 682 | image_cond_latents = torch.zeros_like(latents) 683 | image_latents = vae.encode(pixels[:, :, 0].unsqueeze(2)).latent_dist.sample() * vae.config.scaling_factor 684 | image_cond_latents[:, :, 0] = image_latents[:, :, 0] 685 | del image_latents 686 | 687 | t_writer.add_scalar("debug/context_len", latents.shape[-3] * (latents.shape[-2] / 2) * (latents.shape[-1] / 2), global_step) 688 | t_writer.add_scalar("debug/width", pixels.shape[-1], global_step) 689 | t_writer.add_scalar("debug/height", pixels.shape[-2], global_step) 690 | t_writer.add_scalar("debug/frames", pixels.shape[-3], global_step) 691 | 692 | if args.warped_noise: 693 | noise = get_warped_noise( 694 | pixels.movedim(2, 1)[0], # BCFHW -> BFCHW -> FCHW 695 | degradation = torch.rand(1).item(), 696 | noise_channels = 16, 697 | target_latent_count = latents.shape[2], 698 | ).movedim(0, 1).unsqueeze(0).to(latents) # FCHW -> CFHW -> BCFHW 699 | else: 700 | noise = torch.randn_like(latents) 701 | 702 | # TODO: add sd3/flux timestep density sampling? 703 | sigma = torch.rand(latents.shape[0]) 704 | timesteps = torch.round(sigma * 1000).long() 705 | sigma = sigma[:, None, None, None, None].to(latents) 706 | noisy_model_input = (noise * sigma) + (latents * (1 - sigma)) 707 | 708 | if args.skyreels_i2v: 709 | noisy_model_input = torch.cat([noisy_model_input, image_cond_latents], dim=1) 710 | 711 | guidance_scale = 1.0 712 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=torch.float32, device=device) * 1000.0 713 | 714 | return { 715 | "target": (noise - latents).to(device=device), 716 | "noisy_model_input": noisy_model_input.to(device=device, dtype=torch.bfloat16), 717 | "timesteps": timesteps.to(device=device, dtype=torch.bfloat16), 718 | "llama_embed": llama_embed.to(device=device, dtype=torch.bfloat16), 719 | "llama_mask": llama_mask.to(device=device, dtype=torch.bfloat16), 720 | "clip_embed": clip_embed.to(device=device, dtype=torch.bfloat16), 721 | "guidance": guidance.to(device=device, dtype=torch.bfloat16), 722 | } 723 | 724 | def predict_loss(conditions): 725 | pred = diffusion_model( 726 | hidden_states = conditions["noisy_model_input"], 727 | timestep = conditions["timesteps"], 728 | encoder_hidden_states = conditions["llama_embed"], 729 | encoder_attention_mask = conditions["llama_mask"], 730 | pooled_projections = conditions["clip_embed"], 731 | guidance = conditions["guidance"], 732 | return_dict = False, 733 | )[0] 734 | return F.mse_loss(pred.float(), conditions["target"].float()) 735 | 736 | gc.collect() 737 | torch.cuda.empty_cache() 738 | diffusion_model.train() 739 | 740 | global_step = 0 741 | progress_bar = tqdm(range(0, args.max_train_steps)) 742 | while global_step < args.max_train_steps: 743 | for step, batch in enumerate(train_dataloader): 744 | start_step = perf_counter() 745 | with torch.inference_mode(): 746 | conditions = prepare_conditions(batch) 747 | 748 | torch.cuda.empty_cache() 749 | loss = predict_loss(conditions) 750 | t_writer.add_scalar("loss/train", loss.detach().item(), global_step) 751 | 752 | loss.backward() 753 | 754 | progress_bar.update(1) 755 | global_step += 1 756 | torch.cuda.empty_cache() 757 | t_writer.add_scalar("debug/step_time", perf_counter() - start_step, global_step) 758 | 759 | if global_step == 1 or global_step % args.val_steps == 0: 760 | with torch.inference_mode(), temp_rng(args.seed): 761 | val_loss = 0.0 762 | for step, batch in enumerate(tqdm(val_dataloader, desc="validation", leave=False)): 763 | conditions = prepare_conditions(batch) 764 | torch.cuda.empty_cache() 765 | loss = predict_loss(conditions) 766 | val_loss += loss.detach().item() 767 | torch.cuda.empty_cache() 768 | t_writer.add_scalar("loss/validation", val_loss / len(val_dataloader), global_step) 769 | progress_bar.unpause() 770 | 771 | if global_step >= args.max_train_steps or global_step % args.checkpointing_steps == 0: 772 | save_file( 773 | get_peft_model_state_dict(diffusion_model), 774 | os.path.join(checkpoint_dir, f"hyv-lora-{global_step:08}.safetensors"), 775 | ) 776 | 777 | if global_step >= args.max_train_steps: 778 | break 779 | 780 | # train 781 | # basic t2i and randn noise to start 782 | # guidance=1 783 | # uncond/caption dropout? 784 | 785 | 786 | if __name__ == "__main__": 787 | args = parse_args() 788 | 789 | if args.download_model: 790 | download_model(args) 791 | exit() 792 | 793 | if args.cache_embeddings and args.dataset != "pexels": 794 | cache_embeddings(args) 795 | exit() 796 | 797 | main(args) -------------------------------------------------------------------------------- /pipelines/pipeline_hunyuan_video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast 21 | 22 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 23 | from diffusers.loaders import HunyuanVideoLoraLoaderMixin 24 | from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel 25 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 26 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 27 | from diffusers.utils.torch_utils import randn_tensor 28 | from diffusers.video_processor import VideoProcessor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 30 | from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput 31 | 32 | 33 | if is_torch_xla_available(): 34 | import torch_xla.core.xla_model as xm 35 | 36 | XLA_AVAILABLE = True 37 | else: 38 | XLA_AVAILABLE = False 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | EXAMPLE_DOC_STRING = """ 44 | Examples: 45 | ```python 46 | >>> import torch 47 | >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 48 | >>> from diffusers.utils import export_to_video 49 | 50 | >>> model_id = "hunyuanvideo-community/HunyuanVideo" 51 | >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( 52 | ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 53 | ... ) 54 | >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 55 | >>> pipe.vae.enable_tiling() 56 | >>> pipe.to("cuda") 57 | 58 | >>> output = pipe( 59 | ... prompt="A cat walks on the grass, realistic", 60 | ... height=320, 61 | ... width=512, 62 | ... num_frames=61, 63 | ... num_inference_steps=30, 64 | ... ).frames[0] 65 | >>> export_to_video(output, "output.mp4", fps=15) 66 | ``` 67 | """ 68 | 69 | 70 | DEFAULT_PROMPT_TEMPLATE = { 71 | "template": ( 72 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 73 | "1. The main content and theme of the video." 74 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 75 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 76 | "4. background environment, light, style and atmosphere." 77 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 78 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 79 | ), 80 | "crop_start": 95, 81 | } 82 | 83 | 84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 85 | def retrieve_timesteps( 86 | scheduler, 87 | num_inference_steps: Optional[int] = None, 88 | device: Optional[Union[str, torch.device]] = None, 89 | timesteps: Optional[List[int]] = None, 90 | sigmas: Optional[List[float]] = None, 91 | **kwargs, 92 | ): 93 | r""" 94 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 95 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 96 | 97 | Args: 98 | scheduler (`SchedulerMixin`): 99 | The scheduler to get timesteps from. 100 | num_inference_steps (`int`): 101 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 102 | must be `None`. 103 | device (`str` or `torch.device`, *optional*): 104 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 105 | timesteps (`List[int]`, *optional*): 106 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 107 | `num_inference_steps` and `sigmas` must be `None`. 108 | sigmas (`List[float]`, *optional*): 109 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 110 | `num_inference_steps` and `timesteps` must be `None`. 111 | 112 | Returns: 113 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 114 | second element is the number of inference steps. 115 | """ 116 | if timesteps is not None and sigmas is not None: 117 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 118 | if timesteps is not None: 119 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 120 | if not accepts_timesteps: 121 | raise ValueError( 122 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 123 | f" timestep schedules. Please check whether you are using the correct scheduler." 124 | ) 125 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 126 | timesteps = scheduler.timesteps 127 | num_inference_steps = len(timesteps) 128 | elif sigmas is not None: 129 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 130 | if not accept_sigmas: 131 | raise ValueError( 132 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 133 | f" sigmas schedules. Please check whether you are using the correct scheduler." 134 | ) 135 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | num_inference_steps = len(timesteps) 138 | else: 139 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 140 | timesteps = scheduler.timesteps 141 | return timesteps, num_inference_steps 142 | 143 | 144 | class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): 145 | r""" 146 | Pipeline for text-to-video generation using HunyuanVideo. 147 | 148 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 149 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 150 | 151 | Args: 152 | text_encoder ([`LlamaModel`]): 153 | [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 154 | tokenizer (`LlamaTokenizer`): 155 | Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 156 | transformer ([`HunyuanVideoTransformer3DModel`]): 157 | Conditional Transformer to denoise the encoded image latents. 158 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 159 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 160 | vae ([`AutoencoderKLHunyuanVideo`]): 161 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 162 | text_encoder_2 ([`CLIPTextModel`]): 163 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 164 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 165 | tokenizer_2 (`CLIPTokenizer`): 166 | Tokenizer of class 167 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 168 | """ 169 | 170 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 171 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 172 | 173 | def __init__( 174 | self, 175 | text_encoder: LlamaModel, 176 | tokenizer: LlamaTokenizerFast, 177 | transformer: HunyuanVideoTransformer3DModel, 178 | vae: AutoencoderKLHunyuanVideo, 179 | scheduler: FlowMatchEulerDiscreteScheduler, 180 | text_encoder_2: CLIPTextModel, 181 | tokenizer_2: CLIPTokenizer, 182 | ): 183 | super().__init__() 184 | 185 | self.register_modules( 186 | vae=vae, 187 | text_encoder=text_encoder, 188 | tokenizer=tokenizer, 189 | transformer=transformer, 190 | scheduler=scheduler, 191 | text_encoder_2=text_encoder_2, 192 | tokenizer_2=tokenizer_2, 193 | ) 194 | 195 | self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 196 | self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 197 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 198 | 199 | def _get_llama_prompt_embeds( 200 | self, 201 | prompt: Union[str, List[str]], 202 | prompt_template: Dict[str, Any], 203 | num_videos_per_prompt: int = 1, 204 | device: Optional[torch.device] = None, 205 | dtype: Optional[torch.dtype] = None, 206 | max_sequence_length: int = 256, 207 | num_hidden_layers_to_skip: int = 2, 208 | ) -> Tuple[torch.Tensor, torch.Tensor]: 209 | device = device or self._execution_device 210 | dtype = dtype or self.text_encoder.dtype 211 | 212 | prompt = [prompt] if isinstance(prompt, str) else prompt 213 | batch_size = len(prompt) 214 | 215 | prompt = [prompt_template["template"].format(p) for p in prompt] 216 | 217 | crop_start = prompt_template.get("crop_start", None) 218 | if crop_start is None: 219 | prompt_template_input = self.tokenizer( 220 | prompt_template["template"], 221 | padding="max_length", 222 | return_tensors="pt", 223 | return_length=False, 224 | return_overflowing_tokens=False, 225 | return_attention_mask=False, 226 | ) 227 | crop_start = prompt_template_input["input_ids"].shape[-1] 228 | # Remove <|eot_id|> token and placeholder {} 229 | crop_start -= 2 230 | 231 | max_sequence_length += crop_start 232 | text_inputs = self.tokenizer( 233 | prompt, 234 | max_length=max_sequence_length, 235 | padding="max_length", 236 | truncation=True, 237 | return_tensors="pt", 238 | return_length=False, 239 | return_overflowing_tokens=False, 240 | return_attention_mask=True, 241 | ) 242 | text_input_ids = text_inputs.input_ids.to(device=device) 243 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) 244 | 245 | prompt_embeds = self.text_encoder( 246 | input_ids=text_input_ids, 247 | attention_mask=prompt_attention_mask, 248 | output_hidden_states=True, 249 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 250 | prompt_embeds = prompt_embeds.to(dtype=dtype) 251 | 252 | if crop_start is not None and crop_start > 0: 253 | prompt_embeds = prompt_embeds[:, crop_start:] 254 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 255 | 256 | # duplicate text embeddings for each generation per prompt, using mps friendly method 257 | _, seq_len, _ = prompt_embeds.shape 258 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 259 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 260 | prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) 261 | prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) 262 | 263 | return prompt_embeds, prompt_attention_mask 264 | 265 | def _get_clip_prompt_embeds( 266 | self, 267 | prompt: Union[str, List[str]], 268 | num_videos_per_prompt: int = 1, 269 | device: Optional[torch.device] = None, 270 | dtype: Optional[torch.dtype] = None, 271 | max_sequence_length: int = 77, 272 | ) -> torch.Tensor: 273 | device = device or self._execution_device 274 | dtype = dtype or self.text_encoder_2.dtype 275 | 276 | prompt = [prompt] if isinstance(prompt, str) else prompt 277 | batch_size = len(prompt) 278 | 279 | text_inputs = self.tokenizer_2( 280 | prompt, 281 | padding="max_length", 282 | max_length=max_sequence_length, 283 | truncation=True, 284 | return_tensors="pt", 285 | ) 286 | 287 | text_input_ids = text_inputs.input_ids 288 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 289 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 290 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) 291 | logger.warning( 292 | "The following part of your input was truncated because CLIP can only handle sequences up to" 293 | f" {max_sequence_length} tokens: {removed_text}" 294 | ) 295 | 296 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output 297 | 298 | # duplicate text embeddings for each generation per prompt, using mps friendly method 299 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) 300 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) 301 | 302 | return prompt_embeds 303 | 304 | def encode_prompt( 305 | self, 306 | prompt: Union[str, List[str]], 307 | prompt_2: Union[str, List[str]] = None, 308 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 309 | num_videos_per_prompt: int = 1, 310 | prompt_embeds: Optional[torch.Tensor] = None, 311 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 312 | prompt_attention_mask: Optional[torch.Tensor] = None, 313 | device: Optional[torch.device] = None, 314 | dtype: Optional[torch.dtype] = None, 315 | max_sequence_length: int = 256, 316 | ): 317 | if prompt_embeds is None: 318 | prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( 319 | prompt, 320 | prompt_template, 321 | num_videos_per_prompt, 322 | device=device, 323 | dtype=dtype, 324 | max_sequence_length=max_sequence_length, 325 | ) 326 | 327 | if pooled_prompt_embeds is None: 328 | if prompt_2 is None and pooled_prompt_embeds is None: 329 | prompt_2 = prompt 330 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 331 | prompt, 332 | num_videos_per_prompt, 333 | device=device, 334 | dtype=dtype, 335 | max_sequence_length=77, 336 | ) 337 | 338 | return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask 339 | 340 | def check_inputs( 341 | self, 342 | prompt, 343 | prompt_2, 344 | height, 345 | width, 346 | prompt_embeds=None, 347 | callback_on_step_end_tensor_inputs=None, 348 | prompt_template=None, 349 | ): 350 | if height % 16 != 0 or width % 16 != 0: 351 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 352 | 353 | if callback_on_step_end_tensor_inputs is not None and not all( 354 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 355 | ): 356 | raise ValueError( 357 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 358 | ) 359 | 360 | if prompt is not None and prompt_embeds is not None: 361 | raise ValueError( 362 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 363 | " only forward one of the two." 364 | ) 365 | elif prompt_2 is not None and prompt_embeds is not None: 366 | raise ValueError( 367 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 368 | " only forward one of the two." 369 | ) 370 | elif prompt is None and prompt_embeds is None: 371 | raise ValueError( 372 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 373 | ) 374 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 375 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 376 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 377 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 378 | 379 | if prompt_template is not None: 380 | if not isinstance(prompt_template, dict): 381 | raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") 382 | if "template" not in prompt_template: 383 | raise ValueError( 384 | f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" 385 | ) 386 | 387 | def prepare_latents( 388 | self, 389 | batch_size: int, 390 | num_channels_latents: 32, 391 | height: int = 720, 392 | width: int = 1280, 393 | num_frames: int = 129, 394 | dtype: Optional[torch.dtype] = None, 395 | device: Optional[torch.device] = None, 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 397 | latents: Optional[torch.Tensor] = None, 398 | ) -> torch.Tensor: 399 | if latents is not None: 400 | return latents.to(device=device, dtype=dtype) 401 | 402 | shape = ( 403 | batch_size, 404 | num_channels_latents, 405 | num_frames, 406 | int(height) // self.vae_scale_factor_spatial, 407 | int(width) // self.vae_scale_factor_spatial, 408 | ) 409 | if isinstance(generator, list) and len(generator) != batch_size: 410 | raise ValueError( 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 412 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 413 | ) 414 | 415 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 416 | return latents 417 | 418 | def enable_vae_slicing(self): 419 | r""" 420 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 421 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 422 | """ 423 | self.vae.enable_slicing() 424 | 425 | def disable_vae_slicing(self): 426 | r""" 427 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 428 | computing decoding in one step. 429 | """ 430 | self.vae.disable_slicing() 431 | 432 | def enable_vae_tiling(self): 433 | r""" 434 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 435 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 436 | processing larger images. 437 | """ 438 | self.vae.enable_tiling() 439 | 440 | def disable_vae_tiling(self): 441 | r""" 442 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 443 | computing decoding in one step. 444 | """ 445 | self.vae.disable_tiling() 446 | 447 | @property 448 | def guidance_scale(self): 449 | return self._guidance_scale 450 | 451 | @property 452 | def num_timesteps(self): 453 | return self._num_timesteps 454 | 455 | @property 456 | def attention_kwargs(self): 457 | return self._attention_kwargs 458 | 459 | @property 460 | def current_timestep(self): 461 | return self._current_timestep 462 | 463 | @property 464 | def interrupt(self): 465 | return self._interrupt 466 | 467 | @torch.no_grad() 468 | @replace_example_docstring(EXAMPLE_DOC_STRING) 469 | def __call__( 470 | self, 471 | prompt: Union[str, List[str]] = None, 472 | prompt_2: Union[str, List[str]] = None, 473 | height: int = 720, 474 | width: int = 1280, 475 | num_frames: int = 129, 476 | num_inference_steps: int = 50, 477 | sigmas: List[float] = None, 478 | guidance_scale: float = 6.0, 479 | num_videos_per_prompt: Optional[int] = 1, 480 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 481 | latents: Optional[torch.Tensor] = None, 482 | prompt_embeds: Optional[torch.Tensor] = None, 483 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 484 | prompt_attention_mask: Optional[torch.Tensor] = None, 485 | output_type: Optional[str] = "pil", 486 | return_dict: bool = True, 487 | attention_kwargs: Optional[Dict[str, Any]] = None, 488 | callback_on_step_end: Optional[ 489 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 490 | ] = None, 491 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 492 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 493 | max_sequence_length: int = 256, 494 | ): 495 | r""" 496 | The call function to the pipeline for generation. 497 | 498 | Args: 499 | prompt (`str` or `List[str]`, *optional*): 500 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 501 | instead. 502 | prompt_2 (`str` or `List[str]`, *optional*): 503 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 504 | will be used instead. 505 | height (`int`, defaults to `720`): 506 | The height in pixels of the generated image. 507 | width (`int`, defaults to `1280`): 508 | The width in pixels of the generated image. 509 | num_frames (`int`, defaults to `129`): 510 | The number of frames in the generated video. 511 | num_inference_steps (`int`, defaults to `50`): 512 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 513 | expense of slower inference. 514 | sigmas (`List[float]`, *optional*): 515 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 516 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 517 | will be used. 518 | guidance_scale (`float`, defaults to `6.0`): 519 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 520 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 521 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 522 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 523 | usually at the expense of lower image quality. Note that the only available HunyuanVideo model is 524 | CFG-distilled, which means that traditional guidance between unconditional and conditional latent is 525 | not applied. 526 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 527 | The number of images to generate per prompt. 528 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 529 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 530 | generation deterministic. 531 | latents (`torch.Tensor`, *optional*): 532 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 533 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 534 | tensor is generated by sampling using the supplied random `generator`. 535 | prompt_embeds (`torch.Tensor`, *optional*): 536 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 537 | provided, text embeddings are generated from the `prompt` input argument. 538 | output_type (`str`, *optional*, defaults to `"pil"`): 539 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 540 | return_dict (`bool`, *optional*, defaults to `True`): 541 | Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. 542 | attention_kwargs (`dict`, *optional*): 543 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 544 | `self.processor` in 545 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 546 | clip_skip (`int`, *optional*): 547 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 548 | the output of the pre-final layer will be used for computing the prompt embeddings. 549 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 550 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 551 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 552 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 553 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 554 | callback_on_step_end_tensor_inputs (`List`, *optional*): 555 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 556 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 557 | `._callback_tensor_inputs` attribute of your pipeline class. 558 | 559 | Examples: 560 | 561 | Returns: 562 | [`~HunyuanVideoPipelineOutput`] or `tuple`: 563 | If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned 564 | where the first element is a list with the generated images and the second element is a list of `bool`s 565 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 566 | """ 567 | 568 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 569 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 570 | 571 | # 1. Check inputs. Raise error if not correct 572 | self.check_inputs( 573 | prompt, 574 | prompt_2, 575 | height, 576 | width, 577 | prompt_embeds, 578 | callback_on_step_end_tensor_inputs, 579 | prompt_template, 580 | ) 581 | 582 | self._guidance_scale = guidance_scale 583 | self._attention_kwargs = attention_kwargs 584 | self._current_timestep = None 585 | self._interrupt = False 586 | 587 | device = self._execution_device 588 | 589 | # 2. Define call parameters 590 | if prompt is not None and isinstance(prompt, str): 591 | batch_size = 1 592 | elif prompt is not None and isinstance(prompt, list): 593 | batch_size = len(prompt) 594 | else: 595 | batch_size = prompt_embeds.shape[0] 596 | 597 | # 3. Encode input prompt 598 | prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( 599 | prompt=prompt, 600 | prompt_2=prompt_2, 601 | prompt_template=prompt_template, 602 | num_videos_per_prompt=num_videos_per_prompt, 603 | prompt_embeds=prompt_embeds, 604 | pooled_prompt_embeds=pooled_prompt_embeds, 605 | prompt_attention_mask=prompt_attention_mask, 606 | device=device, 607 | max_sequence_length=max_sequence_length, 608 | ) 609 | 610 | transformer_dtype = self.transformer.dtype 611 | prompt_embeds = prompt_embeds.to(transformer_dtype) 612 | prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) 613 | if pooled_prompt_embeds is not None: 614 | pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) 615 | 616 | # 4. Prepare timesteps 617 | sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas 618 | timesteps, num_inference_steps = retrieve_timesteps( 619 | self.scheduler, 620 | num_inference_steps, 621 | device, 622 | sigmas=sigmas, 623 | ) 624 | 625 | # 5. Prepare latent variables 626 | num_channels_latents = self.transformer.config.in_channels 627 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 628 | latents = self.prepare_latents( 629 | batch_size * num_videos_per_prompt, 630 | num_channels_latents, 631 | height, 632 | width, 633 | num_latent_frames, 634 | torch.float32, 635 | device, 636 | generator, 637 | latents, 638 | ) 639 | 640 | # 6. Prepare guidance condition 641 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 642 | 643 | # 7. Denoising loop 644 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 645 | self._num_timesteps = len(timesteps) 646 | 647 | with self.progress_bar(total=num_inference_steps) as progress_bar: 648 | for i, t in enumerate(timesteps): 649 | if self.interrupt: 650 | continue 651 | 652 | self._current_timestep = t 653 | latent_model_input = latents.to(transformer_dtype) 654 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 655 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 656 | 657 | noise_pred = self.transformer( 658 | hidden_states=latent_model_input, 659 | timestep=timestep, 660 | encoder_hidden_states=prompt_embeds, 661 | encoder_attention_mask=prompt_attention_mask, 662 | pooled_projections=pooled_prompt_embeds, 663 | guidance=guidance, 664 | attention_kwargs=attention_kwargs, 665 | return_dict=False, 666 | )[0] 667 | 668 | # compute the previous noisy sample x_t -> x_t-1 669 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 670 | 671 | if callback_on_step_end is not None: 672 | callback_kwargs = {} 673 | for k in callback_on_step_end_tensor_inputs: 674 | callback_kwargs[k] = locals()[k] 675 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 676 | 677 | latents = callback_outputs.pop("latents", latents) 678 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 679 | 680 | # call the callback, if provided 681 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 682 | progress_bar.update() 683 | 684 | if XLA_AVAILABLE: 685 | xm.mark_step() 686 | 687 | self._current_timestep = None 688 | 689 | if not output_type == "latent": 690 | latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor 691 | video = self.vae.decode(latents, return_dict=False)[0] 692 | video = self.video_processor.postprocess_video(video, output_type=output_type) 693 | else: 694 | video = latents 695 | 696 | # Offload all models 697 | self.maybe_free_model_hooks() 698 | 699 | if not return_dict: 700 | return (video,) 701 | 702 | return HunyuanVideoPipelineOutput(frames=video) 703 | -------------------------------------------------------------------------------- /pipelines/pipeline_skyreels_t2v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast 21 | 22 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 23 | from diffusers.loaders import HunyuanVideoLoraLoaderMixin 24 | from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel 25 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 26 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 27 | from diffusers.utils.torch_utils import randn_tensor 28 | from diffusers.video_processor import VideoProcessor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 30 | from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput 31 | 32 | 33 | if is_torch_xla_available(): 34 | import torch_xla.core.xla_model as xm 35 | 36 | XLA_AVAILABLE = True 37 | else: 38 | XLA_AVAILABLE = False 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | EXAMPLE_DOC_STRING = """ 44 | Examples: 45 | ```python 46 | >>> import torch 47 | >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 48 | >>> from diffusers.utils import export_to_video 49 | 50 | >>> model_id = "hunyuanvideo-community/HunyuanVideo" 51 | >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( 52 | ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 53 | ... ) 54 | >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 55 | >>> pipe.vae.enable_tiling() 56 | >>> pipe.to("cuda") 57 | 58 | >>> output = pipe( 59 | ... prompt="A cat walks on the grass, realistic", 60 | ... height=320, 61 | ... width=512, 62 | ... num_frames=61, 63 | ... num_inference_steps=30, 64 | ... ).frames[0] 65 | >>> export_to_video(output, "output.mp4", fps=15) 66 | ``` 67 | """ 68 | 69 | 70 | DEFAULT_PROMPT_TEMPLATE = { 71 | "template": ( 72 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 73 | "1. The main content and theme of the video." 74 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 75 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 76 | "4. background environment, light, style and atmosphere." 77 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 78 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 79 | ), 80 | "crop_start": 95, 81 | } 82 | 83 | 84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 85 | def retrieve_timesteps( 86 | scheduler, 87 | num_inference_steps: Optional[int] = None, 88 | device: Optional[Union[str, torch.device]] = None, 89 | timesteps: Optional[List[int]] = None, 90 | sigmas: Optional[List[float]] = None, 91 | **kwargs, 92 | ): 93 | r""" 94 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 95 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 96 | 97 | Args: 98 | scheduler (`SchedulerMixin`): 99 | The scheduler to get timesteps from. 100 | num_inference_steps (`int`): 101 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 102 | must be `None`. 103 | device (`str` or `torch.device`, *optional*): 104 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 105 | timesteps (`List[int]`, *optional*): 106 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 107 | `num_inference_steps` and `sigmas` must be `None`. 108 | sigmas (`List[float]`, *optional*): 109 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 110 | `num_inference_steps` and `timesteps` must be `None`. 111 | 112 | Returns: 113 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 114 | second element is the number of inference steps. 115 | """ 116 | if timesteps is not None and sigmas is not None: 117 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 118 | if timesteps is not None: 119 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 120 | if not accepts_timesteps: 121 | raise ValueError( 122 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 123 | f" timestep schedules. Please check whether you are using the correct scheduler." 124 | ) 125 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 126 | timesteps = scheduler.timesteps 127 | num_inference_steps = len(timesteps) 128 | elif sigmas is not None: 129 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 130 | if not accept_sigmas: 131 | raise ValueError( 132 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 133 | f" sigmas schedules. Please check whether you are using the correct scheduler." 134 | ) 135 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | num_inference_steps = len(timesteps) 138 | else: 139 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 140 | timesteps = scheduler.timesteps 141 | return timesteps, num_inference_steps 142 | 143 | 144 | class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): 145 | r""" 146 | Pipeline for text-to-video generation using HunyuanVideo. 147 | 148 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 149 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 150 | 151 | Args: 152 | text_encoder ([`LlamaModel`]): 153 | [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 154 | tokenizer (`LlamaTokenizer`): 155 | Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 156 | transformer ([`HunyuanVideoTransformer3DModel`]): 157 | Conditional Transformer to denoise the encoded image latents. 158 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 159 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 160 | vae ([`AutoencoderKLHunyuanVideo`]): 161 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 162 | text_encoder_2 ([`CLIPTextModel`]): 163 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 164 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 165 | tokenizer_2 (`CLIPTokenizer`): 166 | Tokenizer of class 167 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 168 | """ 169 | 170 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 171 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 172 | 173 | def __init__( 174 | self, 175 | text_encoder: LlamaModel, 176 | tokenizer: LlamaTokenizerFast, 177 | transformer: HunyuanVideoTransformer3DModel, 178 | vae: AutoencoderKLHunyuanVideo, 179 | scheduler: FlowMatchEulerDiscreteScheduler, 180 | text_encoder_2: CLIPTextModel, 181 | tokenizer_2: CLIPTokenizer, 182 | ): 183 | super().__init__() 184 | 185 | self.register_modules( 186 | vae=vae, 187 | text_encoder=text_encoder, 188 | tokenizer=tokenizer, 189 | transformer=transformer, 190 | scheduler=scheduler, 191 | text_encoder_2=text_encoder_2, 192 | tokenizer_2=tokenizer_2, 193 | ) 194 | 195 | self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 196 | self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 197 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 198 | 199 | def _get_llama_prompt_embeds( 200 | self, 201 | prompt: Union[str, List[str]], 202 | prompt_template: Dict[str, Any], 203 | num_videos_per_prompt: int = 1, 204 | device: Optional[torch.device] = None, 205 | dtype: Optional[torch.dtype] = None, 206 | max_sequence_length: int = 256, 207 | num_hidden_layers_to_skip: int = 2, 208 | ) -> Tuple[torch.Tensor, torch.Tensor]: 209 | device = device or self._execution_device 210 | dtype = dtype or self.text_encoder.dtype 211 | 212 | prompt = [prompt] if isinstance(prompt, str) else prompt 213 | batch_size = len(prompt) 214 | 215 | prompt = [prompt_template["template"].format(p) for p in prompt] 216 | 217 | crop_start = prompt_template.get("crop_start", None) 218 | if crop_start is None: 219 | prompt_template_input = self.tokenizer( 220 | prompt_template["template"], 221 | padding="max_length", 222 | return_tensors="pt", 223 | return_length=False, 224 | return_overflowing_tokens=False, 225 | return_attention_mask=False, 226 | ) 227 | crop_start = prompt_template_input["input_ids"].shape[-1] 228 | # Remove <|eot_id|> token and placeholder {} 229 | crop_start -= 2 230 | 231 | max_sequence_length += crop_start 232 | text_inputs = self.tokenizer( 233 | prompt, 234 | max_length=max_sequence_length, 235 | padding="max_length", 236 | truncation=True, 237 | return_tensors="pt", 238 | return_length=False, 239 | return_overflowing_tokens=False, 240 | return_attention_mask=True, 241 | ) 242 | text_input_ids = text_inputs.input_ids.to(device=device) 243 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) 244 | 245 | prompt_embeds = self.text_encoder( 246 | input_ids=text_input_ids, 247 | attention_mask=prompt_attention_mask, 248 | output_hidden_states=True, 249 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 250 | prompt_embeds = prompt_embeds.to(dtype=dtype) 251 | 252 | if crop_start is not None and crop_start > 0: 253 | prompt_embeds = prompt_embeds[:, crop_start:] 254 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 255 | 256 | # duplicate text embeddings for each generation per prompt, using mps friendly method 257 | _, seq_len, _ = prompt_embeds.shape 258 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 259 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 260 | prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) 261 | prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) 262 | 263 | return prompt_embeds, prompt_attention_mask 264 | 265 | def _get_clip_prompt_embeds( 266 | self, 267 | prompt: Union[str, List[str]], 268 | num_videos_per_prompt: int = 1, 269 | device: Optional[torch.device] = None, 270 | dtype: Optional[torch.dtype] = None, 271 | max_sequence_length: int = 77, 272 | ) -> torch.Tensor: 273 | device = device or self._execution_device 274 | dtype = dtype or self.text_encoder_2.dtype 275 | 276 | prompt = [prompt] if isinstance(prompt, str) else prompt 277 | batch_size = len(prompt) 278 | 279 | text_inputs = self.tokenizer_2( 280 | prompt, 281 | padding="max_length", 282 | max_length=max_sequence_length, 283 | truncation=True, 284 | return_tensors="pt", 285 | ) 286 | 287 | text_input_ids = text_inputs.input_ids 288 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 289 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 290 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) 291 | logger.warning( 292 | "The following part of your input was truncated because CLIP can only handle sequences up to" 293 | f" {max_sequence_length} tokens: {removed_text}" 294 | ) 295 | 296 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output 297 | 298 | # duplicate text embeddings for each generation per prompt, using mps friendly method 299 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) 300 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) 301 | 302 | return prompt_embeds 303 | 304 | def encode_prompt( 305 | self, 306 | prompt: Union[str, List[str]], 307 | prompt_2: Union[str, List[str]] = None, 308 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 309 | num_videos_per_prompt: int = 1, 310 | prompt_embeds: Optional[torch.Tensor] = None, 311 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 312 | prompt_attention_mask: Optional[torch.Tensor] = None, 313 | device: Optional[torch.device] = None, 314 | dtype: Optional[torch.dtype] = None, 315 | max_sequence_length: int = 256, 316 | ): 317 | if prompt_embeds is None: 318 | prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( 319 | prompt, 320 | prompt_template, 321 | num_videos_per_prompt, 322 | device=device, 323 | dtype=dtype, 324 | max_sequence_length=max_sequence_length, 325 | ) 326 | 327 | if pooled_prompt_embeds is None: 328 | if prompt_2 is None and pooled_prompt_embeds is None: 329 | prompt_2 = prompt 330 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 331 | prompt, 332 | num_videos_per_prompt, 333 | device=device, 334 | dtype=dtype, 335 | max_sequence_length=77, 336 | ) 337 | 338 | return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask 339 | 340 | def check_inputs( 341 | self, 342 | prompt, 343 | prompt_2, 344 | height, 345 | width, 346 | prompt_embeds=None, 347 | callback_on_step_end_tensor_inputs=None, 348 | prompt_template=None, 349 | ): 350 | if height % 16 != 0 or width % 16 != 0: 351 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 352 | 353 | if callback_on_step_end_tensor_inputs is not None and not all( 354 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 355 | ): 356 | raise ValueError( 357 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 358 | ) 359 | 360 | if prompt is not None and prompt_embeds is not None: 361 | raise ValueError( 362 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 363 | " only forward one of the two." 364 | ) 365 | elif prompt_2 is not None and prompt_embeds is not None: 366 | raise ValueError( 367 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 368 | " only forward one of the two." 369 | ) 370 | elif prompt is None and prompt_embeds is None: 371 | raise ValueError( 372 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 373 | ) 374 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 375 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 376 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 377 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 378 | 379 | if prompt_template is not None: 380 | if not isinstance(prompt_template, dict): 381 | raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") 382 | if "template" not in prompt_template: 383 | raise ValueError( 384 | f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" 385 | ) 386 | 387 | def prepare_latents( 388 | self, 389 | batch_size: int, 390 | num_channels_latents: 32, 391 | height: int = 720, 392 | width: int = 1280, 393 | num_frames: int = 129, 394 | dtype: Optional[torch.dtype] = None, 395 | device: Optional[torch.device] = None, 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 397 | latents: Optional[torch.Tensor] = None, 398 | ) -> torch.Tensor: 399 | if latents is not None: 400 | return latents.to(device=device, dtype=dtype) 401 | 402 | shape = ( 403 | batch_size, 404 | num_channels_latents, 405 | num_frames, 406 | int(height) // self.vae_scale_factor_spatial, 407 | int(width) // self.vae_scale_factor_spatial, 408 | ) 409 | if isinstance(generator, list) and len(generator) != batch_size: 410 | raise ValueError( 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 412 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 413 | ) 414 | 415 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 416 | return latents 417 | 418 | def enable_vae_slicing(self): 419 | r""" 420 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 421 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 422 | """ 423 | self.vae.enable_slicing() 424 | 425 | def disable_vae_slicing(self): 426 | r""" 427 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 428 | computing decoding in one step. 429 | """ 430 | self.vae.disable_slicing() 431 | 432 | def enable_vae_tiling(self): 433 | r""" 434 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 435 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 436 | processing larger images. 437 | """ 438 | self.vae.enable_tiling() 439 | 440 | def disable_vae_tiling(self): 441 | r""" 442 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 443 | computing decoding in one step. 444 | """ 445 | self.vae.disable_tiling() 446 | 447 | @property 448 | def guidance_scale(self): 449 | return self._guidance_scale 450 | 451 | @property 452 | def num_timesteps(self): 453 | return self._num_timesteps 454 | 455 | @property 456 | def attention_kwargs(self): 457 | return self._attention_kwargs 458 | 459 | @property 460 | def current_timestep(self): 461 | return self._current_timestep 462 | 463 | @property 464 | def interrupt(self): 465 | return self._interrupt 466 | 467 | @torch.no_grad() 468 | @replace_example_docstring(EXAMPLE_DOC_STRING) 469 | def __call__( 470 | self, 471 | prompt: Union[str, List[str]] = None, 472 | prompt_2: Union[str, List[str]] = None, 473 | height: int = 720, 474 | width: int = 1280, 475 | num_frames: int = 129, 476 | num_inference_steps: int = 50, 477 | sigmas: List[float] = None, 478 | guidance_scale: float = 1.0, 479 | cfg_scale: float = 6.0, 480 | cfg_steps: int = 5, 481 | num_videos_per_prompt: Optional[int] = 1, 482 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 483 | latents: Optional[torch.Tensor] = None, 484 | prompt_embeds: Optional[torch.Tensor] = None, 485 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 486 | prompt_attention_mask: Optional[torch.Tensor] = None, 487 | output_type: Optional[str] = "pil", 488 | return_dict: bool = True, 489 | attention_kwargs: Optional[Dict[str, Any]] = None, 490 | callback_on_step_end: Optional[ 491 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 492 | ] = None, 493 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 494 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 495 | max_sequence_length: int = 256, 496 | ): 497 | r""" 498 | The call function to the pipeline for generation. 499 | 500 | Args: 501 | prompt (`str` or `List[str]`, *optional*): 502 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 503 | instead. 504 | prompt_2 (`str` or `List[str]`, *optional*): 505 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 506 | will be used instead. 507 | height (`int`, defaults to `720`): 508 | The height in pixels of the generated image. 509 | width (`int`, defaults to `1280`): 510 | The width in pixels of the generated image. 511 | num_frames (`int`, defaults to `129`): 512 | The number of frames in the generated video. 513 | num_inference_steps (`int`, defaults to `50`): 514 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 515 | expense of slower inference. 516 | sigmas (`List[float]`, *optional*): 517 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 518 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 519 | will be used. 520 | guidance_scale (`float`, defaults to `6.0`): 521 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 522 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 523 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 524 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 525 | usually at the expense of lower image quality. Note that the only available HunyuanVideo model is 526 | CFG-distilled, which means that traditional guidance between unconditional and conditional latent is 527 | not applied. 528 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 529 | The number of images to generate per prompt. 530 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 531 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 532 | generation deterministic. 533 | latents (`torch.Tensor`, *optional*): 534 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 535 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 536 | tensor is generated by sampling using the supplied random `generator`. 537 | prompt_embeds (`torch.Tensor`, *optional*): 538 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 539 | provided, text embeddings are generated from the `prompt` input argument. 540 | output_type (`str`, *optional*, defaults to `"pil"`): 541 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 542 | return_dict (`bool`, *optional*, defaults to `True`): 543 | Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. 544 | attention_kwargs (`dict`, *optional*): 545 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 546 | `self.processor` in 547 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 548 | clip_skip (`int`, *optional*): 549 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 550 | the output of the pre-final layer will be used for computing the prompt embeddings. 551 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 552 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 553 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 554 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 555 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 556 | callback_on_step_end_tensor_inputs (`List`, *optional*): 557 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 558 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 559 | `._callback_tensor_inputs` attribute of your pipeline class. 560 | 561 | Examples: 562 | 563 | Returns: 564 | [`~HunyuanVideoPipelineOutput`] or `tuple`: 565 | If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned 566 | where the first element is a list with the generated images and the second element is a list of `bool`s 567 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 568 | """ 569 | 570 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 571 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 572 | 573 | # 1. Check inputs. Raise error if not correct 574 | self.check_inputs( 575 | prompt, 576 | prompt_2, 577 | height, 578 | width, 579 | prompt_embeds, 580 | callback_on_step_end_tensor_inputs, 581 | prompt_template, 582 | ) 583 | 584 | self._guidance_scale = guidance_scale 585 | self._attention_kwargs = attention_kwargs 586 | self._current_timestep = None 587 | self._interrupt = False 588 | 589 | device = self._execution_device 590 | 591 | # 2. Define call parameters 592 | if prompt is not None and isinstance(prompt, str): 593 | batch_size = 1 594 | elif prompt is not None and isinstance(prompt, list): 595 | batch_size = len(prompt) 596 | else: 597 | batch_size = prompt_embeds.shape[0] 598 | 599 | # 3. Encode input prompt 600 | prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( 601 | prompt=prompt, 602 | prompt_2=prompt_2, 603 | prompt_template=prompt_template, 604 | num_videos_per_prompt=num_videos_per_prompt, 605 | prompt_embeds=prompt_embeds, 606 | pooled_prompt_embeds=pooled_prompt_embeds, 607 | prompt_attention_mask=prompt_attention_mask, 608 | device=device, 609 | max_sequence_length=max_sequence_length, 610 | ) 611 | 612 | neg_prompt_embeds, neg_pooled_prompt_embeds, neg_prompt_attention_mask = self.encode_prompt( 613 | prompt="", 614 | prompt_2=None, 615 | prompt_template=prompt_template, 616 | num_videos_per_prompt=num_videos_per_prompt, 617 | prompt_embeds=prompt_embeds, 618 | pooled_prompt_embeds=pooled_prompt_embeds, 619 | prompt_attention_mask=prompt_attention_mask, 620 | device=device, 621 | max_sequence_length=max_sequence_length, 622 | ) 623 | 624 | transformer_dtype = self.transformer.dtype 625 | prompt_embeds = prompt_embeds.to(transformer_dtype) 626 | neg_prompt_embeds = neg_prompt_embeds.to(transformer_dtype) 627 | prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) 628 | neg_prompt_attention_mask = neg_prompt_attention_mask.to(transformer_dtype) 629 | if pooled_prompt_embeds is not None: 630 | pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) 631 | neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.to(transformer_dtype) 632 | 633 | # 4. Prepare timesteps 634 | sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas 635 | timesteps, num_inference_steps = retrieve_timesteps( 636 | self.scheduler, 637 | num_inference_steps, 638 | device, 639 | sigmas=sigmas, 640 | ) 641 | 642 | # 5. Prepare latent variables 643 | num_channels_latents = self.transformer.config.in_channels 644 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 645 | latents = self.prepare_latents( 646 | batch_size * num_videos_per_prompt, 647 | num_channels_latents, 648 | height, 649 | width, 650 | num_latent_frames, 651 | torch.float32, 652 | device, 653 | generator, 654 | latents, 655 | ) 656 | 657 | # 6. Prepare guidance condition 658 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 659 | 660 | # 7. Denoising loop 661 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 662 | self._num_timesteps = len(timesteps) 663 | 664 | with self.progress_bar(total=num_inference_steps) as progress_bar: 665 | for i, t in enumerate(timesteps): 666 | if self.interrupt: 667 | continue 668 | 669 | self._current_timestep = t 670 | latent_model_input = latents.to(transformer_dtype) 671 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 672 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 673 | 674 | noise_pred = self.transformer( 675 | hidden_states=latent_model_input, 676 | timestep=timestep, 677 | encoder_hidden_states=prompt_embeds, 678 | encoder_attention_mask=prompt_attention_mask, 679 | pooled_projections=pooled_prompt_embeds, 680 | guidance=guidance, 681 | attention_kwargs=attention_kwargs, 682 | return_dict=False, 683 | )[0] 684 | 685 | if i < cfg_steps: 686 | noise_pred_neg = self.transformer( 687 | hidden_states=latent_model_input, 688 | timestep=timestep, 689 | encoder_hidden_states=neg_prompt_embeds, 690 | encoder_attention_mask=neg_prompt_attention_mask, 691 | pooled_projections=neg_pooled_prompt_embeds, 692 | guidance=guidance, 693 | attention_kwargs=attention_kwargs, 694 | return_dict=False, 695 | )[0] 696 | 697 | # CFG 698 | noise_pred = noise_pred_neg + cfg_scale * (noise_pred - noise_pred_neg) 699 | 700 | # compute the previous noisy sample x_t -> x_t-1 701 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 702 | 703 | if callback_on_step_end is not None: 704 | callback_kwargs = {} 705 | for k in callback_on_step_end_tensor_inputs: 706 | callback_kwargs[k] = locals()[k] 707 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 708 | 709 | latents = callback_outputs.pop("latents", latents) 710 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 711 | 712 | # call the callback, if provided 713 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 714 | progress_bar.update() 715 | 716 | if XLA_AVAILABLE: 717 | xm.mark_step() 718 | 719 | self._current_timestep = None 720 | 721 | if not output_type == "latent": 722 | latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor 723 | video = self.vae.decode(latents, return_dict=False)[0] 724 | video = self.video_processor.postprocess_video(video, output_type=output_type) 725 | else: 726 | video = latents 727 | 728 | # Offload all models 729 | self.maybe_free_model_hooks() 730 | 731 | if not return_dict: 732 | return (video,) 733 | 734 | return HunyuanVideoPipelineOutput(frames=video) 735 | -------------------------------------------------------------------------------- /pipelines/pipeline_skyreels_i2v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast 21 | 22 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 23 | from diffusers.loaders import HunyuanVideoLoraLoaderMixin 24 | from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel 25 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 26 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 27 | from diffusers.utils.torch_utils import randn_tensor 28 | from diffusers.video_processor import VideoProcessor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 30 | from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput 31 | 32 | 33 | if is_torch_xla_available(): 34 | import torch_xla.core.xla_model as xm 35 | 36 | XLA_AVAILABLE = True 37 | else: 38 | XLA_AVAILABLE = False 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | EXAMPLE_DOC_STRING = """ 44 | Examples: 45 | ```python 46 | >>> import torch 47 | >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 48 | >>> from diffusers.utils import export_to_video 49 | 50 | >>> model_id = "hunyuanvideo-community/HunyuanVideo" 51 | >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( 52 | ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 53 | ... ) 54 | >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 55 | >>> pipe.vae.enable_tiling() 56 | >>> pipe.to("cuda") 57 | 58 | >>> output = pipe( 59 | ... prompt="A cat walks on the grass, realistic", 60 | ... height=320, 61 | ... width=512, 62 | ... num_frames=61, 63 | ... num_inference_steps=30, 64 | ... ).frames[0] 65 | >>> export_to_video(output, "output.mp4", fps=15) 66 | ``` 67 | """ 68 | 69 | 70 | DEFAULT_PROMPT_TEMPLATE = { 71 | "template": ( 72 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 73 | "1. The main content and theme of the video." 74 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 75 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 76 | "4. background environment, light, style and atmosphere." 77 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 78 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 79 | ), 80 | "crop_start": 95, 81 | } 82 | 83 | 84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 85 | def retrieve_timesteps( 86 | scheduler, 87 | num_inference_steps: Optional[int] = None, 88 | device: Optional[Union[str, torch.device]] = None, 89 | timesteps: Optional[List[int]] = None, 90 | sigmas: Optional[List[float]] = None, 91 | **kwargs, 92 | ): 93 | r""" 94 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 95 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 96 | 97 | Args: 98 | scheduler (`SchedulerMixin`): 99 | The scheduler to get timesteps from. 100 | num_inference_steps (`int`): 101 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 102 | must be `None`. 103 | device (`str` or `torch.device`, *optional*): 104 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 105 | timesteps (`List[int]`, *optional*): 106 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 107 | `num_inference_steps` and `sigmas` must be `None`. 108 | sigmas (`List[float]`, *optional*): 109 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 110 | `num_inference_steps` and `timesteps` must be `None`. 111 | 112 | Returns: 113 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 114 | second element is the number of inference steps. 115 | """ 116 | if timesteps is not None and sigmas is not None: 117 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 118 | if timesteps is not None: 119 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 120 | if not accepts_timesteps: 121 | raise ValueError( 122 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 123 | f" timestep schedules. Please check whether you are using the correct scheduler." 124 | ) 125 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 126 | timesteps = scheduler.timesteps 127 | num_inference_steps = len(timesteps) 128 | elif sigmas is not None: 129 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 130 | if not accept_sigmas: 131 | raise ValueError( 132 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 133 | f" sigmas schedules. Please check whether you are using the correct scheduler." 134 | ) 135 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | num_inference_steps = len(timesteps) 138 | else: 139 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 140 | timesteps = scheduler.timesteps 141 | return timesteps, num_inference_steps 142 | 143 | 144 | class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): 145 | r""" 146 | Pipeline for text-to-video generation using HunyuanVideo. 147 | 148 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 149 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 150 | 151 | Args: 152 | text_encoder ([`LlamaModel`]): 153 | [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 154 | tokenizer (`LlamaTokenizer`): 155 | Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). 156 | transformer ([`HunyuanVideoTransformer3DModel`]): 157 | Conditional Transformer to denoise the encoded image latents. 158 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 159 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 160 | vae ([`AutoencoderKLHunyuanVideo`]): 161 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 162 | text_encoder_2 ([`CLIPTextModel`]): 163 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 164 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 165 | tokenizer_2 (`CLIPTokenizer`): 166 | Tokenizer of class 167 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 168 | """ 169 | 170 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 171 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 172 | 173 | def __init__( 174 | self, 175 | text_encoder: LlamaModel, 176 | tokenizer: LlamaTokenizerFast, 177 | transformer: HunyuanVideoTransformer3DModel, 178 | vae: AutoencoderKLHunyuanVideo, 179 | scheduler: FlowMatchEulerDiscreteScheduler, 180 | text_encoder_2: CLIPTextModel, 181 | tokenizer_2: CLIPTokenizer, 182 | ): 183 | super().__init__() 184 | 185 | self.register_modules( 186 | vae=vae, 187 | text_encoder=text_encoder, 188 | tokenizer=tokenizer, 189 | transformer=transformer, 190 | scheduler=scheduler, 191 | text_encoder_2=text_encoder_2, 192 | tokenizer_2=tokenizer_2, 193 | ) 194 | 195 | self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 196 | self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 197 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 198 | 199 | def _get_llama_prompt_embeds( 200 | self, 201 | prompt: Union[str, List[str]], 202 | prompt_template: Dict[str, Any], 203 | num_videos_per_prompt: int = 1, 204 | device: Optional[torch.device] = None, 205 | dtype: Optional[torch.dtype] = None, 206 | max_sequence_length: int = 256, 207 | num_hidden_layers_to_skip: int = 2, 208 | ) -> Tuple[torch.Tensor, torch.Tensor]: 209 | device = device or self._execution_device 210 | dtype = dtype or self.text_encoder.dtype 211 | 212 | prompt = [prompt] if isinstance(prompt, str) else prompt 213 | batch_size = len(prompt) 214 | 215 | prompt = [prompt_template["template"].format(p) for p in prompt] 216 | 217 | crop_start = prompt_template.get("crop_start", None) 218 | if crop_start is None: 219 | prompt_template_input = self.tokenizer( 220 | prompt_template["template"], 221 | padding="max_length", 222 | return_tensors="pt", 223 | return_length=False, 224 | return_overflowing_tokens=False, 225 | return_attention_mask=False, 226 | ) 227 | crop_start = prompt_template_input["input_ids"].shape[-1] 228 | # Remove <|eot_id|> token and placeholder {} 229 | crop_start -= 2 230 | 231 | max_sequence_length += crop_start 232 | text_inputs = self.tokenizer( 233 | prompt, 234 | max_length=max_sequence_length, 235 | padding="max_length", 236 | truncation=True, 237 | return_tensors="pt", 238 | return_length=False, 239 | return_overflowing_tokens=False, 240 | return_attention_mask=True, 241 | ) 242 | text_input_ids = text_inputs.input_ids.to(device=device) 243 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) 244 | 245 | prompt_embeds = self.text_encoder( 246 | input_ids=text_input_ids, 247 | attention_mask=prompt_attention_mask, 248 | output_hidden_states=True, 249 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 250 | prompt_embeds = prompt_embeds.to(dtype=dtype) 251 | 252 | if crop_start is not None and crop_start > 0: 253 | prompt_embeds = prompt_embeds[:, crop_start:] 254 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 255 | 256 | # duplicate text embeddings for each generation per prompt, using mps friendly method 257 | _, seq_len, _ = prompt_embeds.shape 258 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 259 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 260 | prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) 261 | prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) 262 | 263 | return prompt_embeds, prompt_attention_mask 264 | 265 | def _get_clip_prompt_embeds( 266 | self, 267 | prompt: Union[str, List[str]], 268 | num_videos_per_prompt: int = 1, 269 | device: Optional[torch.device] = None, 270 | dtype: Optional[torch.dtype] = None, 271 | max_sequence_length: int = 77, 272 | ) -> torch.Tensor: 273 | device = device or self._execution_device 274 | dtype = dtype or self.text_encoder_2.dtype 275 | 276 | prompt = [prompt] if isinstance(prompt, str) else prompt 277 | batch_size = len(prompt) 278 | 279 | text_inputs = self.tokenizer_2( 280 | prompt, 281 | padding="max_length", 282 | max_length=max_sequence_length, 283 | truncation=True, 284 | return_tensors="pt", 285 | ) 286 | 287 | text_input_ids = text_inputs.input_ids 288 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 289 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 290 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) 291 | logger.warning( 292 | "The following part of your input was truncated because CLIP can only handle sequences up to" 293 | f" {max_sequence_length} tokens: {removed_text}" 294 | ) 295 | 296 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output 297 | 298 | # duplicate text embeddings for each generation per prompt, using mps friendly method 299 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) 300 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) 301 | 302 | return prompt_embeds 303 | 304 | def encode_prompt( 305 | self, 306 | prompt: Union[str, List[str]], 307 | prompt_2: Union[str, List[str]] = None, 308 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 309 | num_videos_per_prompt: int = 1, 310 | prompt_embeds: Optional[torch.Tensor] = None, 311 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 312 | prompt_attention_mask: Optional[torch.Tensor] = None, 313 | device: Optional[torch.device] = None, 314 | dtype: Optional[torch.dtype] = None, 315 | max_sequence_length: int = 256, 316 | ): 317 | if prompt_embeds is None: 318 | prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( 319 | prompt, 320 | prompt_template, 321 | num_videos_per_prompt, 322 | device=device, 323 | dtype=dtype, 324 | max_sequence_length=max_sequence_length, 325 | ) 326 | 327 | if pooled_prompt_embeds is None: 328 | if prompt_2 is None and pooled_prompt_embeds is None: 329 | prompt_2 = prompt 330 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 331 | prompt, 332 | num_videos_per_prompt, 333 | device=device, 334 | dtype=dtype, 335 | max_sequence_length=77, 336 | ) 337 | 338 | return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask 339 | 340 | def check_inputs( 341 | self, 342 | prompt, 343 | prompt_2, 344 | height, 345 | width, 346 | prompt_embeds=None, 347 | callback_on_step_end_tensor_inputs=None, 348 | prompt_template=None, 349 | ): 350 | if height % 16 != 0 or width % 16 != 0: 351 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 352 | 353 | if callback_on_step_end_tensor_inputs is not None and not all( 354 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 355 | ): 356 | raise ValueError( 357 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 358 | ) 359 | 360 | if prompt is not None and prompt_embeds is not None: 361 | raise ValueError( 362 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 363 | " only forward one of the two." 364 | ) 365 | elif prompt_2 is not None and prompt_embeds is not None: 366 | raise ValueError( 367 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 368 | " only forward one of the two." 369 | ) 370 | elif prompt is None and prompt_embeds is None: 371 | raise ValueError( 372 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 373 | ) 374 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 375 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 376 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 377 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 378 | 379 | if prompt_template is not None: 380 | if not isinstance(prompt_template, dict): 381 | raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") 382 | if "template" not in prompt_template: 383 | raise ValueError( 384 | f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" 385 | ) 386 | 387 | def prepare_latents( 388 | self, 389 | batch_size: int, 390 | num_channels_latents: 32, 391 | height: int = 720, 392 | width: int = 1280, 393 | num_frames: int = 129, 394 | dtype: Optional[torch.dtype] = None, 395 | device: Optional[torch.device] = None, 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 397 | latents: Optional[torch.Tensor] = None, 398 | ) -> torch.Tensor: 399 | if latents is not None: 400 | return latents.to(device=device, dtype=dtype) 401 | 402 | shape = ( 403 | batch_size, 404 | num_channels_latents, 405 | num_frames, 406 | int(height) // self.vae_scale_factor_spatial, 407 | int(width) // self.vae_scale_factor_spatial, 408 | ) 409 | if isinstance(generator, list) and len(generator) != batch_size: 410 | raise ValueError( 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 412 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 413 | ) 414 | 415 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 416 | return latents 417 | 418 | def enable_vae_slicing(self): 419 | r""" 420 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 421 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 422 | """ 423 | self.vae.enable_slicing() 424 | 425 | def disable_vae_slicing(self): 426 | r""" 427 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 428 | computing decoding in one step. 429 | """ 430 | self.vae.disable_slicing() 431 | 432 | def enable_vae_tiling(self): 433 | r""" 434 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 435 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 436 | processing larger images. 437 | """ 438 | self.vae.enable_tiling() 439 | 440 | def disable_vae_tiling(self): 441 | r""" 442 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 443 | computing decoding in one step. 444 | """ 445 | self.vae.disable_tiling() 446 | 447 | @property 448 | def guidance_scale(self): 449 | return self._guidance_scale 450 | 451 | @property 452 | def num_timesteps(self): 453 | return self._num_timesteps 454 | 455 | @property 456 | def attention_kwargs(self): 457 | return self._attention_kwargs 458 | 459 | @property 460 | def current_timestep(self): 461 | return self._current_timestep 462 | 463 | @property 464 | def interrupt(self): 465 | return self._interrupt 466 | 467 | @torch.no_grad() 468 | @replace_example_docstring(EXAMPLE_DOC_STRING) 469 | def __call__( 470 | self, 471 | image, 472 | prompt: Union[str, List[str]] = None, 473 | prompt_2: Union[str, List[str]] = None, 474 | height: int = 720, 475 | width: int = 1280, 476 | num_frames: int = 129, 477 | num_inference_steps: int = 50, 478 | sigmas: List[float] = None, 479 | guidance_scale: float = 1.0, 480 | cfg_scale: float = 6.0, 481 | cfg_steps: int = 5, 482 | num_videos_per_prompt: Optional[int] = 1, 483 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 484 | latents: Optional[torch.Tensor] = None, 485 | prompt_embeds: Optional[torch.Tensor] = None, 486 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 487 | prompt_attention_mask: Optional[torch.Tensor] = None, 488 | output_type: Optional[str] = "pil", 489 | return_dict: bool = True, 490 | attention_kwargs: Optional[Dict[str, Any]] = None, 491 | callback_on_step_end: Optional[ 492 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 493 | ] = None, 494 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 495 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 496 | max_sequence_length: int = 256, 497 | ): 498 | r""" 499 | The call function to the pipeline for generation. 500 | 501 | Args: 502 | prompt (`str` or `List[str]`, *optional*): 503 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 504 | instead. 505 | prompt_2 (`str` or `List[str]`, *optional*): 506 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 507 | will be used instead. 508 | height (`int`, defaults to `720`): 509 | The height in pixels of the generated image. 510 | width (`int`, defaults to `1280`): 511 | The width in pixels of the generated image. 512 | num_frames (`int`, defaults to `129`): 513 | The number of frames in the generated video. 514 | num_inference_steps (`int`, defaults to `50`): 515 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 516 | expense of slower inference. 517 | sigmas (`List[float]`, *optional*): 518 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 519 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 520 | will be used. 521 | guidance_scale (`float`, defaults to `6.0`): 522 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 523 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 524 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 525 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 526 | usually at the expense of lower image quality. Note that the only available HunyuanVideo model is 527 | CFG-distilled, which means that traditional guidance between unconditional and conditional latent is 528 | not applied. 529 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 530 | The number of images to generate per prompt. 531 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 532 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 533 | generation deterministic. 534 | latents (`torch.Tensor`, *optional*): 535 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 536 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 537 | tensor is generated by sampling using the supplied random `generator`. 538 | prompt_embeds (`torch.Tensor`, *optional*): 539 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 540 | provided, text embeddings are generated from the `prompt` input argument. 541 | output_type (`str`, *optional*, defaults to `"pil"`): 542 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 543 | return_dict (`bool`, *optional*, defaults to `True`): 544 | Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. 545 | attention_kwargs (`dict`, *optional*): 546 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 547 | `self.processor` in 548 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 549 | clip_skip (`int`, *optional*): 550 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 551 | the output of the pre-final layer will be used for computing the prompt embeddings. 552 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 553 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 554 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 555 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 556 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 557 | callback_on_step_end_tensor_inputs (`List`, *optional*): 558 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 559 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 560 | `._callback_tensor_inputs` attribute of your pipeline class. 561 | 562 | Examples: 563 | 564 | Returns: 565 | [`~HunyuanVideoPipelineOutput`] or `tuple`: 566 | If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned 567 | where the first element is a list with the generated images and the second element is a list of `bool`s 568 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 569 | """ 570 | 571 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 572 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 573 | 574 | # 1. Check inputs. Raise error if not correct 575 | self.check_inputs( 576 | prompt, 577 | prompt_2, 578 | height, 579 | width, 580 | prompt_embeds, 581 | callback_on_step_end_tensor_inputs, 582 | prompt_template, 583 | ) 584 | 585 | self._guidance_scale = guidance_scale 586 | self._attention_kwargs = attention_kwargs 587 | self._current_timestep = None 588 | self._interrupt = False 589 | 590 | device = self._execution_device 591 | 592 | # 2. Define call parameters 593 | if prompt is not None and isinstance(prompt, str): 594 | batch_size = 1 595 | elif prompt is not None and isinstance(prompt, list): 596 | batch_size = len(prompt) 597 | else: 598 | batch_size = prompt_embeds.shape[0] 599 | 600 | # 3. Encode input prompt 601 | prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( 602 | prompt=prompt, 603 | prompt_2=prompt_2, 604 | prompt_template=prompt_template, 605 | num_videos_per_prompt=num_videos_per_prompt, 606 | prompt_embeds=prompt_embeds, 607 | pooled_prompt_embeds=pooled_prompt_embeds, 608 | prompt_attention_mask=prompt_attention_mask, 609 | device=device, 610 | max_sequence_length=max_sequence_length, 611 | ) 612 | 613 | neg_prompt_embeds, neg_pooled_prompt_embeds, neg_prompt_attention_mask = self.encode_prompt( 614 | prompt="", 615 | prompt_2=None, 616 | prompt_template=prompt_template, 617 | num_videos_per_prompt=num_videos_per_prompt, 618 | prompt_embeds=prompt_embeds, 619 | pooled_prompt_embeds=pooled_prompt_embeds, 620 | prompt_attention_mask=prompt_attention_mask, 621 | device=device, 622 | max_sequence_length=max_sequence_length, 623 | ) 624 | 625 | transformer_dtype = self.transformer.dtype 626 | prompt_embeds = prompt_embeds.to(transformer_dtype) 627 | neg_prompt_embeds = neg_prompt_embeds.to(transformer_dtype) 628 | prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) 629 | neg_prompt_attention_mask = neg_prompt_attention_mask.to(transformer_dtype) 630 | if pooled_prompt_embeds is not None: 631 | pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) 632 | neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.to(transformer_dtype) 633 | 634 | # 4. Prepare timesteps 635 | sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas 636 | timesteps, num_inference_steps = retrieve_timesteps( 637 | self.scheduler, 638 | num_inference_steps, 639 | device, 640 | sigmas=sigmas, 641 | ) 642 | 643 | # 5. Prepare latent variables 644 | num_channels_latents = self.transformer.config.out_channels 645 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 646 | latents = self.prepare_latents( 647 | batch_size * num_videos_per_prompt, 648 | num_channels_latents, 649 | height, 650 | width, 651 | num_latent_frames, 652 | torch.float32, 653 | device, 654 | generator, 655 | latents, 656 | ) 657 | 658 | # Prepare image 659 | image = image.to(device, dtype=self.vae.dtype) 660 | image_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor 661 | # print(image_latents.shape, latents.shape) 662 | assert image_latents[:, :, 0].shape == latents[:, :, 0].shape 663 | image_cond_latents = torch.zeros_like(latents) 664 | image_cond_latents[:, :, 0] = image_latents[:, :, 0] # replace first frame with image 665 | image_cond_latents = image_cond_latents.to(transformer_dtype) 666 | 667 | # 6. Prepare guidance condition 668 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 669 | 670 | # 7. Denoising loop 671 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 672 | self._num_timesteps = len(timesteps) 673 | 674 | with self.progress_bar(total=num_inference_steps) as progress_bar: 675 | for i, t in enumerate(timesteps): 676 | if self.interrupt: 677 | continue 678 | 679 | self._current_timestep = t 680 | # latent_model_input = latents.to(transformer_dtype) 681 | latent_model_input = torch.cat([latents.to(transformer_dtype), image_cond_latents], dim=1) # concat on channel dim 682 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 683 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 684 | 685 | noise_pred = self.transformer( 686 | hidden_states=latent_model_input, 687 | timestep=timestep, 688 | encoder_hidden_states=prompt_embeds, 689 | encoder_attention_mask=prompt_attention_mask, 690 | pooled_projections=pooled_prompt_embeds, 691 | guidance=guidance, 692 | attention_kwargs=attention_kwargs, 693 | return_dict=False, 694 | )[0] 695 | 696 | if i < cfg_steps: 697 | noise_pred_neg = self.transformer( 698 | hidden_states=latent_model_input, 699 | timestep=timestep, 700 | encoder_hidden_states=neg_prompt_embeds, 701 | encoder_attention_mask=neg_prompt_attention_mask, 702 | pooled_projections=neg_pooled_prompt_embeds, 703 | guidance=guidance, 704 | attention_kwargs=attention_kwargs, 705 | return_dict=False, 706 | )[0] 707 | 708 | # CFG 709 | noise_pred = noise_pred_neg + cfg_scale * (noise_pred - noise_pred_neg) 710 | 711 | # compute the previous noisy sample x_t -> x_t-1 712 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 713 | 714 | if callback_on_step_end is not None: 715 | callback_kwargs = {} 716 | for k in callback_on_step_end_tensor_inputs: 717 | callback_kwargs[k] = locals()[k] 718 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 719 | 720 | latents = callback_outputs.pop("latents", latents) 721 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 722 | 723 | # call the callback, if provided 724 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 725 | progress_bar.update() 726 | 727 | if XLA_AVAILABLE: 728 | xm.mark_step() 729 | 730 | self._current_timestep = None 731 | 732 | if not output_type == "latent": 733 | latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor 734 | video = self.vae.decode(latents, return_dict=False)[0] 735 | video = self.video_processor.postprocess_video(video, output_type=output_type) 736 | else: 737 | video = latents 738 | 739 | # Offload all models 740 | self.maybe_free_model_hooks() 741 | 742 | if not return_dict: 743 | return (video,) 744 | 745 | return HunyuanVideoPipelineOutput(frames=video) 746 | --------------------------------------------------------------------------------