├── README.md
├── __init__.py
├── diffusers_helper
├── bucket_tools.py
├── clip_vision.py
├── dit_common.py
├── gradio
│ ├── __pycache__
│ │ └── progress_bar.cpython-310.pyc
│ └── progress_bar.py
├── hf_login.py
├── hunyuan.py
├── k_diffusion
│ ├── __pycache__
│ │ ├── uni_pc_fm.cpython-310.pyc
│ │ ├── uni_pc_fm.cpython-312.pyc
│ │ ├── wrapper.cpython-310.pyc
│ │ └── wrapper.cpython-312.pyc
│ ├── uni_pc_fm.py
│ └── wrapper.py
├── memory.py
├── models
│ ├── __pycache__
│ │ ├── hunyuan_video_packed.cpython-310.pyc
│ │ └── hunyuan_video_packed.cpython-312.pyc
│ └── hunyuan_video_packed.py
├── pipelines
│ ├── __pycache__
│ │ ├── k_diffusion_hunyuan.cpython-310.pyc
│ │ └── k_diffusion_hunyuan.cpython-312.pyc
│ └── k_diffusion_hunyuan.py
├── thread_utils.py
└── utils.py
├── examples
├── FramePack_endimage.json
└── FramePack_regular.json
├── nodes.py
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # FramePack for ComfyUI
2 |
3 | **20250506 Update:** Added support for `FramePack_F1`.
4 | - **Download F1 Workflow (English)**: [https://www.runninghub.ai/post/1919141028262252546](https://www.runninghub.ai/post/1919141028262252546)
5 | - **Download F1 Workflow (中文)**: [https://www.runninghub.cn/post/1919141028262252546](https://www.runninghub.cn/post/1919141028262252546)
6 |
7 | **20250421 Update:** Added support for first/last frame image-to-video generation from TTPlanetPig
8 | [TTPlanetPig](https://github.com/TTPlanetPig) https://github.com/lllyasviel/FramePack/pull/167
9 |
10 | ## Online Access
11 | You can access RunningHub online to use this plugin and models for free:
12 | ### English Version
13 | - **Run & Download Workflow**:
14 | [https://www.runninghub.ai/post/1912930457355517954](https://www.runninghub.ai/post/1912930457355517954)
15 | ### 中文版本
16 | - **运行并下载工作流**:
17 | [https://www.runninghub.cn/post/1912930457355517954](https://www.runninghub.cn/post/1912930457355517954)
18 |
19 | ## Features
20 | This is a simple implementation of https://github.com/lllyasviel/FramePack. If there are any advantages, they would be:
21 | - Better automatic adaptation for 24GB GPUs, enabling higher resolution processing whenever possible.
22 | - The entire workflow requires no parameter adjustments, making it extremely user-friendly.
23 |
24 |
25 |
26 |
27 | # Model Download Guide
28 |
29 | ## Choose a Download Method (Pick One)
30 |
31 | 1. **Download via Cloud Storage (for users in China)**
32 | - [T8模型包] (https://pan.quark.cn/s/9669ce6c7356)
33 | 2. **One-Click Download with Python Script**
34 | ```python
35 | from huggingface_hub import snapshot_download
36 |
37 | # Download HunyuanVideo model
38 | snapshot_download(
39 | repo_id="hunyuanvideo-community/HunyuanVideo",
40 | local_dir="HunyuanVideo",
41 | ignore_patterns=["transformer/*", "*.git*", "*.log*", "*.md"],
42 | local_dir_use_symlinks=False
43 | )
44 |
45 | # Download flux_redux_bfl model
46 | snapshot_download(
47 | repo_id="lllyasviel/flux_redux_bfl",
48 | local_dir="flux_redux_bfl",
49 | ignore_patterns=["*.git*", "*.log*", "*.md"],
50 | local_dir_use_symlinks=False
51 | )
52 |
53 | # Download FramePackI2V_HY model
54 | snapshot_download(
55 | repo_id="lllyasviel/FramePackI2V_HY",
56 | local_dir="FramePackI2V_HY",
57 | ignore_patterns=["*.git*", "*.log*", "*.md"],
58 | local_dir_use_symlinks=False
59 | )
60 |
61 | # Download FramePackF1_HY model
62 | snapshot_download(
63 | repo_id="lllyasviel/FramePack_F1_I2V_HY_20250503",
64 | local_dir="FramePackF1_HY",
65 | ignore_patterns=["transformer/*", "*.git*", "*.log*", "*.md"],
66 | local_dir_use_symlinks=False
67 | )
68 |
69 | 3. **Manual Download**
70 | - HunyuanVideo: [HuggingFace Link](https://huggingface.co/hunyuanvideo-community/HunyuanVideo/tree/main)
71 | - Flux Redux BFL: [HuggingFace Link](https://huggingface.co/lllyasviel/flux_redux_bfl/tree/main)
72 | - FramePackI2V: [HuggingFace Link](https://huggingface.co/lllyasviel/FramePackI2V_HY/tree/main)
73 | - FramePackF1_HY: [HuggingFace Link](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503/tree/main)
74 |
75 | 4. **File Structure After Download**
76 | ```
77 | comfyui/models/
78 | FramePackF1_HY
79 | ├── config.json
80 | ├── diffusion_pytorch_model-00001-of-00003.safetensors
81 | ├── diffusion_pytorch_model-00002-of-00003.safetensors
82 | ├── diffusion_pytorch_model-00003-of-00003.safetensors
83 | ├── diffusion_pytorch_model.safetensors.index.json
84 | └── down.py
85 | FramePackI2V_HY
86 | ├── config.json
87 | ├── diffusion_pytorch_model-00001-of-00003.safetensors
88 | ├── diffusion_pytorch_model-00002-of-00003.safetensors
89 | ├── diffusion_pytorch_model-00003-of-00003.safetensors
90 | └── diffusion_pytorch_model.safetensors.index.json
91 | flux_redux_bfl
92 | ├── feature_extractor
93 | │ └── preprocessor_config.json
94 | ├── image_embedder
95 | │ ├── config.json
96 | │ └── diffusion_pytorch_model.safetensors
97 | ├── image_encoder
98 | │ ├── config.json
99 | │ └── model.safetensors
100 | └── model_index.json
101 | HunyuanVideo
102 | ├── config.json
103 | ├── model_index.json
104 | ├── scheduler
105 | │ └── scheduler_config.json
106 | ├── text_encoder
107 | │ ├── config.json
108 | │ ├── model-00001-of-00004.safetensors
109 | │ ├── model-00002-of-00004.safetensors
110 | │ ├── model-00003-of-00004.safetensors
111 | │ ├── model-00004-of-00004.safetensors
112 | │ └── model.safetensors.index.json
113 | ├── text_encoder_2
114 | │ ├── config.json
115 | │ └── model.safetensors
116 | ├── tokenizer
117 | │ ├── special_tokens_map.json
118 | │ ├── tokenizer.json
119 | │ └── tokenizer_config.json
120 | ├── tokenizer_2
121 | │ ├── merges.txt
122 | │ ├── special_tokens_map.json
123 | │ ├── tokenizer_config.json
124 | │ └── vocab.json
125 | └── vae
126 | ├── config.json
127 | └── diffusion_pytorch_model.safetensors
128 | ```
129 | 
130 |
131 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS
2 | NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/diffusers_helper/bucket_tools.py:
--------------------------------------------------------------------------------
1 | bucket_options = {
2 | 640: [
3 | (416, 960),
4 | (448, 864),
5 | (480, 832),
6 | (512, 768),
7 | (544, 704),
8 | (576, 672),
9 | (608, 640),
10 | (640, 608),
11 | (672, 576),
12 | (704, 544),
13 | (768, 512),
14 | (832, 480),
15 | (864, 448),
16 | (960, 416),
17 | ],
18 | }
19 |
20 |
21 | def find_nearest_bucket(h, w, resolution=640):
22 | min_metric = float('inf')
23 | best_bucket = None
24 | for (bucket_h, bucket_w) in bucket_options[resolution]:
25 | metric = abs(h * bucket_w - w * bucket_h)
26 | if metric <= min_metric:
27 | min_metric = metric
28 | best_bucket = (bucket_h, bucket_w)
29 | return best_bucket
30 |
31 |
--------------------------------------------------------------------------------
/diffusers_helper/clip_vision.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5 | assert isinstance(image, np.ndarray)
6 | assert image.ndim == 3 and image.shape[2] == 3
7 | assert image.dtype == np.uint8
8 |
9 | preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10 | image_encoder_output = image_encoder(**preprocessed)
11 |
12 | return image_encoder_output
13 |
--------------------------------------------------------------------------------
/diffusers_helper/dit_common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import accelerate.accelerator
3 |
4 | from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5 |
6 |
7 | accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8 |
9 |
10 | def LayerNorm_forward(self, x):
11 | return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12 |
13 |
14 | LayerNorm.forward = LayerNorm_forward
15 | torch.nn.LayerNorm.forward = LayerNorm_forward
16 |
17 |
18 | def FP32LayerNorm_forward(self, x):
19 | origin_dtype = x.dtype
20 | return torch.nn.functional.layer_norm(
21 | x.float(),
22 | self.normalized_shape,
23 | self.weight.float() if self.weight is not None else None,
24 | self.bias.float() if self.bias is not None else None,
25 | self.eps,
26 | ).to(origin_dtype)
27 |
28 |
29 | FP32LayerNorm.forward = FP32LayerNorm_forward
30 |
31 |
32 | def RMSNorm_forward(self, hidden_states):
33 | input_dtype = hidden_states.dtype
34 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36 |
37 | if self.weight is None:
38 | return hidden_states.to(input_dtype)
39 |
40 | return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41 |
42 |
43 | RMSNorm.forward = RMSNorm_forward
44 |
45 |
46 | def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47 | emb = self.linear(self.silu(conditioning_embedding))
48 | scale, shift = emb.chunk(2, dim=1)
49 | x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50 | return x
51 |
52 |
53 | AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
54 |
--------------------------------------------------------------------------------
/diffusers_helper/gradio/__pycache__/progress_bar.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/gradio/__pycache__/progress_bar.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusers_helper/gradio/progress_bar.py:
--------------------------------------------------------------------------------
1 | progress_html = '''
2 |
9 | '''
10 |
11 | css = '''
12 | .loader-container {
13 | display: flex; /* Use flex to align items horizontally */
14 | align-items: center; /* Center items vertically within the container */
15 | white-space: nowrap; /* Prevent line breaks within the container */
16 | }
17 |
18 | .loader {
19 | border: 8px solid #f3f3f3; /* Light grey */
20 | border-top: 8px solid #3498db; /* Blue */
21 | border-radius: 50%;
22 | width: 30px;
23 | height: 30px;
24 | animation: spin 2s linear infinite;
25 | }
26 |
27 | @keyframes spin {
28 | 0% { transform: rotate(0deg); }
29 | 100% { transform: rotate(360deg); }
30 | }
31 |
32 | /* Style the progress bar */
33 | progress {
34 | appearance: none; /* Remove default styling */
35 | height: 20px; /* Set the height of the progress bar */
36 | border-radius: 5px; /* Round the corners of the progress bar */
37 | background-color: #f3f3f3; /* Light grey background */
38 | width: 100%;
39 | vertical-align: middle !important;
40 | }
41 |
42 | /* Style the progress bar container */
43 | .progress-container {
44 | margin-left: 20px;
45 | margin-right: 20px;
46 | flex-grow: 1; /* Allow the progress container to take up remaining space */
47 | }
48 |
49 | /* Set the color of the progress bar fill */
50 | progress::-webkit-progress-value {
51 | background-color: #3498db; /* Blue color for the fill */
52 | }
53 |
54 | progress::-moz-progress-bar {
55 | background-color: #3498db; /* Blue color for the fill in Firefox */
56 | }
57 |
58 | /* Style the text on the progress bar */
59 | progress::after {
60 | content: attr(value '%'); /* Display the progress value followed by '%' */
61 | position: absolute;
62 | top: 50%;
63 | left: 50%;
64 | transform: translate(-50%, -50%);
65 | color: white; /* Set text color */
66 | font-size: 14px; /* Set font size */
67 | }
68 |
69 | /* Style other texts */
70 | .loader-container > span {
71 | margin-left: 5px; /* Add spacing between the progress bar and the text */
72 | }
73 |
74 | .no-generating-animation > .generating {
75 | display: none !important;
76 | }
77 |
78 | '''
79 |
80 |
81 | def make_progress_bar_html(number, text):
82 | return progress_html.replace('*number*', str(number)).replace('*text*', text)
83 |
84 |
85 | def make_progress_bar_css():
86 | return css
87 |
--------------------------------------------------------------------------------
/diffusers_helper/hf_login.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def login(token):
5 | from huggingface_hub import login
6 | import time
7 |
8 | while True:
9 | try:
10 | login(token)
11 | print('HF login ok.')
12 | break
13 | except Exception as e:
14 | print(f'HF login failed: {e}. Retrying')
15 | time.sleep(0.5)
16 |
17 |
18 | hf_token = os.environ.get('HF_TOKEN', None)
19 |
20 | if hf_token is not None:
21 | login(hf_token)
22 |
--------------------------------------------------------------------------------
/diffusers_helper/hunyuan.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4 | from diffusers_helper.utils import crop_or_pad_yield_mask
5 |
6 |
7 | @torch.no_grad()
8 | def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9 | assert isinstance(prompt, str)
10 |
11 | prompt = [prompt]
12 |
13 | # LLAMA
14 |
15 | prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16 | crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17 |
18 | llama_inputs = tokenizer(
19 | prompt_llama,
20 | padding="max_length",
21 | max_length=max_length + crop_start,
22 | truncation=True,
23 | return_tensors="pt",
24 | return_length=False,
25 | return_overflowing_tokens=False,
26 | return_attention_mask=True,
27 | )
28 |
29 | llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30 | llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31 | llama_attention_length = int(llama_attention_mask.sum())
32 |
33 | llama_outputs = text_encoder(
34 | input_ids=llama_input_ids,
35 | attention_mask=llama_attention_mask,
36 | output_hidden_states=True,
37 | )
38 |
39 | llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40 | # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41 | llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42 |
43 | assert torch.all(llama_attention_mask.bool())
44 |
45 | # CLIP
46 |
47 | clip_l_input_ids = tokenizer_2(
48 | prompt,
49 | padding="max_length",
50 | max_length=77,
51 | truncation=True,
52 | return_overflowing_tokens=False,
53 | return_length=False,
54 | return_tensors="pt",
55 | ).input_ids
56 | clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57 |
58 | return llama_vec, clip_l_pooler
59 |
60 |
61 | @torch.no_grad()
62 | def vae_decode_fake(latents):
63 | latent_rgb_factors = [
64 | [-0.0395, -0.0331, 0.0445],
65 | [0.0696, 0.0795, 0.0518],
66 | [0.0135, -0.0945, -0.0282],
67 | [0.0108, -0.0250, -0.0765],
68 | [-0.0209, 0.0032, 0.0224],
69 | [-0.0804, -0.0254, -0.0639],
70 | [-0.0991, 0.0271, -0.0669],
71 | [-0.0646, -0.0422, -0.0400],
72 | [-0.0696, -0.0595, -0.0894],
73 | [-0.0799, -0.0208, -0.0375],
74 | [0.1166, 0.1627, 0.0962],
75 | [0.1165, 0.0432, 0.0407],
76 | [-0.2315, -0.1920, -0.1355],
77 | [-0.0270, 0.0401, -0.0821],
78 | [-0.0616, -0.0997, -0.0727],
79 | [0.0249, -0.0469, -0.1703]
80 | ] # From comfyui
81 |
82 | latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83 |
84 | weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85 | bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86 |
87 | images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88 | images = images.clamp(0.0, 1.0)
89 |
90 | return images
91 |
92 |
93 | @torch.no_grad()
94 | def vae_decode(latents, vae, image_mode=False):
95 | latents = latents / vae.config.scaling_factor
96 |
97 | if not image_mode:
98 | image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
99 | else:
100 | latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
101 | image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
102 | image = torch.cat(image, dim=2)
103 |
104 | return image
105 |
106 |
107 | @torch.no_grad()
108 | def vae_encode(image, vae):
109 | latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110 | latents = latents * vae.config.scaling_factor
111 | return latents
112 |
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/__pycache__/uni_pc_fm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/k_diffusion/__pycache__/uni_pc_fm.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/__pycache__/uni_pc_fm.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/k_diffusion/__pycache__/uni_pc_fm.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/__pycache__/wrapper.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/k_diffusion/__pycache__/wrapper.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/__pycache__/wrapper.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/k_diffusion/__pycache__/wrapper.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/uni_pc_fm.py:
--------------------------------------------------------------------------------
1 | # Better Flow Matching UniPC by Lvmin Zhang
2 | # (c) 2025
3 | # CC BY-SA 4.0
4 | # Attribution-ShareAlike 4.0 International Licence
5 |
6 |
7 | import torch
8 |
9 | from tqdm.auto import trange
10 |
11 |
12 | def expand_dims(v, dims):
13 | return v[(...,) + (None,) * (dims - 1)]
14 |
15 |
16 | class FlowMatchUniPC:
17 | def __init__(self, model, extra_args, variant='bh1'):
18 | self.model = model
19 | self.variant = variant
20 | self.extra_args = extra_args
21 |
22 | def model_fn(self, x, t):
23 | return self.model(x, t, **self.extra_args)
24 |
25 | def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26 | assert order <= len(model_prev_list)
27 | dims = x.dim()
28 |
29 | t_prev_0 = t_prev_list[-1]
30 | lambda_prev_0 = - torch.log(t_prev_0)
31 | lambda_t = - torch.log(t)
32 | model_prev_0 = model_prev_list[-1]
33 |
34 | h = lambda_t - lambda_prev_0
35 |
36 | rks = []
37 | D1s = []
38 | for i in range(1, order):
39 | t_prev_i = t_prev_list[-(i + 1)]
40 | model_prev_i = model_prev_list[-(i + 1)]
41 | lambda_prev_i = - torch.log(t_prev_i)
42 | rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43 | rks.append(rk)
44 | D1s.append((model_prev_i - model_prev_0) / rk)
45 |
46 | rks.append(1.)
47 | rks = torch.tensor(rks, device=x.device)
48 |
49 | R = []
50 | b = []
51 |
52 | hh = -h[0]
53 | h_phi_1 = torch.expm1(hh)
54 | h_phi_k = h_phi_1 / hh - 1
55 |
56 | factorial_i = 1
57 |
58 | if self.variant == 'bh1':
59 | B_h = hh
60 | elif self.variant == 'bh2':
61 | B_h = torch.expm1(hh)
62 | else:
63 | raise NotImplementedError('Bad variant!')
64 |
65 | for i in range(1, order + 1):
66 | R.append(torch.pow(rks, i - 1))
67 | b.append(h_phi_k * factorial_i / B_h)
68 | factorial_i *= (i + 1)
69 | h_phi_k = h_phi_k / hh - 1 / factorial_i
70 |
71 | R = torch.stack(R)
72 | b = torch.tensor(b, device=x.device)
73 |
74 | use_predictor = len(D1s) > 0
75 |
76 | if use_predictor:
77 | D1s = torch.stack(D1s, dim=1)
78 | if order == 2:
79 | rhos_p = torch.tensor([0.5], device=b.device)
80 | else:
81 | rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82 | else:
83 | D1s = None
84 | rhos_p = None
85 |
86 | if order == 1:
87 | rhos_c = torch.tensor([0.5], device=b.device)
88 | else:
89 | rhos_c = torch.linalg.solve(R, b)
90 |
91 | x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92 |
93 | if use_predictor:
94 | pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95 | else:
96 | pred_res = 0
97 |
98 | x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99 | model_t = self.model_fn(x_t, t)
100 |
101 | if D1s is not None:
102 | corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103 | else:
104 | corr_res = 0
105 |
106 | D1_t = (model_t - model_prev_0)
107 | x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108 |
109 | return x_t, model_t
110 |
111 | def sample(self, x, sigmas, callback=None, disable_pbar=False):
112 | order = min(3, len(sigmas) - 2)
113 | model_prev_list, t_prev_list = [], []
114 | for i in trange(len(sigmas) - 1, disable=disable_pbar):
115 | vec_t = sigmas[i].expand(x.shape[0])
116 |
117 | if i == 0:
118 | model_prev_list = [self.model_fn(x, vec_t)]
119 | t_prev_list = [vec_t]
120 | elif i < order:
121 | init_order = i
122 | x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123 | model_prev_list.append(model_x)
124 | t_prev_list.append(vec_t)
125 | else:
126 | x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127 | model_prev_list.append(model_x)
128 | t_prev_list.append(vec_t)
129 |
130 | model_prev_list = model_prev_list[-order:]
131 | t_prev_list = t_prev_list[-order:]
132 |
133 | if callback is not None:
134 | callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135 |
136 | return model_prev_list[-1]
137 |
138 |
139 | def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
140 | assert variant in ['bh1', 'bh2']
141 | return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
142 |
--------------------------------------------------------------------------------
/diffusers_helper/k_diffusion/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def append_dims(x, target_dims):
5 | return x[(...,) + (None,) * (target_dims - x.ndim)]
6 |
7 |
8 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9 | if guidance_rescale == 0:
10 | return noise_cfg
11 |
12 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16 | return noise_cfg
17 |
18 |
19 | def fm_wrapper(transformer, t_scale=1000.0):
20 | def k_model(x, sigma, **extra_args):
21 | dtype = extra_args['dtype']
22 | cfg_scale = extra_args['cfg_scale']
23 | cfg_rescale = extra_args['cfg_rescale']
24 | concat_latent = extra_args['concat_latent']
25 |
26 | original_dtype = x.dtype
27 | sigma = sigma.float()
28 |
29 | x = x.to(dtype)
30 | timestep = (sigma * t_scale).to(dtype)
31 |
32 | if concat_latent is None:
33 | hidden_states = x
34 | else:
35 | hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36 |
37 | pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38 |
39 | if cfg_scale == 1.0:
40 | pred_negative = torch.zeros_like(pred_positive)
41 | else:
42 | pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43 |
44 | pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45 | pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46 |
47 | x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48 |
49 | return x0.to(dtype=original_dtype)
50 |
51 | return k_model
52 |
--------------------------------------------------------------------------------
/diffusers_helper/memory.py:
--------------------------------------------------------------------------------
1 | # By lllyasviel
2 |
3 |
4 | import torch
5 |
6 |
7 | #cpu = torch.device('cpu')
8 | #gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9 | cpu = 'cpu'
10 | gpu = 'cuda:0'
11 | gpu_complete_modules = []
12 |
13 |
14 | class DynamicSwapInstaller:
15 | @staticmethod
16 | def _install_module(module: torch.nn.Module, **kwargs):
17 | original_class = module.__class__
18 | module.__dict__['forge_backup_original_class'] = original_class
19 |
20 | def hacked_get_attr(self, name: str):
21 | if '_parameters' in self.__dict__:
22 | _parameters = self.__dict__['_parameters']
23 | if name in _parameters:
24 | p = _parameters[name]
25 | if p is None:
26 | return None
27 | if p.__class__ == torch.nn.Parameter:
28 | return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
29 | else:
30 | return p.to(**kwargs)
31 | if '_buffers' in self.__dict__:
32 | _buffers = self.__dict__['_buffers']
33 | if name in _buffers:
34 | return _buffers[name].to(**kwargs)
35 | return super(original_class, self).__getattr__(name)
36 |
37 | module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
38 | '__getattr__': hacked_get_attr,
39 | })
40 |
41 | return
42 |
43 | @staticmethod
44 | def _uninstall_module(module: torch.nn.Module):
45 | if 'forge_backup_original_class' in module.__dict__:
46 | module.__class__ = module.__dict__.pop('forge_backup_original_class')
47 | return
48 |
49 | @staticmethod
50 | def install_model(model: torch.nn.Module, **kwargs):
51 | for m in model.modules():
52 | DynamicSwapInstaller._install_module(m, **kwargs)
53 | return
54 |
55 | @staticmethod
56 | def uninstall_model(model: torch.nn.Module):
57 | for m in model.modules():
58 | DynamicSwapInstaller._uninstall_module(m)
59 | return
60 |
61 |
62 | def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
63 | if hasattr(model, 'scale_shift_table'):
64 | model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
65 | return
66 |
67 | for k, p in model.named_modules():
68 | if hasattr(p, 'weight'):
69 | p.to(target_device)
70 | return
71 |
72 |
73 | def get_cuda_free_memory_gb(device=None):
74 | if device is None:
75 | device = gpu
76 |
77 | memory_stats = torch.cuda.memory_stats(device)
78 | bytes_active = memory_stats['active_bytes.all.current']
79 | bytes_reserved = memory_stats['reserved_bytes.all.current']
80 | bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
81 | bytes_inactive_reserved = bytes_reserved - bytes_active
82 | bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
83 | return bytes_total_available / (1024 ** 3)
84 |
85 |
86 | def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
87 | print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
88 |
89 | for m in model.modules():
90 | if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
91 | torch.cuda.empty_cache()
92 | return
93 |
94 | if hasattr(m, 'weight'):
95 | m.to(device=target_device)
96 |
97 | model.to(device=target_device)
98 | torch.cuda.empty_cache()
99 | return
100 |
101 |
102 | def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
103 | print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
104 |
105 | for m in model.modules():
106 | if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
107 | torch.cuda.empty_cache()
108 | return
109 |
110 | if hasattr(m, 'weight'):
111 | m.to(device=cpu)
112 |
113 | model.to(device=cpu)
114 | torch.cuda.empty_cache()
115 | return
116 |
117 |
118 | def unload_complete_models(*args):
119 | for m in gpu_complete_modules + list(args):
120 | m.to(device=cpu)
121 | print(f'Unloaded {m.__class__.__name__} as complete.')
122 |
123 | gpu_complete_modules.clear()
124 | torch.cuda.empty_cache()
125 | return
126 |
127 |
128 | def load_model_as_complete(model, target_device, unload=True):
129 | if unload:
130 | unload_complete_models()
131 |
132 | model.to(device=target_device)
133 | print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
134 |
135 | gpu_complete_modules.append(model)
136 | return
137 |
--------------------------------------------------------------------------------
/diffusers_helper/models/__pycache__/hunyuan_video_packed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/models/__pycache__/hunyuan_video_packed.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusers_helper/models/__pycache__/hunyuan_video_packed.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/models/__pycache__/hunyuan_video_packed.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusers_helper/models/hunyuan_video_packed.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | import einops
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | from diffusers.loaders import FromOriginalModelMixin
9 | from diffusers.configuration_utils import ConfigMixin, register_to_config
10 | from diffusers.loaders import PeftAdapterMixin
11 | from diffusers.utils import logging
12 | from diffusers.models.attention import FeedForward
13 | from diffusers.models.attention_processor import Attention
14 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
16 | from diffusers.models.modeling_utils import ModelMixin
17 | from diffusers_helper.dit_common import LayerNorm
18 | from diffusers_helper.utils import zero_module
19 |
20 |
21 | enabled_backends = []
22 |
23 | if torch.backends.cuda.flash_sdp_enabled():
24 | enabled_backends.append("flash")
25 | if torch.backends.cuda.math_sdp_enabled():
26 | enabled_backends.append("math")
27 | if torch.backends.cuda.mem_efficient_sdp_enabled():
28 | enabled_backends.append("mem_efficient")
29 | if torch.backends.cuda.cudnn_sdp_enabled():
30 | enabled_backends.append("cudnn")
31 |
32 | print("Currently enabled native sdp backends:", enabled_backends)
33 |
34 | try:
35 | # raise NotImplementedError
36 | from xformers.ops import memory_efficient_attention as xformers_attn_func
37 | print('Xformers is installed!')
38 | except:
39 | print('Xformers is not installed!')
40 | xformers_attn_func = None
41 |
42 | try:
43 | # raise NotImplementedError
44 | from flash_attn import flash_attn_varlen_func, flash_attn_func
45 | print('Flash Attn is installed!')
46 | except:
47 | print('Flash Attn is not installed!')
48 | flash_attn_varlen_func = None
49 | flash_attn_func = None
50 |
51 | try:
52 | # raise NotImplementedError
53 | from sageattention import sageattn_varlen, sageattn
54 | print('Sage Attn is installed!')
55 | except:
56 | print('Sage Attn is not installed!')
57 | sageattn_varlen = None
58 | sageattn = None
59 |
60 |
61 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62 |
63 |
64 | def pad_for_3d_conv(x, kernel_size):
65 | b, c, t, h, w = x.shape
66 | pt, ph, pw = kernel_size
67 | pad_t = (pt - (t % pt)) % pt
68 | pad_h = (ph - (h % ph)) % ph
69 | pad_w = (pw - (w % pw)) % pw
70 | return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71 |
72 |
73 | def center_down_sample_3d(x, kernel_size):
74 | # pt, ph, pw = kernel_size
75 | # cp = (pt * ph * pw) // 2
76 | # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77 | # xc = xp[cp]
78 | # return xc
79 | return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80 |
81 |
82 | def get_cu_seqlens(text_mask, img_len):
83 | batch_size = text_mask.shape[0]
84 | text_len = text_mask.sum(dim=1)
85 | max_len = text_mask.shape[1] + img_len
86 |
87 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88 |
89 | for i in range(batch_size):
90 | s = text_len[i] + img_len
91 | s1 = i * max_len + s
92 | s2 = (i + 1) * max_len
93 | cu_seqlens[2 * i + 1] = s1
94 | cu_seqlens[2 * i + 2] = s2
95 |
96 | return cu_seqlens
97 |
98 |
99 | def apply_rotary_emb_transposed(x, freqs_cis):
100 | cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101 | x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103 | out = x.float() * cos + x_rotated.float() * sin
104 | out = out.to(x)
105 | return out
106 |
107 |
108 | def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109 | if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110 | if sageattn is not None:
111 | x = sageattn(q, k, v, tensor_layout='NHD')
112 | return x
113 |
114 | if flash_attn_func is not None:
115 | x = flash_attn_func(q, k, v)
116 | return x
117 |
118 | if xformers_attn_func is not None:
119 | x = xformers_attn_func(q, k, v)
120 | return x
121 |
122 | x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123 | return x
124 |
125 | batch_size = q.shape[0]
126 | q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
127 | k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
128 | v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
129 | if sageattn_varlen is not None:
130 | x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
131 | elif flash_attn_varlen_func is not None:
132 | x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133 | else:
134 | raise NotImplementedError('No Attn Installed!')
135 | x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
136 | return x
137 |
138 |
139 | class HunyuanAttnProcessorFlashAttnDouble:
140 | def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
141 | cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
142 |
143 | query = attn.to_q(hidden_states)
144 | key = attn.to_k(hidden_states)
145 | value = attn.to_v(hidden_states)
146 |
147 | query = query.unflatten(2, (attn.heads, -1))
148 | key = key.unflatten(2, (attn.heads, -1))
149 | value = value.unflatten(2, (attn.heads, -1))
150 |
151 | query = attn.norm_q(query)
152 | key = attn.norm_k(key)
153 |
154 | query = apply_rotary_emb_transposed(query, image_rotary_emb)
155 | key = apply_rotary_emb_transposed(key, image_rotary_emb)
156 |
157 | encoder_query = attn.add_q_proj(encoder_hidden_states)
158 | encoder_key = attn.add_k_proj(encoder_hidden_states)
159 | encoder_value = attn.add_v_proj(encoder_hidden_states)
160 |
161 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
162 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
163 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
164 |
165 | encoder_query = attn.norm_added_q(encoder_query)
166 | encoder_key = attn.norm_added_k(encoder_key)
167 |
168 | query = torch.cat([query, encoder_query], dim=1)
169 | key = torch.cat([key, encoder_key], dim=1)
170 | value = torch.cat([value, encoder_value], dim=1)
171 |
172 | hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
173 | hidden_states = hidden_states.flatten(-2)
174 |
175 | txt_length = encoder_hidden_states.shape[1]
176 | hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
177 |
178 | hidden_states = attn.to_out[0](hidden_states)
179 | hidden_states = attn.to_out[1](hidden_states)
180 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
181 |
182 | return hidden_states, encoder_hidden_states
183 |
184 |
185 | class HunyuanAttnProcessorFlashAttnSingle:
186 | def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
187 | cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
188 |
189 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
190 |
191 | query = attn.to_q(hidden_states)
192 | key = attn.to_k(hidden_states)
193 | value = attn.to_v(hidden_states)
194 |
195 | query = query.unflatten(2, (attn.heads, -1))
196 | key = key.unflatten(2, (attn.heads, -1))
197 | value = value.unflatten(2, (attn.heads, -1))
198 |
199 | query = attn.norm_q(query)
200 | key = attn.norm_k(key)
201 |
202 | txt_length = encoder_hidden_states.shape[1]
203 |
204 | query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
205 | key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
206 |
207 | hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
208 | hidden_states = hidden_states.flatten(-2)
209 |
210 | hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
211 |
212 | return hidden_states, encoder_hidden_states
213 |
214 |
215 | class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
216 | def __init__(self, embedding_dim, pooled_projection_dim):
217 | super().__init__()
218 |
219 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
220 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
221 | self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
222 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
223 |
224 | def forward(self, timestep, guidance, pooled_projection):
225 | timesteps_proj = self.time_proj(timestep)
226 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
227 |
228 | guidance_proj = self.time_proj(guidance)
229 | guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
230 |
231 | time_guidance_emb = timesteps_emb + guidance_emb
232 |
233 | pooled_projections = self.text_embedder(pooled_projection)
234 | conditioning = time_guidance_emb + pooled_projections
235 |
236 | return conditioning
237 |
238 |
239 | class CombinedTimestepTextProjEmbeddings(nn.Module):
240 | def __init__(self, embedding_dim, pooled_projection_dim):
241 | super().__init__()
242 |
243 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
244 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
245 | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
246 |
247 | def forward(self, timestep, pooled_projection):
248 | timesteps_proj = self.time_proj(timestep)
249 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
250 |
251 | pooled_projections = self.text_embedder(pooled_projection)
252 |
253 | conditioning = timesteps_emb + pooled_projections
254 |
255 | return conditioning
256 |
257 |
258 | class HunyuanVideoAdaNorm(nn.Module):
259 | def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
260 | super().__init__()
261 |
262 | out_features = out_features or 2 * in_features
263 | self.linear = nn.Linear(in_features, out_features)
264 | self.nonlinearity = nn.SiLU()
265 |
266 | def forward(
267 | self, temb: torch.Tensor
268 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
269 | temb = self.linear(self.nonlinearity(temb))
270 | gate_msa, gate_mlp = temb.chunk(2, dim=-1)
271 | gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
272 | return gate_msa, gate_mlp
273 |
274 |
275 | class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
276 | def __init__(
277 | self,
278 | num_attention_heads: int,
279 | attention_head_dim: int,
280 | mlp_width_ratio: str = 4.0,
281 | mlp_drop_rate: float = 0.0,
282 | attention_bias: bool = True,
283 | ) -> None:
284 | super().__init__()
285 |
286 | hidden_size = num_attention_heads * attention_head_dim
287 |
288 | self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
289 | self.attn = Attention(
290 | query_dim=hidden_size,
291 | cross_attention_dim=None,
292 | heads=num_attention_heads,
293 | dim_head=attention_head_dim,
294 | bias=attention_bias,
295 | )
296 |
297 | self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
298 | self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
299 |
300 | self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
301 |
302 | def forward(
303 | self,
304 | hidden_states: torch.Tensor,
305 | temb: torch.Tensor,
306 | attention_mask: Optional[torch.Tensor] = None,
307 | ) -> torch.Tensor:
308 | norm_hidden_states = self.norm1(hidden_states)
309 |
310 | attn_output = self.attn(
311 | hidden_states=norm_hidden_states,
312 | encoder_hidden_states=None,
313 | attention_mask=attention_mask,
314 | )
315 |
316 | gate_msa, gate_mlp = self.norm_out(temb)
317 | hidden_states = hidden_states + attn_output * gate_msa
318 |
319 | ff_output = self.ff(self.norm2(hidden_states))
320 | hidden_states = hidden_states + ff_output * gate_mlp
321 |
322 | return hidden_states
323 |
324 |
325 | class HunyuanVideoIndividualTokenRefiner(nn.Module):
326 | def __init__(
327 | self,
328 | num_attention_heads: int,
329 | attention_head_dim: int,
330 | num_layers: int,
331 | mlp_width_ratio: float = 4.0,
332 | mlp_drop_rate: float = 0.0,
333 | attention_bias: bool = True,
334 | ) -> None:
335 | super().__init__()
336 |
337 | self.refiner_blocks = nn.ModuleList(
338 | [
339 | HunyuanVideoIndividualTokenRefinerBlock(
340 | num_attention_heads=num_attention_heads,
341 | attention_head_dim=attention_head_dim,
342 | mlp_width_ratio=mlp_width_ratio,
343 | mlp_drop_rate=mlp_drop_rate,
344 | attention_bias=attention_bias,
345 | )
346 | for _ in range(num_layers)
347 | ]
348 | )
349 |
350 | def forward(
351 | self,
352 | hidden_states: torch.Tensor,
353 | temb: torch.Tensor,
354 | attention_mask: Optional[torch.Tensor] = None,
355 | ) -> None:
356 | self_attn_mask = None
357 | if attention_mask is not None:
358 | batch_size = attention_mask.shape[0]
359 | seq_len = attention_mask.shape[1]
360 | attention_mask = attention_mask.to(hidden_states.device).bool()
361 | self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
362 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
363 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
364 | self_attn_mask[:, :, :, 0] = True
365 |
366 | for block in self.refiner_blocks:
367 | hidden_states = block(hidden_states, temb, self_attn_mask)
368 |
369 | return hidden_states
370 |
371 |
372 | class HunyuanVideoTokenRefiner(nn.Module):
373 | def __init__(
374 | self,
375 | in_channels: int,
376 | num_attention_heads: int,
377 | attention_head_dim: int,
378 | num_layers: int,
379 | mlp_ratio: float = 4.0,
380 | mlp_drop_rate: float = 0.0,
381 | attention_bias: bool = True,
382 | ) -> None:
383 | super().__init__()
384 |
385 | hidden_size = num_attention_heads * attention_head_dim
386 |
387 | self.time_text_embed = CombinedTimestepTextProjEmbeddings(
388 | embedding_dim=hidden_size, pooled_projection_dim=in_channels
389 | )
390 | self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
391 | self.token_refiner = HunyuanVideoIndividualTokenRefiner(
392 | num_attention_heads=num_attention_heads,
393 | attention_head_dim=attention_head_dim,
394 | num_layers=num_layers,
395 | mlp_width_ratio=mlp_ratio,
396 | mlp_drop_rate=mlp_drop_rate,
397 | attention_bias=attention_bias,
398 | )
399 |
400 | def forward(
401 | self,
402 | hidden_states: torch.Tensor,
403 | timestep: torch.LongTensor,
404 | attention_mask: Optional[torch.LongTensor] = None,
405 | ) -> torch.Tensor:
406 | if attention_mask is None:
407 | pooled_projections = hidden_states.mean(dim=1)
408 | else:
409 | original_dtype = hidden_states.dtype
410 | mask_float = attention_mask.float().unsqueeze(-1)
411 | pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
412 | pooled_projections = pooled_projections.to(original_dtype)
413 |
414 | temb = self.time_text_embed(timestep, pooled_projections)
415 | hidden_states = self.proj_in(hidden_states)
416 | hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
417 |
418 | return hidden_states
419 |
420 |
421 | class HunyuanVideoRotaryPosEmbed(nn.Module):
422 | def __init__(self, rope_dim, theta):
423 | super().__init__()
424 | self.DT, self.DY, self.DX = rope_dim
425 | self.theta = theta
426 |
427 | @torch.no_grad()
428 | def get_frequency(self, dim, pos):
429 | T, H, W = pos.shape
430 | freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
431 | freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
432 | return freqs.cos(), freqs.sin()
433 |
434 | @torch.no_grad()
435 | def forward_inner(self, frame_indices, height, width, device):
436 | GT, GY, GX = torch.meshgrid(
437 | frame_indices.to(device=device, dtype=torch.float32),
438 | torch.arange(0, height, device=device, dtype=torch.float32),
439 | torch.arange(0, width, device=device, dtype=torch.float32),
440 | indexing="ij"
441 | )
442 |
443 | FCT, FST = self.get_frequency(self.DT, GT)
444 | FCY, FSY = self.get_frequency(self.DY, GY)
445 | FCX, FSX = self.get_frequency(self.DX, GX)
446 |
447 | result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
448 |
449 | return result.to(device)
450 |
451 | @torch.no_grad()
452 | def forward(self, frame_indices, height, width, device):
453 | frame_indices = frame_indices.unbind(0)
454 | results = [self.forward_inner(f, height, width, device) for f in frame_indices]
455 | results = torch.stack(results, dim=0)
456 | return results
457 |
458 |
459 | class AdaLayerNormZero(nn.Module):
460 | def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
461 | super().__init__()
462 | self.silu = nn.SiLU()
463 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
464 | if norm_type == "layer_norm":
465 | self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
466 | else:
467 | raise ValueError(f"unknown norm_type {norm_type}")
468 |
469 | def forward(
470 | self,
471 | x: torch.Tensor,
472 | emb: Optional[torch.Tensor] = None,
473 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
474 | emb = emb.unsqueeze(-2)
475 | emb = self.linear(self.silu(emb))
476 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
477 | x = self.norm(x) * (1 + scale_msa) + shift_msa
478 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
479 |
480 |
481 | class AdaLayerNormZeroSingle(nn.Module):
482 | def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
483 | super().__init__()
484 |
485 | self.silu = nn.SiLU()
486 | self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
487 | if norm_type == "layer_norm":
488 | self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
489 | else:
490 | raise ValueError(f"unknown norm_type {norm_type}")
491 |
492 | def forward(
493 | self,
494 | x: torch.Tensor,
495 | emb: Optional[torch.Tensor] = None,
496 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
497 | emb = emb.unsqueeze(-2)
498 | emb = self.linear(self.silu(emb))
499 | shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
500 | x = self.norm(x) * (1 + scale_msa) + shift_msa
501 | return x, gate_msa
502 |
503 |
504 | class AdaLayerNormContinuous(nn.Module):
505 | def __init__(
506 | self,
507 | embedding_dim: int,
508 | conditioning_embedding_dim: int,
509 | elementwise_affine=True,
510 | eps=1e-5,
511 | bias=True,
512 | norm_type="layer_norm",
513 | ):
514 | super().__init__()
515 | self.silu = nn.SiLU()
516 | self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
517 | if norm_type == "layer_norm":
518 | self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
519 | else:
520 | raise ValueError(f"unknown norm_type {norm_type}")
521 |
522 | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
523 | emb = emb.unsqueeze(-2)
524 | emb = self.linear(self.silu(emb))
525 | scale, shift = emb.chunk(2, dim=-1)
526 | x = self.norm(x) * (1 + scale) + shift
527 | return x
528 |
529 |
530 | class HunyuanVideoSingleTransformerBlock(nn.Module):
531 | def __init__(
532 | self,
533 | num_attention_heads: int,
534 | attention_head_dim: int,
535 | mlp_ratio: float = 4.0,
536 | qk_norm: str = "rms_norm",
537 | ) -> None:
538 | super().__init__()
539 |
540 | hidden_size = num_attention_heads * attention_head_dim
541 | mlp_dim = int(hidden_size * mlp_ratio)
542 |
543 | self.attn = Attention(
544 | query_dim=hidden_size,
545 | cross_attention_dim=None,
546 | dim_head=attention_head_dim,
547 | heads=num_attention_heads,
548 | out_dim=hidden_size,
549 | bias=True,
550 | processor=HunyuanAttnProcessorFlashAttnSingle(),
551 | qk_norm=qk_norm,
552 | eps=1e-6,
553 | pre_only=True,
554 | )
555 |
556 | self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
557 | self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
558 | self.act_mlp = nn.GELU(approximate="tanh")
559 | self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
560 |
561 | def forward(
562 | self,
563 | hidden_states: torch.Tensor,
564 | encoder_hidden_states: torch.Tensor,
565 | temb: torch.Tensor,
566 | attention_mask: Optional[torch.Tensor] = None,
567 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
568 | ) -> torch.Tensor:
569 | text_seq_length = encoder_hidden_states.shape[1]
570 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
571 |
572 | residual = hidden_states
573 |
574 | # 1. Input normalization
575 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
576 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
577 |
578 | norm_hidden_states, norm_encoder_hidden_states = (
579 | norm_hidden_states[:, :-text_seq_length, :],
580 | norm_hidden_states[:, -text_seq_length:, :],
581 | )
582 |
583 | # 2. Attention
584 | attn_output, context_attn_output = self.attn(
585 | hidden_states=norm_hidden_states,
586 | encoder_hidden_states=norm_encoder_hidden_states,
587 | attention_mask=attention_mask,
588 | image_rotary_emb=image_rotary_emb,
589 | )
590 | attn_output = torch.cat([attn_output, context_attn_output], dim=1)
591 |
592 | # 3. Modulation and residual connection
593 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
594 | hidden_states = gate * self.proj_out(hidden_states)
595 | hidden_states = hidden_states + residual
596 |
597 | hidden_states, encoder_hidden_states = (
598 | hidden_states[:, :-text_seq_length, :],
599 | hidden_states[:, -text_seq_length:, :],
600 | )
601 | return hidden_states, encoder_hidden_states
602 |
603 |
604 | class HunyuanVideoTransformerBlock(nn.Module):
605 | def __init__(
606 | self,
607 | num_attention_heads: int,
608 | attention_head_dim: int,
609 | mlp_ratio: float,
610 | qk_norm: str = "rms_norm",
611 | ) -> None:
612 | super().__init__()
613 |
614 | hidden_size = num_attention_heads * attention_head_dim
615 |
616 | self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
617 | self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
618 |
619 | self.attn = Attention(
620 | query_dim=hidden_size,
621 | cross_attention_dim=None,
622 | added_kv_proj_dim=hidden_size,
623 | dim_head=attention_head_dim,
624 | heads=num_attention_heads,
625 | out_dim=hidden_size,
626 | context_pre_only=False,
627 | bias=True,
628 | processor=HunyuanAttnProcessorFlashAttnDouble(),
629 | qk_norm=qk_norm,
630 | eps=1e-6,
631 | )
632 |
633 | self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
634 | self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
635 |
636 | self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
637 | self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
638 |
639 | def forward(
640 | self,
641 | hidden_states: torch.Tensor,
642 | encoder_hidden_states: torch.Tensor,
643 | temb: torch.Tensor,
644 | attention_mask: Optional[torch.Tensor] = None,
645 | freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
646 | ) -> Tuple[torch.Tensor, torch.Tensor]:
647 | # 1. Input normalization
648 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
649 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
650 |
651 | # 2. Joint attention
652 | attn_output, context_attn_output = self.attn(
653 | hidden_states=norm_hidden_states,
654 | encoder_hidden_states=norm_encoder_hidden_states,
655 | attention_mask=attention_mask,
656 | image_rotary_emb=freqs_cis,
657 | )
658 |
659 | # 3. Modulation and residual connection
660 | hidden_states = hidden_states + attn_output * gate_msa
661 | encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
662 |
663 | norm_hidden_states = self.norm2(hidden_states)
664 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
665 |
666 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
667 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
668 |
669 | # 4. Feed-forward
670 | ff_output = self.ff(norm_hidden_states)
671 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
672 |
673 | hidden_states = hidden_states + gate_mlp * ff_output
674 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
675 |
676 | return hidden_states, encoder_hidden_states
677 |
678 |
679 | class ClipVisionProjection(nn.Module):
680 | def __init__(self, in_channels, out_channels):
681 | super().__init__()
682 | self.up = nn.Linear(in_channels, out_channels * 3)
683 | self.down = nn.Linear(out_channels * 3, out_channels)
684 |
685 | def forward(self, x):
686 | projected_x = self.down(nn.functional.silu(self.up(x)))
687 | return projected_x
688 |
689 |
690 | class HunyuanVideoPatchEmbed(nn.Module):
691 | def __init__(self, patch_size, in_chans, embed_dim):
692 | super().__init__()
693 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
694 |
695 |
696 | class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
697 | def __init__(self, inner_dim):
698 | super().__init__()
699 | self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
700 | self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
701 | self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
702 |
703 | @torch.no_grad()
704 | def initialize_weight_from_another_conv3d(self, another_layer):
705 | weight = another_layer.weight.detach().clone()
706 | bias = another_layer.bias.detach().clone()
707 |
708 | sd = {
709 | 'proj.weight': weight.clone(),
710 | 'proj.bias': bias.clone(),
711 | 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
712 | 'proj_2x.bias': bias.clone(),
713 | 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
714 | 'proj_4x.bias': bias.clone(),
715 | }
716 |
717 | sd = {k: v.clone() for k, v in sd.items()}
718 |
719 | self.load_state_dict(sd)
720 | return
721 |
722 |
723 | class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
724 | @register_to_config
725 | def __init__(
726 | self,
727 | in_channels: int = 16,
728 | out_channels: int = 16,
729 | num_attention_heads: int = 24,
730 | attention_head_dim: int = 128,
731 | num_layers: int = 20,
732 | num_single_layers: int = 40,
733 | num_refiner_layers: int = 2,
734 | mlp_ratio: float = 4.0,
735 | patch_size: int = 2,
736 | patch_size_t: int = 1,
737 | qk_norm: str = "rms_norm",
738 | guidance_embeds: bool = True,
739 | text_embed_dim: int = 4096,
740 | pooled_projection_dim: int = 768,
741 | rope_theta: float = 256.0,
742 | rope_axes_dim: Tuple[int] = (16, 56, 56),
743 | has_image_proj=False,
744 | image_proj_dim=1152,
745 | has_clean_x_embedder=False,
746 | ) -> None:
747 | super().__init__()
748 |
749 | inner_dim = num_attention_heads * attention_head_dim
750 | out_channels = out_channels or in_channels
751 |
752 | # 1. Latent and condition embedders
753 | self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
754 | self.context_embedder = HunyuanVideoTokenRefiner(
755 | text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
756 | )
757 | self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
758 |
759 | self.clean_x_embedder = None
760 | self.image_projection = None
761 |
762 | # 2. RoPE
763 | self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
764 |
765 | # 3. Dual stream transformer blocks
766 | self.transformer_blocks = nn.ModuleList(
767 | [
768 | HunyuanVideoTransformerBlock(
769 | num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
770 | )
771 | for _ in range(num_layers)
772 | ]
773 | )
774 |
775 | # 4. Single stream transformer blocks
776 | self.single_transformer_blocks = nn.ModuleList(
777 | [
778 | HunyuanVideoSingleTransformerBlock(
779 | num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
780 | )
781 | for _ in range(num_single_layers)
782 | ]
783 | )
784 |
785 | # 5. Output projection
786 | self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
787 | self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
788 |
789 | self.inner_dim = inner_dim
790 | self.use_gradient_checkpointing = False
791 | self.enable_teacache = False
792 |
793 | if has_image_proj:
794 | self.install_image_projection(image_proj_dim)
795 |
796 | if has_clean_x_embedder:
797 | self.install_clean_x_embedder()
798 |
799 | self.high_quality_fp32_output_for_inference = False
800 |
801 | def install_image_projection(self, in_channels):
802 | self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
803 | self.config['has_image_proj'] = True
804 | self.config['image_proj_dim'] = in_channels
805 |
806 | def install_clean_x_embedder(self):
807 | self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
808 | self.config['has_clean_x_embedder'] = True
809 |
810 | def enable_gradient_checkpointing(self):
811 | self.use_gradient_checkpointing = True
812 | print('self.use_gradient_checkpointing = True')
813 |
814 | def disable_gradient_checkpointing(self):
815 | self.use_gradient_checkpointing = False
816 | print('self.use_gradient_checkpointing = False')
817 |
818 | def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
819 | self.enable_teacache = enable_teacache
820 | self.cnt = 0
821 | self.num_steps = num_steps
822 | self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
823 | self.accumulated_rel_l1_distance = 0
824 | self.previous_modulated_input = None
825 | self.previous_residual = None
826 | self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
827 |
828 | def gradient_checkpointing_method(self, block, *args):
829 | if self.use_gradient_checkpointing:
830 | result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
831 | else:
832 | result = block(*args)
833 | return result
834 |
835 | def process_input_hidden_states(
836 | self,
837 | latents, latent_indices=None,
838 | clean_latents=None, clean_latent_indices=None,
839 | clean_latents_2x=None, clean_latent_2x_indices=None,
840 | clean_latents_4x=None, clean_latent_4x_indices=None
841 | ):
842 | hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
843 | B, C, T, H, W = hidden_states.shape
844 |
845 | if latent_indices is None:
846 | latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
847 |
848 | hidden_states = hidden_states.flatten(2).transpose(1, 2)
849 |
850 | rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
851 | rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
852 |
853 | if clean_latents is not None and clean_latent_indices is not None:
854 | clean_latents = clean_latents.to(hidden_states)
855 | clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
856 | clean_latents = clean_latents.flatten(2).transpose(1, 2)
857 |
858 | clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
859 | clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
860 |
861 | hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
862 | rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
863 |
864 | if clean_latents_2x is not None and clean_latent_2x_indices is not None:
865 | clean_latents_2x = clean_latents_2x.to(hidden_states)
866 | clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
867 | clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
868 | clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
869 |
870 | clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
871 | clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
872 | clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
873 | clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
874 |
875 | hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
876 | rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
877 |
878 | if clean_latents_4x is not None and clean_latent_4x_indices is not None:
879 | clean_latents_4x = clean_latents_4x.to(hidden_states)
880 | clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
881 | clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
882 | clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
883 |
884 | clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
885 | clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
886 | clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
887 | clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
888 |
889 | hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
890 | rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
891 |
892 | return hidden_states, rope_freqs
893 |
894 | def forward(
895 | self,
896 | hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
897 | latent_indices=None,
898 | clean_latents=None, clean_latent_indices=None,
899 | clean_latents_2x=None, clean_latent_2x_indices=None,
900 | clean_latents_4x=None, clean_latent_4x_indices=None,
901 | image_embeddings=None,
902 | attention_kwargs=None, return_dict=True
903 | ):
904 |
905 | if attention_kwargs is None:
906 | attention_kwargs = {}
907 |
908 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
909 | p, p_t = self.config['patch_size'], self.config['patch_size_t']
910 | post_patch_num_frames = num_frames // p_t
911 | post_patch_height = height // p
912 | post_patch_width = width // p
913 | original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
914 |
915 | hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
916 |
917 | temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
918 | encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
919 |
920 | if self.image_projection is not None:
921 | assert image_embeddings is not None, 'You must use image embeddings!'
922 | extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
923 | extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
924 |
925 | # must cat before (not after) encoder_hidden_states, due to attn masking
926 | encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
927 | encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
928 |
929 | with torch.no_grad():
930 | if batch_size == 1:
931 | # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
932 | # If they are not same, then their impls are wrong. Ours are always the correct one.
933 | text_len = encoder_attention_mask.sum().item()
934 | encoder_hidden_states = encoder_hidden_states[:, :text_len]
935 | attention_mask = None, None, None, None
936 | else:
937 | img_seq_len = hidden_states.shape[1]
938 | txt_seq_len = encoder_hidden_states.shape[1]
939 |
940 | cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
941 | cu_seqlens_kv = cu_seqlens_q
942 | max_seqlen_q = img_seq_len + txt_seq_len
943 | max_seqlen_kv = max_seqlen_q
944 |
945 | attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
946 |
947 | if self.enable_teacache:
948 | modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
949 |
950 | if self.cnt == 0 or self.cnt == self.num_steps-1:
951 | should_calc = True
952 | self.accumulated_rel_l1_distance = 0
953 | else:
954 | curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
955 | self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
956 | should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
957 |
958 | if should_calc:
959 | self.accumulated_rel_l1_distance = 0
960 |
961 | self.previous_modulated_input = modulated_inp
962 | self.cnt += 1
963 |
964 | if self.cnt == self.num_steps:
965 | self.cnt = 0
966 |
967 | if not should_calc:
968 | hidden_states = hidden_states + self.previous_residual
969 | else:
970 | ori_hidden_states = hidden_states.clone()
971 |
972 | for block_id, block in enumerate(self.transformer_blocks):
973 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
974 | block,
975 | hidden_states,
976 | encoder_hidden_states,
977 | temb,
978 | attention_mask,
979 | rope_freqs
980 | )
981 |
982 | for block_id, block in enumerate(self.single_transformer_blocks):
983 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
984 | block,
985 | hidden_states,
986 | encoder_hidden_states,
987 | temb,
988 | attention_mask,
989 | rope_freqs
990 | )
991 |
992 | self.previous_residual = hidden_states - ori_hidden_states
993 | else:
994 | for block_id, block in enumerate(self.transformer_blocks):
995 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
996 | block,
997 | hidden_states,
998 | encoder_hidden_states,
999 | temb,
1000 | attention_mask,
1001 | rope_freqs
1002 | )
1003 |
1004 | for block_id, block in enumerate(self.single_transformer_blocks):
1005 | hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1006 | block,
1007 | hidden_states,
1008 | encoder_hidden_states,
1009 | temb,
1010 | attention_mask,
1011 | rope_freqs
1012 | )
1013 |
1014 | hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1015 |
1016 | hidden_states = hidden_states[:, -original_context_length:, :]
1017 |
1018 | if self.high_quality_fp32_output_for_inference:
1019 | hidden_states = hidden_states.to(dtype=torch.float32)
1020 | if self.proj_out.weight.dtype != torch.float32:
1021 | self.proj_out.to(dtype=torch.float32)
1022 |
1023 | hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1024 |
1025 | hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1026 | t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1027 | pt=p_t, ph=p, pw=p)
1028 |
1029 | if return_dict:
1030 | return Transformer2DModelOutput(sample=hidden_states)
1031 |
1032 | return hidden_states,
1033 |
--------------------------------------------------------------------------------
/diffusers_helper/pipelines/__pycache__/k_diffusion_hunyuan.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/pipelines/__pycache__/k_diffusion_hunyuan.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusers_helper/pipelines/__pycache__/k_diffusion_hunyuan.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HM-RunningHub/ComfyUI_RH_FramePack/ad48591fcdf6af01489305a339e25fe6fcff8869/diffusers_helper/pipelines/__pycache__/k_diffusion_hunyuan.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusers_helper/pipelines/k_diffusion_hunyuan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5 | from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6 | from diffusers_helper.utils import repeat_to_batch_size
7 |
8 |
9 | def flux_time_shift(t, mu=1.15, sigma=1.0):
10 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11 |
12 |
13 | def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14 | k = (y2 - y1) / (x2 - x1)
15 | b = y1 - k * x1
16 | mu = k * context_length + b
17 | mu = min(mu, math.log(exp_max))
18 | return mu
19 |
20 |
21 | def get_flux_sigmas_from_mu(n, mu):
22 | sigmas = torch.linspace(1, 0, steps=n + 1)
23 | sigmas = flux_time_shift(sigmas, mu=mu)
24 | return sigmas
25 |
26 |
27 | @torch.inference_mode()
28 | def sample_hunyuan(
29 | transformer,
30 | sampler='unipc',
31 | initial_latent=None,
32 | concat_latent=None,
33 | strength=1.0,
34 | width=512,
35 | height=512,
36 | frames=16,
37 | real_guidance_scale=1.0,
38 | distilled_guidance_scale=6.0,
39 | guidance_rescale=0.0,
40 | shift=None,
41 | num_inference_steps=25,
42 | batch_size=None,
43 | generator=None,
44 | prompt_embeds=None,
45 | prompt_embeds_mask=None,
46 | prompt_poolers=None,
47 | negative_prompt_embeds=None,
48 | negative_prompt_embeds_mask=None,
49 | negative_prompt_poolers=None,
50 | dtype=torch.bfloat16,
51 | device=None,
52 | negative_kwargs=None,
53 | callback=None,
54 | **kwargs,
55 | ):
56 | device = device or transformer.device
57 |
58 | if batch_size is None:
59 | batch_size = int(prompt_embeds.shape[0])
60 |
61 | latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62 |
63 | B, C, T, H, W = latents.shape
64 | seq_length = T * H * W // 4
65 |
66 | if shift is None:
67 | mu = calculate_flux_mu(seq_length, exp_max=7.0)
68 | else:
69 | mu = math.log(shift)
70 |
71 | sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72 |
73 | k_model = fm_wrapper(transformer)
74 |
75 | if initial_latent is not None:
76 | sigmas = sigmas * strength
77 | first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78 | initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79 | latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80 |
81 | if concat_latent is not None:
82 | concat_latent = concat_latent.to(latents)
83 |
84 | distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85 |
86 | prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87 | prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88 | prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89 | negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90 | negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91 | negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92 | concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93 |
94 | sampler_kwargs = dict(
95 | dtype=dtype,
96 | cfg_scale=real_guidance_scale,
97 | cfg_rescale=guidance_rescale,
98 | concat_latent=concat_latent,
99 | positive=dict(
100 | pooled_projections=prompt_poolers,
101 | encoder_hidden_states=prompt_embeds,
102 | encoder_attention_mask=prompt_embeds_mask,
103 | guidance=distilled_guidance,
104 | **kwargs,
105 | ),
106 | negative=dict(
107 | pooled_projections=negative_prompt_poolers,
108 | encoder_hidden_states=negative_prompt_embeds,
109 | encoder_attention_mask=negative_prompt_embeds_mask,
110 | guidance=distilled_guidance,
111 | **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112 | )
113 | )
114 |
115 | if sampler == 'unipc':
116 | results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117 | else:
118 | raise NotImplementedError(f'Sampler {sampler} is not supported.')
119 |
120 | return results
121 |
--------------------------------------------------------------------------------
/diffusers_helper/thread_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from threading import Thread, Lock
4 |
5 |
6 | class Listener:
7 | task_queue = []
8 | lock = Lock()
9 | thread = None
10 |
11 | @classmethod
12 | def _process_tasks(cls):
13 | while True:
14 | task = None
15 | with cls.lock:
16 | if cls.task_queue:
17 | task = cls.task_queue.pop(0)
18 |
19 | if task is None:
20 | time.sleep(0.001)
21 | continue
22 |
23 | func, args, kwargs = task
24 | try:
25 | func(*args, **kwargs)
26 | except Exception as e:
27 | print(f"Error in listener thread: {e}")
28 |
29 | @classmethod
30 | def add_task(cls, func, *args, **kwargs):
31 | with cls.lock:
32 | cls.task_queue.append((func, args, kwargs))
33 |
34 | if cls.thread is None:
35 | cls.thread = Thread(target=cls._process_tasks, daemon=True)
36 | cls.thread.start()
37 |
38 |
39 | def async_run(func, *args, **kwargs):
40 | Listener.add_task(func, *args, **kwargs)
41 |
42 |
43 | class FIFOQueue:
44 | def __init__(self):
45 | self.queue = []
46 | self.lock = Lock()
47 |
48 | def push(self, item):
49 | with self.lock:
50 | self.queue.append(item)
51 |
52 | def pop(self):
53 | with self.lock:
54 | if self.queue:
55 | return self.queue.pop(0)
56 | return None
57 |
58 | def top(self):
59 | with self.lock:
60 | if self.queue:
61 | return self.queue[0]
62 | return None
63 |
64 | def next(self):
65 | while True:
66 | with self.lock:
67 | if self.queue:
68 | return self.queue.pop(0)
69 |
70 | time.sleep(0.001)
71 |
72 |
73 | class AsyncStream:
74 | def __init__(self):
75 | self.input_queue = FIFOQueue()
76 | self.output_queue = FIFOQueue()
77 |
--------------------------------------------------------------------------------
/diffusers_helper/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import random
5 | import glob
6 | import torch
7 | import einops
8 | import numpy as np
9 | import datetime
10 | import torchvision
11 |
12 | import safetensors.torch as sf
13 | from PIL import Image
14 |
15 |
16 | def min_resize(x, m):
17 | if x.shape[0] < x.shape[1]:
18 | s0 = m
19 | s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20 | else:
21 | s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22 | s1 = m
23 | new_max = max(s1, s0)
24 | raw_max = max(x.shape[0], x.shape[1])
25 | if new_max < raw_max:
26 | interpolation = cv2.INTER_AREA
27 | else:
28 | interpolation = cv2.INTER_LANCZOS4
29 | y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30 | return y
31 |
32 |
33 | def d_resize(x, y):
34 | H, W, C = y.shape
35 | new_min = min(H, W)
36 | raw_min = min(x.shape[0], x.shape[1])
37 | if new_min < raw_min:
38 | interpolation = cv2.INTER_AREA
39 | else:
40 | interpolation = cv2.INTER_LANCZOS4
41 | y = cv2.resize(x, (W, H), interpolation=interpolation)
42 | return y
43 |
44 |
45 | def resize_and_center_crop(image, target_width, target_height):
46 | if target_height == image.shape[0] and target_width == image.shape[1]:
47 | return image
48 |
49 | pil_image = Image.fromarray(image)
50 | original_width, original_height = pil_image.size
51 | scale_factor = max(target_width / original_width, target_height / original_height)
52 | resized_width = int(round(original_width * scale_factor))
53 | resized_height = int(round(original_height * scale_factor))
54 | resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55 | left = (resized_width - target_width) / 2
56 | top = (resized_height - target_height) / 2
57 | right = (resized_width + target_width) / 2
58 | bottom = (resized_height + target_height) / 2
59 | cropped_image = resized_image.crop((left, top, right, bottom))
60 | return np.array(cropped_image)
61 |
62 |
63 | def resize_and_center_crop_pytorch(image, target_width, target_height):
64 | B, C, H, W = image.shape
65 |
66 | if H == target_height and W == target_width:
67 | return image
68 |
69 | scale_factor = max(target_width / W, target_height / H)
70 | resized_width = int(round(W * scale_factor))
71 | resized_height = int(round(H * scale_factor))
72 |
73 | resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74 |
75 | top = (resized_height - target_height) // 2
76 | left = (resized_width - target_width) // 2
77 | cropped = resized[:, :, top:top + target_height, left:left + target_width]
78 |
79 | return cropped
80 |
81 |
82 | def resize_without_crop(image, target_width, target_height):
83 | if target_height == image.shape[0] and target_width == image.shape[1]:
84 | return image
85 |
86 | pil_image = Image.fromarray(image)
87 | resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88 | return np.array(resized_image)
89 |
90 |
91 | def just_crop(image, w, h):
92 | if h == image.shape[0] and w == image.shape[1]:
93 | return image
94 |
95 | original_height, original_width = image.shape[:2]
96 | k = min(original_height / h, original_width / w)
97 | new_width = int(round(w * k))
98 | new_height = int(round(h * k))
99 | x_start = (original_width - new_width) // 2
100 | y_start = (original_height - new_height) // 2
101 | cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102 | return cropped_image
103 |
104 |
105 | def write_to_json(data, file_path):
106 | temp_file_path = file_path + ".tmp"
107 | with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108 | json.dump(data, temp_file, indent=4)
109 | os.replace(temp_file_path, file_path)
110 | return
111 |
112 |
113 | def read_from_json(file_path):
114 | with open(file_path, 'rt', encoding='utf-8') as file:
115 | data = json.load(file)
116 | return data
117 |
118 |
119 | def get_active_parameters(m):
120 | return {k: v for k, v in m.named_parameters() if v.requires_grad}
121 |
122 |
123 | def cast_training_params(m, dtype=torch.float32):
124 | result = {}
125 | for n, param in m.named_parameters():
126 | if param.requires_grad:
127 | param.data = param.to(dtype)
128 | result[n] = param
129 | return result
130 |
131 |
132 | def separate_lora_AB(parameters, B_patterns=None):
133 | parameters_normal = {}
134 | parameters_B = {}
135 |
136 | if B_patterns is None:
137 | B_patterns = ['.lora_B.', '__zero__']
138 |
139 | for k, v in parameters.items():
140 | if any(B_pattern in k for B_pattern in B_patterns):
141 | parameters_B[k] = v
142 | else:
143 | parameters_normal[k] = v
144 |
145 | return parameters_normal, parameters_B
146 |
147 |
148 | def set_attr_recursive(obj, attr, value):
149 | attrs = attr.split(".")
150 | for name in attrs[:-1]:
151 | obj = getattr(obj, name)
152 | setattr(obj, attrs[-1], value)
153 | return
154 |
155 |
156 | def print_tensor_list_size(tensors):
157 | total_size = 0
158 | total_elements = 0
159 |
160 | if isinstance(tensors, dict):
161 | tensors = tensors.values()
162 |
163 | for tensor in tensors:
164 | total_size += tensor.nelement() * tensor.element_size()
165 | total_elements += tensor.nelement()
166 |
167 | total_size_MB = total_size / (1024 ** 2)
168 | total_elements_B = total_elements / 1e9
169 |
170 | print(f"Total number of tensors: {len(tensors)}")
171 | print(f"Total size of tensors: {total_size_MB:.2f} MB")
172 | print(f"Total number of parameters: {total_elements_B:.3f} billion")
173 | return
174 |
175 |
176 | @torch.no_grad()
177 | def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178 | batch_size = a.size(0)
179 |
180 | if b is None:
181 | b = torch.zeros_like(a)
182 |
183 | if mask_a is None:
184 | mask_a = torch.rand(batch_size) < probability_a
185 |
186 | mask_a = mask_a.to(a.device)
187 | mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188 | result = torch.where(mask_a, a, b)
189 | return result
190 |
191 |
192 | @torch.no_grad()
193 | def zero_module(module):
194 | for p in module.parameters():
195 | p.detach().zero_()
196 | return module
197 |
198 |
199 | @torch.no_grad()
200 | def supress_lower_channels(m, k, alpha=0.01):
201 | data = m.weight.data.clone()
202 |
203 | assert int(data.shape[1]) >= k
204 |
205 | data[:, :k] = data[:, :k] * alpha
206 | m.weight.data = data.contiguous().clone()
207 | return m
208 |
209 |
210 | def freeze_module(m):
211 | if not hasattr(m, '_forward_inside_frozen_module'):
212 | m._forward_inside_frozen_module = m.forward
213 | m.requires_grad_(False)
214 | m.forward = torch.no_grad()(m.forward)
215 | return m
216 |
217 |
218 | def get_latest_safetensors(folder_path):
219 | safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220 |
221 | if not safetensors_files:
222 | raise ValueError('No file to resume!')
223 |
224 | latest_file = max(safetensors_files, key=os.path.getmtime)
225 | latest_file = os.path.abspath(os.path.realpath(latest_file))
226 | return latest_file
227 |
228 |
229 | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230 | tags = tags_str.split(', ')
231 | tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232 | prompt = ', '.join(tags)
233 | return prompt
234 |
235 |
236 | def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237 | numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238 | if round_to_int:
239 | numbers = np.round(numbers).astype(int)
240 | return numbers.tolist()
241 |
242 |
243 | def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244 | edges = np.linspace(0, 1, n + 1)
245 | points = np.random.uniform(edges[:-1], edges[1:])
246 | numbers = inclusive + (exclusive - inclusive) * points
247 | if round_to_int:
248 | numbers = np.round(numbers).astype(int)
249 | return numbers.tolist()
250 |
251 |
252 | def soft_append_bcthw(history, current, overlap=0):
253 | if overlap <= 0:
254 | return torch.cat([history, current], dim=2)
255 |
256 | assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257 | assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258 |
259 | weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260 | blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261 | output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262 |
263 | return output.to(history)
264 |
265 |
266 | def save_bcthw_as_mp4(x, output_filename, fps=10):
267 | b, c, t, h, w = x.shape
268 |
269 | per_row = b
270 | for p in [6, 5, 4, 3, 2]:
271 | if b % p == 0:
272 | per_row = p
273 | break
274 |
275 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277 | x = x.detach().cpu().to(torch.uint8)
278 | x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279 | torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
280 | return x
281 |
282 |
283 | def save_bcthw_as_png(x, output_filename):
284 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286 | x = x.detach().cpu().to(torch.uint8)
287 | x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288 | torchvision.io.write_png(x, output_filename)
289 | return output_filename
290 |
291 |
292 | def save_bchw_as_png(x, output_filename):
293 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295 | x = x.detach().cpu().to(torch.uint8)
296 | x = einops.rearrange(x, 'b c h w -> c h (b w)')
297 | torchvision.io.write_png(x, output_filename)
298 | return output_filename
299 |
300 |
301 | def add_tensors_with_padding(tensor1, tensor2):
302 | if tensor1.shape == tensor2.shape:
303 | return tensor1 + tensor2
304 |
305 | shape1 = tensor1.shape
306 | shape2 = tensor2.shape
307 |
308 | new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309 |
310 | padded_tensor1 = torch.zeros(new_shape)
311 | padded_tensor2 = torch.zeros(new_shape)
312 |
313 | padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314 | padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315 |
316 | result = padded_tensor1 + padded_tensor2
317 | return result
318 |
319 |
320 | def print_free_mem():
321 | torch.cuda.empty_cache()
322 | free_mem, total_mem = torch.cuda.mem_get_info(0)
323 | free_mem_mb = free_mem / (1024 ** 2)
324 | total_mem_mb = total_mem / (1024 ** 2)
325 | print(f"Free memory: {free_mem_mb:.2f} MB")
326 | print(f"Total memory: {total_mem_mb:.2f} MB")
327 | return
328 |
329 |
330 | def print_gpu_parameters(device, state_dict, log_count=1):
331 | summary = {"device": device, "keys_count": len(state_dict)}
332 |
333 | logged_params = {}
334 | for i, (key, tensor) in enumerate(state_dict.items()):
335 | if i >= log_count:
336 | break
337 | logged_params[key] = tensor.flatten()[:3].tolist()
338 |
339 | summary["params"] = logged_params
340 |
341 | print(str(summary))
342 | return
343 |
344 |
345 | def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346 | from PIL import Image, ImageDraw, ImageFont
347 |
348 | txt = Image.new("RGB", (width, height), color="white")
349 | draw = ImageDraw.Draw(txt)
350 | font = ImageFont.truetype(font_path, size=size)
351 |
352 | if text == '':
353 | return np.array(txt)
354 |
355 | # Split text into lines that fit within the image width
356 | lines = []
357 | words = text.split()
358 | current_line = words[0]
359 |
360 | for word in words[1:]:
361 | line_with_word = f"{current_line} {word}"
362 | if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363 | current_line = line_with_word
364 | else:
365 | lines.append(current_line)
366 | current_line = word
367 |
368 | lines.append(current_line)
369 |
370 | # Draw the text line by line
371 | y = 0
372 | line_height = draw.textbbox((0, 0), "A", font=font)[3]
373 |
374 | for line in lines:
375 | if y + line_height > height:
376 | break # stop drawing if the next line will be outside the image
377 | draw.text((0, y), line, fill="black", font=font)
378 | y += line_height
379 |
380 | return np.array(txt)
381 |
382 |
383 | def blue_mark(x):
384 | x = x.copy()
385 | c = x[:, :, 2]
386 | b = cv2.blur(c, (9, 9))
387 | x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388 | return x
389 |
390 |
391 | def green_mark(x):
392 | x = x.copy()
393 | x[:, :, 2] = -1
394 | x[:, :, 0] = -1
395 | return x
396 |
397 |
398 | def frame_mark(x):
399 | x = x.copy()
400 | x[:64] = -1
401 | x[-64:] = -1
402 | x[:, :8] = 1
403 | x[:, -8:] = 1
404 | return x
405 |
406 |
407 | @torch.inference_mode()
408 | def pytorch2numpy(imgs):
409 | results = []
410 | for x in imgs:
411 | y = x.movedim(0, -1)
412 | y = y * 127.5 + 127.5
413 | y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414 | results.append(y)
415 | return results
416 |
417 |
418 | @torch.inference_mode()
419 | def numpy2pytorch(imgs):
420 | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421 | h = h.movedim(-1, 1)
422 | return h
423 |
424 |
425 | @torch.no_grad()
426 | def duplicate_prefix_to_suffix(x, count, zero_out=False):
427 | if zero_out:
428 | return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429 | else:
430 | return torch.cat([x, x[:count]], dim=0)
431 |
432 |
433 | def weighted_mse(a, b, weight):
434 | return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435 |
436 |
437 | def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438 | x = (x - x_min) / (x_max - x_min)
439 | x = max(0.0, min(x, 1.0))
440 | x = x ** sigma
441 | return y_min + x * (y_max - y_min)
442 |
443 |
444 | def expand_to_dims(x, target_dims):
445 | return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446 |
447 |
448 | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449 | if tensor is None:
450 | return None
451 |
452 | first_dim = tensor.shape[0]
453 |
454 | if first_dim == batch_size:
455 | return tensor
456 |
457 | if batch_size % first_dim != 0:
458 | raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459 |
460 | repeat_times = batch_size // first_dim
461 |
462 | return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463 |
464 |
465 | def dim5(x):
466 | return expand_to_dims(x, 5)
467 |
468 |
469 | def dim4(x):
470 | return expand_to_dims(x, 4)
471 |
472 |
473 | def dim3(x):
474 | return expand_to_dims(x, 3)
475 |
476 |
477 | def crop_or_pad_yield_mask(x, length):
478 | B, F, C = x.shape
479 | device = x.device
480 | dtype = x.dtype
481 |
482 | if F < length:
483 | y = torch.zeros((B, length, C), dtype=dtype, device=device)
484 | mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485 | y[:, :F, :] = x
486 | mask[:, :F] = True
487 | return y, mask
488 |
489 | return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490 |
491 |
492 | def extend_dim(x, dim, minimal_length, zero_pad=False):
493 | original_length = int(x.shape[dim])
494 |
495 | if original_length >= minimal_length:
496 | return x
497 |
498 | if zero_pad:
499 | padding_shape = list(x.shape)
500 | padding_shape[dim] = minimal_length - original_length
501 | padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502 | else:
503 | idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504 | last_element = x[idx]
505 | padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506 |
507 | return torch.cat([x, padding], dim=dim)
508 |
509 |
510 | def lazy_positional_encoding(t, repeats=None):
511 | if not isinstance(t, list):
512 | t = [t]
513 |
514 | from diffusers.models.embeddings import get_timestep_embedding
515 |
516 | te = torch.tensor(t)
517 | te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518 |
519 | if repeats is None:
520 | return te
521 |
522 | te = te[:, None, :].expand(-1, repeats, -1)
523 |
524 | return te
525 |
526 |
527 | def state_dict_offset_merge(A, B, C=None):
528 | result = {}
529 | keys = A.keys()
530 |
531 | for key in keys:
532 | A_value = A[key]
533 | B_value = B[key].to(A_value)
534 |
535 | if C is None:
536 | result[key] = A_value + B_value
537 | else:
538 | C_value = C[key].to(A_value)
539 | result[key] = A_value + B_value - C_value
540 |
541 | return result
542 |
543 |
544 | def state_dict_weighted_merge(state_dicts, weights):
545 | if len(state_dicts) != len(weights):
546 | raise ValueError("Number of state dictionaries must match number of weights")
547 |
548 | if not state_dicts:
549 | return {}
550 |
551 | total_weight = sum(weights)
552 |
553 | if total_weight == 0:
554 | raise ValueError("Sum of weights cannot be zero")
555 |
556 | normalized_weights = [w / total_weight for w in weights]
557 |
558 | keys = state_dicts[0].keys()
559 | result = {}
560 |
561 | for key in keys:
562 | result[key] = state_dicts[0][key] * normalized_weights[0]
563 |
564 | for i in range(1, len(state_dicts)):
565 | state_dict_value = state_dicts[i][key].to(result[key])
566 | result[key] += state_dict_value * normalized_weights[i]
567 |
568 | return result
569 |
570 |
571 | def group_files_by_folder(all_files):
572 | grouped_files = {}
573 |
574 | for file in all_files:
575 | folder_name = os.path.basename(os.path.dirname(file))
576 | if folder_name not in grouped_files:
577 | grouped_files[folder_name] = []
578 | grouped_files[folder_name].append(file)
579 |
580 | list_of_lists = list(grouped_files.values())
581 | return list_of_lists
582 |
583 |
584 | def generate_timestamp():
585 | now = datetime.datetime.now()
586 | timestamp = now.strftime('%y%m%d_%H%M%S')
587 | milliseconds = f"{int(now.microsecond / 1000):03d}"
588 | random_number = random.randint(0, 9999)
589 | return f"{timestamp}_{milliseconds}_{random_number}"
590 |
591 |
592 | def write_PIL_image_with_png_info(image, metadata, path):
593 | from PIL.PngImagePlugin import PngInfo
594 |
595 | png_info = PngInfo()
596 | for key, value in metadata.items():
597 | png_info.add_text(key, value)
598 |
599 | image.save(path, "PNG", pnginfo=png_info)
600 | return image
601 |
602 |
603 | def torch_safe_save(content, path):
604 | torch.save(content, path + '_tmp')
605 | os.replace(path + '_tmp', path)
606 | return path
607 |
608 |
609 | def move_optimizer_to_device(optimizer, device):
610 | for state in optimizer.state.values():
611 | for k, v in state.items():
612 | if isinstance(v, torch.Tensor):
613 | state[k] = v.to(device)
614 |
--------------------------------------------------------------------------------
/examples/FramePack_endimage.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 18,
3 | "last_link_id": 24,
4 | "nodes": [
5 | {
6 | "id": 15,
7 | "type": "LoadImage",
8 | "pos": [
9 | 1632.1324462890625,
10 | 756.5489501953125
11 | ],
12 | "size": [
13 | 467.89154052734375,
14 | 535.3126220703125
15 | ],
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [],
20 | "outputs": [
21 | {
22 | "name": "IMAGE",
23 | "label": "IMAGE",
24 | "type": "IMAGE",
25 | "links": [
26 | 20
27 | ],
28 | "slot_index": 0
29 | },
30 | {
31 | "name": "MASK",
32 | "label": "MASK",
33 | "type": "MASK"
34 | }
35 | ],
36 | "properties": {
37 | "cnr_id": "comfy-core",
38 | "ver": "0.3.28",
39 | "Node name for S&R": "LoadImage"
40 | },
41 | "widgets_values": [
42 | "ComfyUI_temp_ibecr_00005_lqaex_1744979262.png",
43 | "image"
44 | ]
45 | },
46 | {
47 | "id": 2,
48 | "type": "LoadImage",
49 | "pos": [
50 | 1617.58544921875,
51 | 152.19729614257812
52 | ],
53 | "size": [
54 | 467.89154052734375,
55 | 535.3126220703125
56 | ],
57 | "flags": {},
58 | "order": 1,
59 | "mode": 0,
60 | "inputs": [],
61 | "outputs": [
62 | {
63 | "name": "IMAGE",
64 | "label": "IMAGE",
65 | "type": "IMAGE",
66 | "links": [
67 | 21
68 | ]
69 | },
70 | {
71 | "name": "MASK",
72 | "label": "MASK",
73 | "type": "MASK",
74 | "slot_index": 1
75 | }
76 | ],
77 | "properties": {
78 | "cnr_id": "comfy-core",
79 | "ver": "0.3.28",
80 | "Node name for S&R": "LoadImage"
81 | },
82 | "widgets_values": [
83 | "ComfyUI_temp_ibecr_00009_kzhcr_1744979266.png",
84 | "image"
85 | ]
86 | },
87 | {
88 | "id": 18,
89 | "type": "PreviewImage",
90 | "pos": [
91 | 2159.90087890625,
92 | 805.84716796875
93 | ],
94 | "size": [
95 | 591.8181762695312,
96 | 519.8636474609375
97 | ],
98 | "flags": {},
99 | "order": 3,
100 | "mode": 0,
101 | "inputs": [
102 | {
103 | "name": "images",
104 | "label": "images",
105 | "type": "IMAGE",
106 | "link": 24
107 | }
108 | ],
109 | "outputs": [],
110 | "properties": {
111 | "cnr_id": "comfy-core",
112 | "ver": "0.3.28",
113 | "Node name for S&R": "PreviewImage"
114 | },
115 | "widgets_values": []
116 | },
117 | {
118 | "id": 17,
119 | "type": "RunningHub_FramePack",
120 | "pos": [
121 | 2207.928466796875,
122 | 465.5581970214844
123 | ],
124 | "size": [
125 | 400,
126 | 252
127 | ],
128 | "flags": {},
129 | "order": 2,
130 | "mode": 0,
131 | "inputs": [
132 | {
133 | "name": "ref_image",
134 | "label": "ref_image",
135 | "type": "IMAGE",
136 | "link": 21
137 | },
138 | {
139 | "name": "end_image",
140 | "label": "end_image",
141 | "type": "IMAGE",
142 | "shape": 7,
143 | "link": 20
144 | }
145 | ],
146 | "outputs": [
147 | {
148 | "name": "frames",
149 | "label": "frames",
150 | "type": "IMAGE",
151 | "links": [
152 | 22,
153 | 24
154 | ],
155 | "slot_index": 0
156 | },
157 | {
158 | "name": "fps",
159 | "label": "fps",
160 | "type": "FLOAT",
161 | "links": [
162 | 23
163 | ],
164 | "slot_index": 1
165 | }
166 | ],
167 | "properties": {
168 | "aux_id": "HM-RunningHub/ComfyUI_RH_FramePack",
169 | "ver": "c688eb1533f8984a5ea5d2db08496ebb6da0a602",
170 | "Node name for S&R": "RunningHub_FramePack"
171 | },
172 | "widgets_values": [
173 | "Advanced video dynamic shots\n\n",
174 | 3,
175 | 932,
176 | "randomize",
177 | 25,
178 | true,
179 | 1,
180 | [
181 | false,
182 | true
183 | ]
184 | ]
185 | },
186 | {
187 | "id": 3,
188 | "type": "VHS_VideoCombine",
189 | "pos": [
190 | 2928.890625,
191 | 427.95306396484375
192 | ],
193 | "size": [
194 | 472.8837890625,
195 | 820.8837890625
196 | ],
197 | "flags": {},
198 | "order": 4,
199 | "mode": 0,
200 | "inputs": [
201 | {
202 | "name": "images",
203 | "label": "images",
204 | "type": "IMAGE",
205 | "link": 22
206 | },
207 | {
208 | "name": "audio",
209 | "label": "audio",
210 | "type": "AUDIO",
211 | "shape": 7
212 | },
213 | {
214 | "name": "meta_batch",
215 | "label": "meta_batch",
216 | "type": "VHS_BatchManager",
217 | "shape": 7
218 | },
219 | {
220 | "name": "vae",
221 | "label": "vae",
222 | "type": "VAE",
223 | "shape": 7
224 | },
225 | {
226 | "name": "frame_rate",
227 | "label": "frame_rate",
228 | "type": "FLOAT",
229 | "widget": {
230 | "name": "frame_rate"
231 | },
232 | "link": 23
233 | }
234 | ],
235 | "outputs": [
236 | {
237 | "name": "Filenames",
238 | "label": "Filenames",
239 | "type": "VHS_FILENAMES"
240 | }
241 | ],
242 | "properties": {
243 | "cnr_id": "comfyui-videohelpersuite",
244 | "ver": "df55f01d1df2f7bf5cc772294bc2e6d8bab22d66",
245 | "Node name for S&R": "VHS_VideoCombine"
246 | },
247 | "widgets_values": {
248 | "frame_rate": 8,
249 | "loop_count": 0,
250 | "filename_prefix": "AnimateDiff",
251 | "format": "video/h264-mp4",
252 | "pix_fmt": "yuv420p",
253 | "crf": 19,
254 | "save_metadata": true,
255 | "trim_to_audio": false,
256 | "pingpong": false,
257 | "save_output": true,
258 | "videopreview": {
259 | "paused": false,
260 | "hidden": false,
261 | "params": {
262 | "filename": "AnimateDiff_00004.mp4",
263 | "workflow": "AnimateDiff_00004.png",
264 | "fullpath": "D:\\ComfyUI_windows_portable\\ComfyUI\\output\\AnimateDiff_00004.mp4",
265 | "format": "video/h264-mp4",
266 | "subfolder": "",
267 | "type": "output",
268 | "frame_rate": 30
269 | },
270 | "muted": false
271 | }
272 | }
273 | }
274 | ],
275 | "links": [
276 | [
277 | 20,
278 | 15,
279 | 0,
280 | 17,
281 | 1,
282 | "IMAGE"
283 | ],
284 | [
285 | 21,
286 | 2,
287 | 0,
288 | 17,
289 | 0,
290 | "IMAGE"
291 | ],
292 | [
293 | 22,
294 | 17,
295 | 0,
296 | 3,
297 | 0,
298 | "IMAGE"
299 | ],
300 | [
301 | 23,
302 | 17,
303 | 1,
304 | 3,
305 | 4,
306 | "FLOAT"
307 | ],
308 | [
309 | 24,
310 | 17,
311 | 0,
312 | 18,
313 | 0,
314 | "IMAGE"
315 | ]
316 | ],
317 | "groups": [],
318 | "config": {},
319 | "extra": {
320 | "ds": {
321 | "scale": 0.8264462809917354,
322 | "offset": [
323 | -1427.551752071999,
324 | -102.40491658569128
325 | ]
326 | },
327 | "ue_links": [],
328 | "0246.VERSION": [
329 | 0,
330 | 0,
331 | 4
332 | ],
333 | "VHS_latentpreview": false,
334 | "VHS_latentpreviewrate": 0,
335 | "VHS_MetadataImage": true,
336 | "VHS_KeepIntermediate": true
337 | },
338 | "version": 0.4
339 | }
--------------------------------------------------------------------------------
/examples/FramePack_regular.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 16,
3 | "last_link_id": 19,
4 | "nodes": [
5 | {
6 | "id": 7,
7 | "type": "SeargePromptText",
8 | "pos": [
9 | 1632.283203125,
10 | 799.6035766601562
11 | ],
12 | "size": [
13 | 400,
14 | 200
15 | ],
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [],
20 | "outputs": [
21 | {
22 | "name": "prompt",
23 | "label": "prompt",
24 | "type": "STRING",
25 | "links": [
26 | 16
27 | ],
28 | "slot_index": 0
29 | }
30 | ],
31 | "properties": {
32 | "cnr_id": "SeargeSDXL",
33 | "ver": "2eb5edbc712329d77d1a2f5f1e6c5e64397a4a83",
34 | "Node name for S&R": "SeargePromptText"
35 | },
36 | "widgets_values": [
37 | "一个女人 带着恋爱般的微笑,对着镜头双手“比心”",
38 | [
39 | false,
40 | true
41 | ]
42 | ]
43 | },
44 | {
45 | "id": 15,
46 | "type": "RunningHub_FramePack",
47 | "pos": [
48 | 2313.555419921875,
49 | 505.303955078125
50 | ],
51 | "size": [
52 | 400,
53 | 252
54 | ],
55 | "flags": {},
56 | "order": 2,
57 | "mode": 0,
58 | "inputs": [
59 | {
60 | "name": "ref_image",
61 | "label": "ref_image",
62 | "type": "IMAGE",
63 | "link": 15
64 | },
65 | {
66 | "name": "end_image",
67 | "label": "end_image",
68 | "type": "IMAGE",
69 | "shape": 7,
70 | "link": null
71 | },
72 | {
73 | "name": "prompt",
74 | "label": "prompt",
75 | "type": "STRING",
76 | "widget": {
77 | "name": "prompt"
78 | },
79 | "link": 16
80 | }
81 | ],
82 | "outputs": [
83 | {
84 | "name": "frames",
85 | "label": "frames",
86 | "type": "IMAGE",
87 | "links": [
88 | 17
89 | ],
90 | "slot_index": 0
91 | },
92 | {
93 | "name": "fps",
94 | "label": "fps",
95 | "type": "FLOAT",
96 | "links": [
97 | 18
98 | ],
99 | "slot_index": 1
100 | }
101 | ],
102 | "properties": {
103 | "aux_id": "HM-RunningHub/ComfyUI_RH_FramePack",
104 | "ver": "c688eb1533f8984a5ea5d2db08496ebb6da0a602",
105 | "Node name for S&R": "RunningHub_FramePack"
106 | },
107 | "widgets_values": [
108 | "",
109 | 5,
110 | 1378,
111 | "randomize",
112 | 25,
113 | true,
114 | 1.2,
115 | [
116 | false,
117 | true
118 | ]
119 | ]
120 | },
121 | {
122 | "id": 16,
123 | "type": "VHS_VideoCombine",
124 | "pos": [
125 | 2928.94970703125,
126 | 180.33053588867188
127 | ],
128 | "size": [
129 | 419.164794921875,
130 | 803.4525146484375
131 | ],
132 | "flags": {},
133 | "order": 3,
134 | "mode": 0,
135 | "inputs": [
136 | {
137 | "name": "images",
138 | "label": "images",
139 | "type": "IMAGE",
140 | "link": 17
141 | },
142 | {
143 | "name": "audio",
144 | "label": "audio",
145 | "type": "AUDIO",
146 | "shape": 7,
147 | "link": null
148 | },
149 | {
150 | "name": "meta_batch",
151 | "label": "meta_batch",
152 | "type": "VHS_BatchManager",
153 | "shape": 7,
154 | "link": null
155 | },
156 | {
157 | "name": "vae",
158 | "label": "vae",
159 | "type": "VAE",
160 | "shape": 7,
161 | "link": null
162 | },
163 | {
164 | "name": "frame_rate",
165 | "label": "frame_rate",
166 | "type": "FLOAT",
167 | "widget": {
168 | "name": "frame_rate"
169 | },
170 | "link": 18
171 | }
172 | ],
173 | "outputs": [
174 | {
175 | "name": "Filenames",
176 | "label": "Filenames",
177 | "type": "VHS_FILENAMES"
178 | }
179 | ],
180 | "properties": {
181 | "cnr_id": "comfyui-videohelpersuite",
182 | "ver": "df55f01d1df2f7bf5cc772294bc2e6d8bab22d66",
183 | "Node name for S&R": "VHS_VideoCombine"
184 | },
185 | "widgets_values": {
186 | "frame_rate": 8,
187 | "loop_count": 0,
188 | "filename_prefix": "AnimateDiff",
189 | "format": "video/h264-mp4",
190 | "pix_fmt": "yuv420p",
191 | "crf": 19,
192 | "save_metadata": true,
193 | "trim_to_audio": false,
194 | "pingpong": false,
195 | "save_output": true,
196 | "videopreview": {
197 | "paused": false,
198 | "hidden": false,
199 | "params": {
200 | "filename": "AnimateDiff_00009.mp4",
201 | "workflow": "AnimateDiff_00009.png",
202 | "fullpath": "D:\\ComfyUI_windows_portable\\ComfyUI\\output\\AnimateDiff_00009.mp4",
203 | "format": "video/h264-mp4",
204 | "subfolder": "",
205 | "type": "output",
206 | "frame_rate": 30
207 | },
208 | "muted": false
209 | }
210 | }
211 | },
212 | {
213 | "id": 2,
214 | "type": "LoadImage",
215 | "pos": [
216 | 1614.83544921875,
217 | 164.57229614257812
218 | ],
219 | "size": [
220 | 467.89154052734375,
221 | 535.3126220703125
222 | ],
223 | "flags": {},
224 | "order": 1,
225 | "mode": 0,
226 | "inputs": [],
227 | "outputs": [
228 | {
229 | "name": "IMAGE",
230 | "label": "IMAGE",
231 | "type": "IMAGE",
232 | "links": [
233 | 15
234 | ],
235 | "slot_index": 0
236 | },
237 | {
238 | "name": "MASK",
239 | "label": "MASK",
240 | "type": "MASK"
241 | }
242 | ],
243 | "properties": {
244 | "cnr_id": "comfy-core",
245 | "ver": "0.3.28",
246 | "Node name for S&R": "LoadImage"
247 | },
248 | "widgets_values": [
249 | "DONGYUJIE.png",
250 | "image"
251 | ]
252 | }
253 | ],
254 | "links": [
255 | [
256 | 15,
257 | 2,
258 | 0,
259 | 15,
260 | 0,
261 | "IMAGE"
262 | ],
263 | [
264 | 16,
265 | 7,
266 | 0,
267 | 15,
268 | 2,
269 | "STRING"
270 | ],
271 | [
272 | 17,
273 | 15,
274 | 0,
275 | 16,
276 | 0,
277 | "IMAGE"
278 | ],
279 | [
280 | 18,
281 | 15,
282 | 1,
283 | 16,
284 | 4,
285 | "FLOAT"
286 | ]
287 | ],
288 | "groups": [],
289 | "config": {},
290 | "extra": {
291 | "ds": {
292 | "scale": 0.9090909090909091,
293 | "offset": [
294 | -1563.0115557309155,
295 | -19.100634686877555
296 | ]
297 | },
298 | "ue_links": [],
299 | "0246.VERSION": [
300 | 0,
301 | 0,
302 | 4
303 | ],
304 | "VHS_latentpreview": false,
305 | "VHS_latentpreviewrate": 0,
306 | "VHS_MetadataImage": true,
307 | "VHS_KeepIntermediate": true
308 | },
309 | "version": 0.4
310 | }
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | current_dir = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.insert(0, current_dir)
5 |
6 | import torch
7 | import traceback
8 | import einops
9 | import safetensors.torch as sf
10 | import numpy as np
11 | import argparse
12 | import math
13 | import time
14 |
15 | from PIL import Image
16 | from diffusers import AutoencoderKLHunyuanVideo
17 | from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
18 | from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
19 | from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
20 | from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
21 | from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
22 | from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
23 | from transformers import SiglipImageProcessor, SiglipVisionModel
24 | from diffusers_helper.clip_vision import hf_clip_vision_encode
25 | from diffusers_helper.bucket_tools import find_nearest_bucket
26 | import hashlib
27 | import random
28 | import string
29 | import torchvision
30 | from torchvision.transforms.functional import to_pil_image
31 | import comfy.utils
32 |
33 | from PIL import Image
34 | import folder_paths
35 |
36 | class Kiki_FramePack:
37 | @classmethod
38 | def INPUT_TYPES(s):
39 | return {
40 | "required": {
41 | "ref_image": ("IMAGE", ),
42 | "prompt": ("STRING", {"multiline": True}),
43 | # "n_prompt": ("STRING", {"multiline": True}),
44 | "total_second_length": ("INT", {"default": 5, "min": 1, "max": 120, "step": 1}),
45 | "seed": ("INT", {"default": 3407}),
46 | "steps": ("INT", {"default": 25, "min": 1, "max": 100, "step": 1}),
47 | "use_teacache": ("BOOLEAN", {"default": True}),
48 | "upscale": ("FLOAT", {"default": 1.2, "min": 0.1, "max": 2.0, "step": 0.1, "description": "Resolution scaling factor. 1.0 = original size, >1.0 = upscale, <1.0 = downscale"}),
49 | },
50 | "optional": {
51 | "end_image": ("IMAGE", ),
52 | },
53 | }
54 |
55 | RETURN_TYPES = ("IMAGE", "FLOAT")
56 | RETURN_NAMES = ("frames", "fps")
57 | CATEGORY = "Runninghub/FramePack"
58 | FUNCTION = "run"
59 |
60 | TITLE = 'RunningHub FramePack'
61 | OUTPUT_NODE = True
62 |
63 | def __init__(self):
64 | self.high_vram = False
65 | self.frames = None
66 | self.fps = None
67 |
68 | hunyuan_root = os.path.join(folder_paths.models_dir, 'HunyuanVideo')
69 | flux_redux_bfl_root = os.path.join(folder_paths.models_dir, 'flux_redux_bfl')
70 | framePackI2V_root = os.path.join(folder_paths.models_dir, 'FramePackI2V_HY')
71 |
72 | self.text_encoder = LlamaModel.from_pretrained(hunyuan_root, subfolder='text_encoder', torch_dtype=torch.float16).cpu()
73 | self.text_encoder_2 = CLIPTextModel.from_pretrained(hunyuan_root, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
74 | self.tokenizer = LlamaTokenizerFast.from_pretrained(hunyuan_root, subfolder='tokenizer')
75 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(hunyuan_root, subfolder='tokenizer_2')
76 | self.vae = AutoencoderKLHunyuanVideo.from_pretrained(hunyuan_root, subfolder='vae', torch_dtype=torch.float16).cpu()
77 |
78 | self.feature_extractor = SiglipImageProcessor.from_pretrained(flux_redux_bfl_root, subfolder='feature_extractor')
79 | self.image_encoder = SiglipVisionModel.from_pretrained(flux_redux_bfl_root, subfolder='image_encoder', torch_dtype=torch.float16).cpu()
80 |
81 | self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(framePackI2V_root, torch_dtype=torch.bfloat16).cpu()
82 |
83 | self.vae.eval()
84 | self.text_encoder.eval()
85 | self.text_encoder_2.eval()
86 | self.image_encoder.eval()
87 | self.transformer.eval()
88 |
89 | if not self.high_vram:
90 | self.vae.enable_slicing()
91 | self.vae.enable_tiling()
92 |
93 | self.transformer.high_quality_fp32_output_for_inference = True
94 | print('transformer.high_quality_fp32_output_for_inference = True')
95 |
96 | self.transformer.to(dtype=torch.bfloat16)
97 | self.vae.to(dtype=torch.float16)
98 | self.image_encoder.to(dtype=torch.float16)
99 | self.text_encoder.to(dtype=torch.float16)
100 | self.text_encoder_2.to(dtype=torch.float16)
101 |
102 | self.vae.requires_grad_(False)
103 | self.text_encoder.requires_grad_(False)
104 | self.text_encoder_2.requires_grad_(False)
105 | self.image_encoder.requires_grad_(False)
106 | self.transformer.requires_grad_(False)
107 |
108 | if not self.high_vram:
109 | # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
110 | DynamicSwapInstaller.install_model(self.transformer, device=gpu)
111 | DynamicSwapInstaller.install_model(self.text_encoder, device=gpu)
112 |
113 | def strict_align(self, h, w, scale):
114 | raw_h = h * scale
115 | raw_w = w * scale
116 |
117 | aligned_h = int(round(raw_h / 64)) * 64
118 | aligned_w = int(round(raw_w / 64)) * 64
119 |
120 | assert (aligned_h % 64 == 0) and (aligned_w % 64 == 0), "尺寸必须是64的倍数"
121 | assert (aligned_h//8) % 8 == 0 and (aligned_w//8) % 8 == 0, "潜在空间需要8的倍数"
122 | return aligned_h, aligned_w
123 |
124 | def preprocess_image(self, image):
125 | if image is None:
126 | return None
127 | image_np = 255. * image[0].cpu().numpy()
128 | image = Image.fromarray(np.clip(image_np, 0, 255).astype(np.uint8)).convert("RGB")
129 | input_image = np.array(image)
130 | return input_image
131 |
132 | def run(self, **kwargs):
133 | try:
134 | image = kwargs['ref_image']
135 | end_image = kwargs.get('end_image', None) # Use get with None as default
136 | image_np = self.preprocess_image(image)
137 | end_image_np = self.preprocess_image(end_image) if end_image is not None else None
138 | prompt = kwargs['prompt']
139 | seed = kwargs['seed']
140 | total_second_length = kwargs['total_second_length']
141 | steps = kwargs['steps']
142 | use_teacache = kwargs['use_teacache']
143 | upscale = kwargs['upscale']
144 | random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
145 | video_path = os.path.join(folder_paths.get_output_directory(), f'{random_str}.mp4')
146 |
147 | self.pbar = comfy.utils.ProgressBar(steps * total_second_length)
148 |
149 | self.exec(input_image=image_np, end_image=end_image_np, prompt=prompt, seed=seed, total_second_length=total_second_length, video_path=video_path, steps=steps, use_teacache=use_teacache, scale=upscale)
150 |
151 | if os.path.exists(video_path):
152 | self.fps = self.get_fps_with_torchvision(video_path)
153 | self.frames = self.extract_frames_as_pil(video_path)
154 | print(f'{video_path}:{self.fps} {len(self.frames)}')
155 | else:
156 | self.frames = []
157 | self.fps = 0.0
158 | except Exception as e:
159 | print(f"Error in run: {str(e)}")
160 | traceback.print_exc()
161 | self.frames = []
162 | self.fps = 0.0
163 |
164 | return (self.frames, self.fps)
165 |
166 | @torch.no_grad()
167 | def exec(self, input_image, video_path,
168 | end_image=None,
169 | prompt="The girl dances gracefully, with clear movements, full of charm.",
170 | n_prompt="",
171 | seed=31337,
172 | total_second_length=5,
173 | latent_window_size=9,
174 | steps=25,
175 | cfg=1,
176 | gs=32,
177 | rs=0,
178 | gpu_memory_preservation=6,
179 | use_teacache=True,
180 | scale=1.0):
181 |
182 | total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
183 | total_latent_sections = int(max(round(total_latent_sections), 1))
184 |
185 | try:
186 | # Clean GPU
187 | if not self.high_vram:
188 | unload_complete_models(
189 | self.text_encoder, self.text_encoder_2, self.image_encoder, self.vae, self.transformer
190 | )
191 |
192 | # Text encoding
193 | print('Text encoding')
194 |
195 | if not self.high_vram:
196 | fake_diffusers_current_device(self.text_encoder, gpu)
197 | load_model_as_complete(self.text_encoder_2, target_device=gpu)
198 |
199 | llama_vec, clip_l_pooler = encode_prompt_conds(prompt, self.text_encoder, self.text_encoder_2, self.tokenizer, self.tokenizer_2)
200 |
201 | if cfg == 1:
202 | llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
203 | else:
204 | llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, self.text_encoder, self.text_encoder_2, self.tokenizer, self.tokenizer_2)
205 |
206 | llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
207 | llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
208 |
209 | # Processing input image (start frame)
210 | print('Processing start frame ...')
211 |
212 | H, W, C = input_image.shape
213 | height, width = find_nearest_bucket(H, W, resolution=640)
214 | print(f"Resized height: {height}, Resized width: {width}")
215 |
216 | height, width = self.strict_align(height, width, scale)
217 | print(f"After Resized height: {height}, Resized width: {width}")
218 |
219 | input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
220 | input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
221 | input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
222 |
223 | # Processing end image if provided
224 | has_end_image = end_image is not None
225 | end_image_np = None
226 | end_image_pt = None
227 |
228 | if has_end_image:
229 | print('Processing end frame ...')
230 | H_end, W_end, C_end = end_image.shape
231 | end_image_np = resize_and_center_crop(end_image, target_width=width, target_height=height)
232 | end_image_pt = torch.from_numpy(end_image_np).float() / 127.5 - 1
233 | end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]
234 |
235 | # VAE encoding
236 | print('VAE encoding ...')
237 |
238 | if not self.high_vram:
239 | load_model_as_complete(self.vae, target_device=gpu)
240 |
241 | start_latent = vae_encode(input_image_pt, self.vae)
242 | end_latent = None
243 | if has_end_image:
244 | end_latent = vae_encode(end_image_pt, self.vae)
245 |
246 | # CLIP Vision
247 | print('CLIP Vision encoding ...')
248 |
249 | if not self.high_vram:
250 | load_model_as_complete(self.image_encoder, target_device=gpu)
251 |
252 | # Start image encoding
253 | image_encoder_output = hf_clip_vision_encode(input_image_np, self.feature_extractor, self.image_encoder)
254 | image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
255 |
256 | # End image encoding if available
257 | if has_end_image:
258 | end_image_encoder_output = hf_clip_vision_encode(end_image_np, self.feature_extractor, self.image_encoder)
259 | end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
260 | # Use a simple average of embeddings - exactly like in the original code
261 | image_encoder_last_hidden_state = (image_encoder_last_hidden_state + end_image_encoder_last_hidden_state) / 2
262 |
263 | # Dtype
264 | llama_vec = llama_vec.to(self.transformer.dtype)
265 | llama_vec_n = llama_vec_n.to(self.transformer.dtype)
266 | clip_l_pooler = clip_l_pooler.to(self.transformer.dtype)
267 | clip_l_pooler_n = clip_l_pooler_n.to(self.transformer.dtype)
268 | image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype)
269 |
270 | print('Start Sample')
271 |
272 | rnd = torch.Generator("cpu").manual_seed(seed)
273 | num_frames = latent_window_size * 4 - 3
274 |
275 | history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32).cpu()
276 | history_pixels = None
277 | total_generated_latent_frames = 0
278 |
279 | latent_paddings = list(reversed(range(total_latent_sections)))
280 |
281 | if total_latent_sections > 4:
282 | latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
283 |
284 | for i, latent_padding in enumerate(latent_paddings):
285 | is_last_section = latent_padding == 0
286 | is_first_section = latent_padding == latent_paddings[0] # Use the original method
287 | latent_padding_size = latent_padding * latent_window_size
288 |
289 | print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}')
290 |
291 | indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
292 | clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
293 | clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
294 |
295 | # Always use start_latent for the first position (exactly like in the original code)
296 | clean_latents_pre = start_latent.to(history_latents)
297 |
298 | # For the second position, use history
299 | clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
300 |
301 | # Create clean_latents first
302 | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
303 |
304 | # Then if we have end_image and this is the first section, override clean_latents_post with end_latent
305 | if has_end_image and is_first_section:
306 | clean_latents_post = end_latent.to(history_latents)
307 | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
308 |
309 | if not self.high_vram:
310 | unload_complete_models()
311 | move_model_to_device_with_memory_preservation(self.transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
312 |
313 | if use_teacache:
314 | self.transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
315 | else:
316 | self.transformer.initialize_teacache(enable_teacache=False)
317 |
318 | def callback(d):
319 | self.update(1)
320 | return
321 |
322 | generated_latents = sample_hunyuan(
323 | transformer=self.transformer,
324 | sampler='unipc',
325 | width=width,
326 | height=height,
327 | frames=num_frames,
328 | real_guidance_scale=cfg,
329 | distilled_guidance_scale=gs,
330 | guidance_rescale=rs,
331 | num_inference_steps=steps,
332 | generator=rnd,
333 | prompt_embeds=llama_vec,
334 | prompt_embeds_mask=llama_attention_mask,
335 | prompt_poolers=clip_l_pooler,
336 | negative_prompt_embeds=llama_vec_n,
337 | negative_prompt_embeds_mask=llama_attention_mask_n,
338 | negative_prompt_poolers=clip_l_pooler_n,
339 | device=gpu,
340 | dtype=torch.bfloat16,
341 | image_embeddings=image_encoder_last_hidden_state,
342 | latent_indices=latent_indices,
343 | clean_latents=clean_latents,
344 | clean_latent_indices=clean_latent_indices,
345 | clean_latents_2x=clean_latents_2x,
346 | clean_latent_2x_indices=clean_latent_2x_indices,
347 | clean_latents_4x=clean_latents_4x,
348 | clean_latent_4x_indices=clean_latent_4x_indices,
349 | callback=callback,
350 | )
351 |
352 | # For the last section, add start_latent back to the beginning - just like in the original
353 | if is_last_section:
354 | generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
355 |
356 | # Accumulate generated frames
357 | total_generated_latent_frames += int(generated_latents.shape[2])
358 | history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
359 |
360 | if not self.high_vram:
361 | offload_model_from_device_for_memory_preservation(self.transformer, target_device=gpu, preserved_memory_gb=8)
362 | load_model_as_complete(self.vae, target_device=gpu)
363 |
364 | # Only decode up to the total number of frames we've generated
365 | real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
366 |
367 | # Decode latents to pixels
368 | if history_pixels is None:
369 | history_pixels = vae_decode(real_history_latents, self.vae).cpu()
370 | else:
371 | # For appending new frames to existing ones
372 | section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
373 | overlapped_frames = latent_window_size * 4 - 3
374 |
375 | current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], self.vae).cpu()
376 | history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
377 |
378 | if not self.high_vram:
379 | unload_complete_models()
380 |
381 | # If this is the last section, save the video
382 | if is_last_section:
383 | save_bcthw_as_mp4(history_pixels, video_path, fps=30)
384 | break
385 |
386 | except Exception as e:
387 | print(f"Error in exec: {str(e)}")
388 | traceback.print_exc()
389 | finally:
390 | unload_complete_models()
391 |
392 | def update(self, in_progress):
393 | self.pbar.update(in_progress)
394 |
395 | def extract_frames_as_pil(self, video_path):
396 | video, _, _ = torchvision.io.read_video(video_path, pts_unit='sec') # (T, H, W, C)
397 | frames = [to_pil_image(frame.permute(2, 0, 1)) for frame in video]
398 | frames = [torch.from_numpy(np.array(frame).astype(np.float32) / 255.0) for frame in frames]
399 | return frames
400 |
401 | def get_fps_with_torchvision(self, video_path):
402 | _, _, info = torchvision.io.read_video(video_path, pts_unit='sec')
403 | return info['video_fps']
404 |
405 | # --- Start of Kiki_FramePack_F1 Class ---
406 | class Kiki_FramePack_F1:
407 | @classmethod
408 | def INPUT_TYPES(s):
409 | return {
410 | "required": {
411 | "ref_image": ("IMAGE", ),
412 | "prompt": ("STRING", {"multiline": True}),
413 | "total_second_length": ("INT", {"default": 5, "min": 1, "max": 120, "step": 1}),
414 | "fps": ("INT", {"default": 30, "min": 1, "max": 60, "step": 1}),
415 | "seed": ("INT", {"default": 3407}),
416 | "steps": ("INT", {"default": 25, "min": 1, "max": 100, "step": 1}),
417 | "gs": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 32.0, "step": 0.1, "round": 0.01, "label": "Distilled CFG Scale"}),
418 | "use_teacache": ("BOOLEAN", {"default": True}),
419 | "upscale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 2.0, "step": 0.1, "description": "Resolution scaling factor."}),
420 | },
421 | "optional": {
422 | "n_prompt": ("STRING", {"multiline": True, "default": ""}),
423 | }
424 | }
425 |
426 | RETURN_TYPES = ("IMAGE", "FLOAT")
427 | RETURN_NAMES = ("frames", "fps")
428 | CATEGORY = "Runninghub/FramePack"
429 | FUNCTION = "run_f1"
430 |
431 | TITLE = 'RunningHub FramePack F1'
432 | OUTPUT_NODE = True
433 |
434 | def __init__(self):
435 | self.high_vram = False
436 | self.frames = None
437 | self.fps = None
438 |
439 | hunyuan_root = os.path.join(folder_paths.models_dir, 'HunyuanVideo')
440 | flux_redux_bfl_root = os.path.join(folder_paths.models_dir, 'flux_redux_bfl')
441 | framePackF1_root = os.path.join(folder_paths.models_dir, 'FramePackF1_HY')
442 |
443 | if not os.path.isdir(framePackF1_root):
444 | print(f"Warning: FramePack F1 model directory not found at {framePackF1_root}")
445 |
446 | self.text_encoder = LlamaModel.from_pretrained(hunyuan_root, subfolder='text_encoder', torch_dtype=torch.float16).cpu()
447 | self.text_encoder_2 = CLIPTextModel.from_pretrained(hunyuan_root, subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
448 | self.tokenizer = LlamaTokenizerFast.from_pretrained(hunyuan_root, subfolder='tokenizer')
449 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(hunyuan_root, subfolder='tokenizer_2')
450 | self.vae = AutoencoderKLHunyuanVideo.from_pretrained(hunyuan_root, subfolder='vae', torch_dtype=torch.float16).cpu()
451 |
452 | self.feature_extractor = SiglipImageProcessor.from_pretrained(flux_redux_bfl_root, subfolder='feature_extractor')
453 | self.image_encoder = SiglipVisionModel.from_pretrained(flux_redux_bfl_root, subfolder='image_encoder', torch_dtype=torch.float16).cpu()
454 |
455 | try:
456 | self.transformer_f1 = HunyuanVideoTransformer3DModelPacked.from_pretrained(framePackF1_root, torch_dtype=torch.bfloat16).cpu()
457 | except Exception as e:
458 | print(f"Error loading FramePack F1 transformer model from {framePackF1_root}: {e}")
459 | print("Please ensure the F1 model weights (e.g., transformer.safetensors) are correctly placed in the directory.")
460 | self.transformer_f1 = None
461 |
462 | self.vae.eval()
463 | self.text_encoder.eval()
464 | self.text_encoder_2.eval()
465 | self.image_encoder.eval()
466 | if self.transformer_f1:
467 | self.transformer_f1.eval()
468 |
469 | if not self.high_vram:
470 | self.vae.enable_slicing()
471 | self.vae.enable_tiling()
472 |
473 | if self.transformer_f1:
474 | self.transformer_f1.high_quality_fp32_output_for_inference = True
475 | print('F1 transformer.high_quality_fp32_output_for_inference = True')
476 |
477 | self.transformer_f1.to(dtype=torch.bfloat16)
478 |
479 | self.transformer_f1.requires_grad_(False)
480 |
481 | if not self.high_vram:
482 | DynamicSwapInstaller.install_model(self.transformer_f1, device=gpu)
483 |
484 | self.vae.to(dtype=torch.float16)
485 | self.image_encoder.to(dtype=torch.float16)
486 | self.text_encoder.to(dtype=torch.float16)
487 | self.text_encoder_2.to(dtype=torch.float16)
488 | self.vae.requires_grad_(False)
489 | self.text_encoder.requires_grad_(False)
490 | self.text_encoder_2.requires_grad_(False)
491 | self.image_encoder.requires_grad_(False)
492 |
493 | if not self.high_vram:
494 | DynamicSwapInstaller.install_model(self.text_encoder, device=gpu)
495 |
496 | def strict_align(self, h, w, scale):
497 | raw_h = h * scale
498 | raw_w = w * scale
499 | aligned_h = int(round(raw_h / 64)) * 64
500 | aligned_w = int(round(raw_w / 64)) * 64
501 | assert (aligned_h % 64 == 0) and (aligned_w % 64 == 0), "尺寸必须是64的倍数"
502 | assert (aligned_h//8) % 8 == 0 and (aligned_w//8) % 8 == 0, "潜在空间需要8的倍数"
503 | return aligned_h, aligned_w
504 |
505 | def preprocess_image(self, image):
506 | if image is None: return None
507 | if image.dim() == 4 and image.shape[0] == 1:
508 | img_tensor = image[0]
509 | else:
510 | img_tensor = image
511 | print(f"Warning: Unexpected input image tensor shape: {image.shape}. Assuming HWC.")
512 |
513 | image_np = 255. * img_tensor.cpu().numpy()
514 | image = Image.fromarray(np.clip(image_np, 0, 255).astype(np.uint8)).convert("RGB")
515 | input_image = np.array(image)
516 | return input_image
517 |
518 | def run_f1(self, **kwargs):
519 | if not self.transformer_f1:
520 | print("Error: Kiki_FramePack_F1 cannot run because the transformer model failed to load.")
521 | return (torch.empty((0, 1, 1, 3), dtype=torch.float32), 0.0)
522 |
523 | try:
524 | image = kwargs['ref_image']
525 | image_np = self.preprocess_image(image)
526 | prompt = kwargs['prompt']
527 | n_prompt = kwargs.get('n_prompt', "")
528 | seed = kwargs['seed']
529 | total_second_length = kwargs['total_second_length']
530 | fps = kwargs['fps']
531 | steps = kwargs['steps']
532 | gs = kwargs['gs']
533 | use_teacache = kwargs['use_teacache']
534 | upscale = kwargs['upscale']
535 | cfg = 1.0
536 | rs = 0.0
537 | latent_window_size = 9
538 |
539 | random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
540 | video_path = os.path.join(folder_paths.get_output_directory(), f'{random_str}_f1.mp4')
541 |
542 | # --- Initialize Progress Bar (Aligned with demo's section calc) ---
543 | # Use demo's calculation for total_latent_sections, assuming 30fps basis for consistency
544 | total_latent_sections = int(max(round((total_second_length * 30) / (latent_window_size * 4)), 1))
545 | total_progress_steps = total_latent_sections * steps
546 | self.pbar = comfy.utils.ProgressBar(total_progress_steps)
547 |
548 | # Call exec_f1, passing latent_window_size as well
549 | self.exec_f1(input_image=image_np, prompt=prompt, n_prompt=n_prompt, seed=seed,
550 | total_second_length=total_second_length, video_path=video_path, fps=fps,
551 | steps=steps, gs=gs, cfg=cfg, rs=rs, latent_window_size=latent_window_size, # Pass latent_window_size
552 | use_teacache=use_teacache, scale=upscale,
553 | gpu_memory_preservation=6)
554 |
555 | if os.path.exists(video_path):
556 | self.fps = float(fps)
557 | self.frames = self.extract_frames_to_tensor(video_path)
558 | print(f'F1 Video saved: {video_path} | FPS: {self.fps} | Frames: {self.frames.shape[0] if self.frames is not None else 0}')
559 | else:
560 | self.frames = torch.empty((0, 1, 1, 3), dtype=torch.float32)
561 | self.fps = 0.0
562 | print(f'F1 Video generation failed or file not found: {video_path}')
563 |
564 | except Exception as e:
565 | print(f"Error in run_f1: {str(e)}")
566 | traceback.print_exc()
567 | self.frames = torch.empty((0, 1, 1, 3), dtype=torch.float32)
568 | self.fps = 0.0
569 |
570 | return (self.frames, self.fps)
571 |
572 | @torch.no_grad()
573 | def exec_f1(self, input_image, video_path,
574 | prompt, n_prompt, seed, total_second_length, fps,
575 | steps, gs, cfg, rs, latent_window_size, # Receive latent_window_size
576 | use_teacache, scale,
577 | gpu_memory_preservation=6):
578 |
579 | print("--- Starting Kiki_FramePack_F1 exec_f1 (Aligned with Demo Logic) ---")
580 | print(f"Params: seed={seed}, length={total_second_length}s@{fps}fps, steps={steps}, gs={gs}, cfg={cfg}, rs={rs}, lws={latent_window_size}")
581 |
582 | vae_time_stride = 4
583 |
584 | # --- Use Demo's total_latent_sections calculation ---
585 | total_latent_sections = int(max(round((total_second_length * 30) / (latent_window_size * 4)), 1))
586 | print(f"Total generation sections (Demo calc): {total_latent_sections}")
587 |
588 | # --- Calculate target frames needed (still useful for trimming) ---
589 | target_pixel_frames = int(round(total_second_length * fps))
590 |
591 | try:
592 | # --- 1. Initialization & Setup ---
593 | torch.manual_seed(seed)
594 | rnd = torch.Generator("cpu").manual_seed(seed)
595 |
596 | # ... (Unload models if needed) ...
597 |
598 | # --- 2. Encoding Inputs ---
599 | print('Encoding text prompts...')
600 | if not self.high_vram:
601 | fake_diffusers_current_device(self.text_encoder, gpu)
602 | load_model_as_complete(self.text_encoder_2, target_device=gpu)
603 | llama_vec, clip_l_pooler = encode_prompt_conds(prompt, self.text_encoder, self.text_encoder_2, self.tokenizer, self.tokenizer_2)
604 | llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, self.text_encoder, self.text_encoder_2, self.tokenizer, self.tokenizer_2)
605 | llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
606 | llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
607 |
608 | print('Processing reference image...')
609 | H, W, C = input_image.shape
610 | if scale == 1.0:
611 | height, width = find_nearest_bucket(H, W, resolution=640)
612 | height, width = self.strict_align(height, width, 1.0)
613 | else:
614 | height, width = self.strict_align(H, W, scale)
615 | print(f"Target dimensions: {width}x{height}")
616 | input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
617 | input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
618 | input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
619 |
620 | print('VAE encoding reference image...')
621 | if not self.high_vram: load_model_as_complete(self.vae, target_device=gpu)
622 | start_latent = vae_encode(input_image_pt.to(self.vae.device, dtype=self.vae.dtype), self.vae)
623 | print(f"Start latent shape: {start_latent.shape}")
624 |
625 | print('CLIP Vision encoding reference image...')
626 | if not self.high_vram: load_model_as_complete(self.image_encoder, target_device=gpu)
627 | image_encoder_output = hf_clip_vision_encode(input_image_np, self.feature_extractor, self.image_encoder.to(gpu))
628 | image_embeddings = image_encoder_output.last_hidden_state
629 |
630 | transformer_dtype = self.transformer_f1.dtype
631 | start_latent = start_latent.to(transformer_dtype).cpu()
632 |
633 | # --- 3. Diffusion Loop (Aligned with Demo) ---
634 | print(f'Starting diffusion loop for {total_latent_sections} sections...')
635 |
636 | latent_channels = start_latent.shape[1]
637 | latent_height = start_latent.shape[-2]
638 | latent_width = start_latent.shape[-1]
639 | history_context_size = 16 + 2 + 1
640 |
641 | # --- Initialize history_latents like demo ---
642 | # Start with zeros matching context size
643 | history_latents = torch.zeros(size=(1, latent_channels, history_context_size, latent_height, latent_width), dtype=torch.float32).cpu() # Use float32 like demo?
644 | # Immediately add start_latent
645 | history_latents = torch.cat([history_latents, start_latent.to(history_latents.dtype)], dim=2)
646 | total_generated_latent_frames = 1 # Account for start_latent
647 | history_pixels = None
648 |
649 | # ... (Progress bar callback setup) ...
650 | current_section_step = 0
651 | total_progress_steps = total_latent_sections * steps
652 | def callback_f1(d):
653 | # ... (Update pbar logic remains the same) ...
654 | nonlocal current_section_step
655 | step_in_section = d['i']
656 | current_total_step = current_section_step * steps + step_in_section + 1
657 | if hasattr(self, 'pbar') and self.pbar:
658 | self.pbar.update_absolute(current_total_step, total_progress_steps)
659 |
660 | # Calculate frames generated per step based on demo
661 | frames_per_latent_window = latent_window_size * 4 - 3
662 |
663 | for section_index in range(total_latent_sections):
664 | section_start_time = time.time()
665 | print(f'Generating section {section_index + 1} / {total_latent_sections}')
666 | current_section_step = section_index
667 |
668 | # ... (Load transformer if needed) ...
669 |
670 | # --- Prepare context and indices (same as before, uses history_latents) ---
671 | indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
672 | clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
673 | clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
674 |
675 | # Get history context from the *end* of the current history_latents
676 | # No padding needed here because history starts with context + start_latent
677 | history_context = history_latents[:, :, -history_context_size:, :, :]
678 | clean_latents_4x, clean_latents_2x, clean_latents_1x = history_context.split([16, 2, 1], dim=2)
679 | clean_latents = torch.cat([start_latent.cpu(), clean_latents_1x.cpu()], dim=2)
680 |
681 | # --- Prepare sample_kwargs (same as before) ---
682 | sample_kwargs = dict(
683 | transformer=self.transformer_f1,
684 | sampler='unipc',
685 | width=width,
686 | height=height,
687 | frames=frames_per_latent_window, # Use demo's frame count
688 | real_guidance_scale=cfg,
689 | distilled_guidance_scale=gs,
690 | guidance_rescale=rs,
691 | num_inference_steps=steps,
692 | generator=rnd,
693 | # --- Add missing positive prompt embeddings & ENSURE DTYPE ---
694 | prompt_embeds=llama_vec.to(gpu, dtype=transformer_dtype),
695 | prompt_embeds_mask=llama_attention_mask.to(gpu), # Mask dtype usually okay
696 | # --- Existing embeddings/poolers & ENSURE DTYPE ---
697 | prompt_poolers=clip_l_pooler.to(gpu, dtype=transformer_dtype),
698 | negative_prompt_embeds=llama_vec_n.to(gpu, dtype=transformer_dtype),
699 | negative_prompt_embeds_mask=llama_attention_mask_n.to(gpu), # Mask dtype usually okay
700 | negative_prompt_poolers=clip_l_pooler_n.to(gpu, dtype=transformer_dtype),
701 | device=gpu, # Device is already GPU
702 | dtype=transformer_dtype, # Explicitly passing transformer's dtype
703 | image_embeddings=image_embeddings.to(gpu, dtype=transformer_dtype),
704 | latent_indices=latent_indices.to(gpu), # Indices dtype usually okay
705 | clean_latents=clean_latents.to(gpu, dtype=transformer_dtype), # Ensure correct dtype
706 | clean_latent_indices=clean_latent_indices.to(gpu), # Indices dtype usually okay
707 | clean_latents_2x=clean_latents_2x.to(gpu, dtype=transformer_dtype), # Ensure correct dtype
708 | clean_latent_2x_indices=clean_latent_2x_indices.to(gpu), # Indices dtype usually okay
709 | clean_latents_4x=clean_latents_4x.to(gpu, dtype=transformer_dtype), # Ensure correct dtype
710 | clean_latent_4x_indices=clean_latent_4x_indices.to(gpu), # Indices dtype usually okay
711 | callback=callback_f1,
712 | )
713 |
714 | # ... (Initialize teacache) ...
715 | if hasattr(self.transformer_f1, 'initialize_teacache'):
716 | self.transformer_f1.initialize_teacache(enable_teacache=use_teacache, num_steps=steps)
717 |
718 | # --- Call sample_hunyuan ---
719 | generated_latents = sample_hunyuan(**sample_kwargs)
720 |
721 | generated_latents = generated_latents.to(cpu, dtype=torch.float32)
722 | print(f" Sampled latent section shape: {generated_latents.shape}")
723 |
724 | # --- Update history_latents (Aligned with Demo: Always append) ---
725 | total_generated_latent_frames += int(generated_latents.shape[2])
726 | history_latents = torch.cat([history_latents, generated_latents.to(history_latents.dtype)], dim=2)
727 |
728 | # --- Decode and append pixels (Aligned with Demo) ---
729 | if not self.high_vram:
730 | offload_model_from_device_for_memory_preservation(self.transformer_f1, target_device=gpu, preserved_memory_gb=8)
731 | load_model_as_complete(self.vae, target_device=gpu)
732 | else:
733 | if self.vae.device != gpu: self.vae.to(gpu)
734 |
735 | # Calculate the slice of history to decode based on total generated frames
736 | real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] # Use actual generated frames
737 |
738 | if history_pixels is None:
739 | # First time: decode the current relevant history
740 | history_pixels = vae_decode(real_history_latents.to(gpu, dtype=self.vae.dtype), self.vae).cpu()
741 | print(f" Decoded initial pixels. Shape: {history_pixels.shape}")
742 | else:
743 | # Subsequent times: decode only the part needed for smooth append
744 | section_latent_frames = latent_window_size * 2
745 | overlapped_frames = latent_window_size * 4 - 3 # Use demo's overlap calculation
746 |
747 | # Decode the relevant tail end of the history latents
748 | current_latents_to_decode = real_history_latents[:, :, -section_latent_frames:, :, :]
749 | current_pixels = vae_decode(current_latents_to_decode.to(gpu, dtype=self.vae.dtype), self.vae).cpu()
750 |
751 | # Append smoothly using demo's overlap value
752 | history_pixels = soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
753 | print(f" Appended pixels. New history shape: {history_pixels.shape}")
754 |
755 | # ... (Unload VAE if needed) ...
756 | if not self.high_vram:
757 | unload_complete_models(self.vae)
758 |
759 | section_end_time = time.time()
760 | print(f" Section {section_index + 1} took {section_end_time - section_start_time:.2f} seconds.")
761 |
762 | # --- 4. Final Saving (Aligned with Demo, keeping variable fps) ---
763 | print('Saving final video...')
764 | if history_pixels is None or history_pixels.shape[2] == 0:
765 | raise ValueError("No pixel frames were generated or decoded.")
766 |
767 | if history_pixels.shape[2] > target_pixel_frames:
768 | print(f"Trimming final video from {history_pixels.shape[2]} to {target_pixel_frames} frames.")
769 | history_pixels = history_pixels[:,:,:target_pixel_frames,:,:]
770 |
771 | save_bcthw_as_mp4(
772 | history_pixels,
773 | video_path,
774 | fps=fps, # Keep user FPS for now
775 | # crf=18 # Omit crf until utils.py is confirmed synced
776 | )
777 | print(f"Final video saved to: {video_path}")
778 |
779 | except Exception as e:
780 | print(f"Error during Kiki_FramePack_F1 execution: {str(e)}")
781 | traceback.print_exc()
782 | if os.path.exists(video_path):
783 | try: os.remove(video_path)
784 | except OSError: pass
785 | if hasattr(self, 'pbar') and self.pbar: self.pbar.update_absolute(total_progress_steps, total_progress_steps)
786 | raise
787 |
788 | finally:
789 | print('Cleaning up models...')
790 | unload_complete_models(
791 | self.text_encoder, self.text_encoder_2, self.image_encoder, self.vae, self.transformer_f1
792 | )
793 | torch.cuda.empty_cache()
794 | print("--- Finished Kiki_FramePack_F1 exec_f1 (Aligned with Demo Logic) ---")
795 |
796 | def extract_frames_to_tensor(self, video_path):
797 | try:
798 | video_tensor, _, metadata = torchvision.io.read_video(video_path, pts_unit='sec', output_format='TCHW')
799 |
800 | video_tensor = video_tensor.permute(0, 2, 3, 1)
801 |
802 | video_tensor = video_tensor.float() / 255.0
803 |
804 | print(f"Extracted video tensor shape: {video_tensor.shape}")
805 | return video_tensor
806 |
807 | except Exception as e:
808 | print(f"Error extracting frames using torchvision.io.read_video: {e}")
809 | traceback.print_exc()
810 | return torch.empty((0, 1, 1, 3), dtype=torch.float32)
811 |
812 | def get_fps_with_torchvision(self, video_path):
813 | try:
814 | _, _, metadata = torchvision.io.read_video(video_path, pts_unit='sec')
815 | fps = metadata.get('video_fps', 30.0)
816 | return float(fps)
817 | except Exception as e:
818 | print(f"Error reading FPS using torchvision.io.read_video: {e}")
819 | traceback.print_exc()
820 | return 30.0
821 |
822 | # NODE CLASS MAPPINGS
823 | NODE_CLASS_MAPPINGS = {
824 | "RunningHub_FramePack": Kiki_FramePack,
825 | "RunningHub_FramePack_F1": Kiki_FramePack_F1
826 | }
827 |
828 | # A dictionary that contains the friendly/humanly readable titles for the nodes
829 | NODE_DISPLAY_NAME_MAPPINGS = {
830 | "RunningHub_FramePack": Kiki_FramePack.TITLE,
831 | "RunningHub_FramePack_F1": Kiki_FramePack_F1.TITLE
832 | }
833 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | Pillow
5 | diffusers>=0.33.1
6 | transformers>=4.46.2
7 | einops
8 | safetensors
9 | accelerate>=1.6.0
10 | scipy>=1.12.0
11 | torchsde>=0.2.6
12 | opencv-python
13 |
--------------------------------------------------------------------------------