├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── README.md
├── __init__.py
├── assets
└── workflow.png
├── input
├── .DS_Store
├── aud.mp3
├── gt.mp4
├── kps.pth
└── ref.jpg
├── model_ckpts
└── put_V-Express_model_here
├── nodes.py
├── requirements.txt
├── src
├── inference.py
├── inference_v2.yaml
├── modules
│ ├── .DS_Store
│ ├── __init__.py
│ ├── attention.py
│ ├── audio_projection.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ ├── unet_3d_blocks.py
│ └── v_kps_guider.py
├── pipelines
│ ├── .DS_Store
│ ├── __init__.py
│ ├── context.py
│ ├── utils.py
│ └── v_express_pipeline.py
├── scripts
│ └── extract_kps_sequence_and_audio.py
└── util.py
├── web
└── js
│ ├── VEpreviewVideo.js
│ ├── VEuploadAudio.js
│ └── VEuploadVideo.js
└── workflow
└── V-Express.json
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | jobs:
12 | publish-node:
13 | name: Publish Custom Node to registry
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Check out code
17 | uses: actions/checkout@v4
18 | - name: Publish Custom Node
19 | uses: Comfy-Org/publish-node-action@main
20 | with:
21 | ## Add your own personal access token to your Github Repository secrets and reference it here.
22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | *.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **_V-Express: Conditional Dropout for Progressive Training of Portrait Video Generation_**
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ---
10 |
11 | ## Introduction
12 |
13 | In the field of portrait video generation, the use of single images to generate portrait videos has become increasingly prevalent.
14 | A common approach involves leveraging generative models to enhance adapters for controlled generation.
15 | However, control signals can vary in strength, including text, audio, image reference, pose, depth map, etc.
16 | Among these, weaker conditions often struggle to be effective due to interference from stronger conditions, posing a challenge in balancing these conditions.
17 | In our work on portrait video generation, we identified audio signals as particularly weak, often overshadowed by stronger signals such as pose and original image.
18 | However, direct training with weak signals often leads to difficulties in convergence.
19 | To address this, we propose V-Express, a simple method that balances different control signals through a series of progressive drop operations.
20 | Our method gradually enables effective control by weak conditions, thereby achieving generation capabilities that simultaneously take into account pose, input image, and audio.
21 |
22 | ## Workflow
23 |
24 | 
25 |
26 | ## Installation
27 |
28 | 1. Clone this repo into the Your ComfyUI root directory\ComfyUI\custom_nodes\ and install dependent Python packages from [here](https://github.com/tencent-ailab/V-Express#installation):
29 |
30 | ```shell
31 | cd Your_ComfyUI_root_directory\ComfyUI\custom_nodes\
32 |
33 | git clone https://github.com/tiankuan93/ComfyUI-V-Express
34 |
35 | pip install -r requirements.txt
36 | ```
37 |
38 | If you are using ComfyUI_windows_portable , you should use `.\python_embeded\python.exe -m pip` to replace `pip` for installation.
39 |
40 | If you got error regards insightface, you may find solution [here](https://www.youtube.com/watch?v=vCCVxGtCyho).
41 |
42 | - first, you should download .whl file [here](https://github.com/Gourieff/Assets/tree/main/Insightface)
43 | - then, install it by `.\python_embeded\python.exe -m pip install --no-deps --target=\your_path_of\python_embeded\Lib\site-packages [path-to-wheel]`
44 |
45 | 2. Download V-Express models and other needed models:
46 |
47 | - [model_ckpts](https://huggingface.co/tk93/V-Express)
48 | - You need to replace the **model_ckpts** folder with the downloaded **V-Express/model_ckpts**. Then you should download and put all `.bin` model to `model_ckpts/v-express` directory, which includes `audio_projection.bin`, `denoising_unet.bin`, `motion_module.bin`, `reference_net.bin`, and `v_kps_guider.bin`. The final **model_ckpts** folder is as follows:
49 |
50 | ```text
51 | ./model_ckpts/
52 | |-- insightface_models
53 | |-- sd-vae-ft-mse
54 | |-- stable-diffusion-v1-5
55 | |-- v-express
56 | |-- wav2vec2-base-960h
57 | ```
58 |
59 | 3. You should put the files in input directory into the Your ComfyUI Input root `directory\ComfyUI\input\`.
60 | 4. You need to set `output_path` as `directory\ComfyUI\output\xxx.mp4`, otherwise the output video will not be displayed in the ComfyUI.
61 |
62 | ## Acknowledgements
63 |
64 | We would like to thank the contributors to the [AIFSH/ComfyUI_V-Express](https://github.com/AIFSH/ComfyUI_V-Express), for the open research and exploration.
65 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 |
3 | WEB_DIRECTORY = "./web"
4 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
5 |
--------------------------------------------------------------------------------
/assets/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/assets/workflow.png
--------------------------------------------------------------------------------
/input/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/.DS_Store
--------------------------------------------------------------------------------
/input/aud.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/aud.mp3
--------------------------------------------------------------------------------
/input/gt.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/gt.mp4
--------------------------------------------------------------------------------
/input/kps.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/kps.pth
--------------------------------------------------------------------------------
/input/ref.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/ref.jpg
--------------------------------------------------------------------------------
/model_ckpts/put_V-Express_model_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/model_ckpts/put_V-Express_model_here
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import sys
4 | import numpy as np
5 | import time
6 | import torch
7 | import torchaudio.functional
8 | import torchvision.io
9 | from imageio_ffmpeg import get_ffmpeg_exe
10 | from PIL import Image
11 |
12 | from diffusers.utils.torch_utils import randn_tensor
13 | from diffusers import AutoencoderKL
14 | from insightface.app import FaceAnalysis
15 | from transformers import Wav2Vec2Model, Wav2Vec2Processor
16 | import accelerate
17 |
18 | import folder_paths
19 | import folder_paths as comfy_paths
20 | from comfy import model_management
21 |
22 | ROOT_PATH = os.path.join(comfy_paths.get_folder_paths("custom_nodes")[0], "ComfyUI-V-Express")
23 | sys.path.append(os.path.join(ROOT_PATH, 'src'))
24 |
25 | from .src.pipelines import VExpressPipeline
26 | from .src.pipelines.utils import draw_kps_image, save_video
27 | from .src.pipelines.utils import retarget_kps
28 | from .src.util import get_ffmpeg
29 | from .src.inference import (
30 | get_scheduler,
31 | load_reference_net,
32 | load_denoising_unet,
33 | load_v_kps_guider,
34 | load_audio_projection,
35 | )
36 |
37 | INPUT_PATH = folder_paths.get_input_directory()
38 | OUTPUT_PATH = folder_paths.get_output_directory()
39 |
40 |
41 | INFERENCE_CONFIG_PATH = os.path.join(ROOT_PATH, "src/inference_v2.yaml")
42 |
43 | load_device = model_management.get_torch_device()
44 | offload_device = model_management.unet_offload_device()
45 | DEVICE = load_device
46 | WEIGHT_DTYPE = torch.float16
47 | GPU_ID = 0
48 |
49 | STANDARD_AUDIO_SAMPLING_RATE = 16000
50 | NUM_PAD_AUDIO_FRAMES = 2
51 |
52 |
53 | def get_all_model_path(vexpress_model_path):
54 | if not os.path.isabs(vexpress_model_path):
55 | vexpress_model_path = os.path.join(ROOT_PATH, vexpress_model_path)
56 |
57 | unet_config_path = os.path.join(vexpress_model_path, 'stable-diffusion-v1-5/unet/config.json')
58 | vae_path = os.path.join(vexpress_model_path, 'sd-vae-ft-mse')
59 | audio_encoder_path = os.path.join(vexpress_model_path, 'wav2vec2-base-960h')
60 | insightface_model_path = os.path.join(vexpress_model_path, 'insightface_models')
61 |
62 | denoising_unet_path = os.path.join(vexpress_model_path, 'v-express/denoising_unet.bin')
63 | reference_net_path = os.path.join(vexpress_model_path, 'v-express/reference_net.bin')
64 | v_kps_guider_path = os.path.join(vexpress_model_path, 'v-express/v_kps_guider.bin')
65 | audio_projection_path = os.path.join(vexpress_model_path, 'v-express/audio_projection.bin')
66 | motion_module_path = os.path.join(vexpress_model_path, 'v-express/motion_module.bin')
67 |
68 | if not os.path.isfile(denoising_unet_path):
69 | denoising_unet_path = os.path.join(vexpress_model_path, 'v-express/denoising_unet.pth')
70 | if not os.path.isfile(reference_net_path):
71 | reference_net_path = os.path.join(vexpress_model_path, 'v-express/reference_net.pth')
72 | if not os.path.isfile(v_kps_guider_path):
73 | v_kps_guider_path = os.path.join(vexpress_model_path, 'v-express/v_kps_guider.pth')
74 | if not os.path.isfile(audio_projection_path):
75 | audio_projection_path = os.path.join(vexpress_model_path, 'v-express/audio_projection.pth')
76 | if not os.path.isfile(motion_module_path):
77 | motion_module_path = os.path.join(vexpress_model_path, 'v-express/motion_module.pth')
78 |
79 | model_dict = {
80 | "unet_config_path": unet_config_path,
81 | "vae_path": vae_path,
82 | "audio_encoder_path": audio_encoder_path,
83 | "insightface_model_path": insightface_model_path,
84 | "denoising_unet_path": denoising_unet_path,
85 | "reference_net_path": reference_net_path,
86 | "v_kps_guider_path": v_kps_guider_path,
87 | "audio_projection_path": audio_projection_path,
88 | "motion_module_path": motion_module_path,
89 | }
90 | return model_dict
91 |
92 |
93 | class VEINTConstant:
94 | @classmethod
95 | def INPUT_TYPES(s):
96 | return {"required": {
97 | "image_size": ("INT", {"default": 512, "min": 512, "max": 2048}),
98 | },
99 | }
100 | RETURN_TYPES = ("INT_INPUT",)
101 | RETURN_NAMES = ("image_size",)
102 | FUNCTION = "get_value"
103 | CATEGORY = "V-Express"
104 |
105 | def get_value(self, image_size):
106 | return (image_size,)
107 |
108 |
109 | class VEStringConstant:
110 | @classmethod
111 | def INPUT_TYPES(cls):
112 | return {
113 | "required": {
114 | "string": ("STRING", {"default": './model_ckpts', "multiline": False}),
115 | }
116 | }
117 | RETURN_TYPES = ("STRING_INPUT",)
118 | FUNCTION = "passtring"
119 | CATEGORY = "V-Express"
120 |
121 | def passtring(self, string):
122 | return (string, )
123 |
124 |
125 | class V_Express_Sampler:
126 | @classmethod
127 | def INPUT_TYPES(s):
128 | return {
129 | "required": {
130 | "v_express_pipeline": ("V_EXPRESS_PIPELINE",),
131 | "vexpress_model_path": ("STRING_INPUT", ),
132 | "audio_path": ("AUDIO_PATH",),
133 | "kps_path": ("VKPS_PATH",),
134 | "ref_image_path": ("IMAGE_PATH",),
135 | "output_path": ("STRING",{
136 | "default": os.path.join(OUTPUT_PATH,f"{time.time()}_vexpress.mp4")
137 | }),
138 | "image_size": ("INT_INPUT",),
139 | "retarget_strategy": (
140 | ["fix_face", "no_retarget", "offset_retarget", "naive_retarget"],
141 | {"default": "fix_face"}
142 | ),
143 | "fps": ("FLOAT", {"default": 30.0, "min": 20.0, "max": 60.0}),
144 | "seed": ("INT",{
145 | "default": 42
146 | }),
147 | "num_inference_steps": ("INT",{
148 | "default": 20
149 | }),
150 | "guidance_scale": ("FLOAT",{
151 | "default": 3.5
152 | }),
153 | "context_frames": ("INT",{
154 | "default": 12
155 | }),
156 | "context_stride": ("INT",{
157 | "default": 1
158 | }),
159 | "context_overlap": ("INT",{
160 | "default": 4
161 | }),
162 | "reference_attention_weight": ("FLOAT",{
163 | "default": 0.95
164 | }),
165 | "audio_attention_weight": ("FLOAT",{
166 | "default": 3.
167 | }),
168 | }
169 | }
170 |
171 | RETURN_TYPES = (
172 | "STRING_INPUT",
173 | )
174 | RETURN_NAMES = (
175 | "output_path",
176 | )
177 | OUTPUT_NODE = True
178 | # OUTPUT_NODE = False
179 | CATEGORY = "V-Express"
180 | FUNCTION = "v_express"
181 | def v_express(
182 | self,
183 | v_express_pipeline,
184 | vexpress_model_path,
185 | audio_path,
186 | kps_path,
187 | ref_image_path,
188 | output_path,
189 | image_size,
190 | retarget_strategy,
191 | fps,
192 | seed,
193 | num_inference_steps,
194 | guidance_scale,
195 | context_frames,
196 | context_stride,
197 | context_overlap,
198 | reference_attention_weight,
199 | audio_attention_weight,
200 | save_gpu_memory=True,
201 | do_multi_devices_inference=False,
202 | ):
203 | start_time = time.time()
204 |
205 | accelerator = None
206 |
207 | reference_image_path = ref_image_path
208 | model_dict = get_all_model_path(vexpress_model_path)
209 |
210 | insightface_model_path = model_dict['insightface_model_path']
211 |
212 | app = FaceAnalysis(
213 | providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'],
214 | provider_options=[{'device_id': GPU_ID}] if DEVICE == 'cuda' else [],
215 | root=insightface_model_path,
216 | )
217 | app.prepare(ctx_id=0, det_size=(image_size, image_size))
218 |
219 | reference_image = Image.open(reference_image_path).convert('RGB')
220 | reference_image = reference_image.resize((image_size, image_size))
221 |
222 | reference_image_for_kps = cv2.imread(reference_image_path)
223 | reference_image_for_kps = cv2.resize(reference_image_for_kps, (image_size, image_size))
224 | reference_kps = app.get(reference_image_for_kps)[0].kps[:3]
225 | if save_gpu_memory:
226 | del app
227 | torch.cuda.empty_cache()
228 |
229 | _, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec')
230 | audio_sampling_rate = meta_info['audio_fps']
231 | print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.')
232 | if audio_sampling_rate != STANDARD_AUDIO_SAMPLING_RATE:
233 | audio_waveform = torchaudio.functional.resample(
234 | audio_waveform,
235 | orig_freq=audio_sampling_rate,
236 | new_freq=STANDARD_AUDIO_SAMPLING_RATE,
237 | )
238 | audio_waveform = audio_waveform.mean(dim=0)
239 |
240 | duration = audio_waveform.shape[0] / STANDARD_AUDIO_SAMPLING_RATE
241 | init_video_length = int(duration * fps)
242 | num_contexts = np.around((init_video_length + context_overlap) / context_frames)
243 | video_length = int(num_contexts * context_frames - context_overlap)
244 | fps = video_length / duration
245 | print(f'The corresponding video length is {video_length}.')
246 |
247 | kps_sequence = None
248 | if kps_path != "":
249 | assert os.path.exists(kps_path), f'{kps_path} does not exist'
250 | kps_sequence = torch.tensor(torch.load(kps_path)) # [len, 3, 2]
251 | print(f'The original length of kps sequence is {kps_sequence.shape[0]}.')
252 |
253 | if kps_sequence.shape[0] > video_length:
254 | kps_sequence = kps_sequence[:video_length, :, :]
255 |
256 | kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
257 | kps_sequence = kps_sequence.permute(2, 0, 1)
258 | print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.')
259 |
260 | if retarget_strategy == 'fix_face':
261 | kps_sequence = torch.tensor([reference_kps] * video_length)
262 | elif retarget_strategy == 'no_retarget':
263 | kps_sequence = kps_sequence
264 | elif retarget_strategy == 'offset_retarget':
265 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
266 | elif retarget_strategy == 'naive_retarget':
267 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
268 | else:
269 | raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.')
270 |
271 | kps_images = []
272 | for i in range(video_length):
273 | kps_image = draw_kps_image(image_size, image_size, kps_sequence[i])
274 | kps_images.append(Image.fromarray(kps_image))
275 |
276 | generator = torch.manual_seed(seed)
277 | video_tensor = v_express_pipeline(
278 | reference_image=reference_image,
279 | kps_images=kps_images,
280 | audio_waveform=audio_waveform,
281 | width=image_size,
282 | height=image_size,
283 | video_length=video_length,
284 | num_inference_steps=num_inference_steps,
285 | guidance_scale=guidance_scale,
286 | context_frames=context_frames,
287 | context_overlap=context_overlap,
288 | reference_attention_weight=reference_attention_weight,
289 | audio_attention_weight=audio_attention_weight,
290 | num_pad_audio_frames=NUM_PAD_AUDIO_FRAMES,
291 | generator=generator,
292 | do_multi_devices_inference=do_multi_devices_inference,
293 | save_gpu_memory=save_gpu_memory,
294 | )
295 |
296 | if accelerator is None or accelerator.is_main_process:
297 | save_video(video_tensor, audio_path, output_path, DEVICE, fps)
298 | consumed_time = time.time() - start_time
299 | generation_fps = video_tensor.shape[2] / consumed_time
300 | print(f'The generated video has been saved at {output_path}. '
301 | f'The generation time is {consumed_time:.1f} seconds. '
302 | f'The generation FPS is {generation_fps:.2f}.')
303 |
304 | return (output_path, )
305 |
306 |
307 | class V_Express_Loader:
308 | @classmethod
309 | def INPUT_TYPES(s):
310 | return {
311 | "required": {
312 | "vexpress_model_path": ("STRING_INPUT", ),
313 | },
314 | }
315 |
316 | RETURN_TYPES = (
317 | "V_EXPRESS_PIPELINE",
318 | )
319 | RETURN_NAMES = (
320 | "v_express_pipeline",
321 | )
322 |
323 | CATEGORY = "V-Express"
324 | FUNCTION = "load_vexpress_pipeline"
325 | def load_vexpress_pipeline(self, vexpress_model_path):
326 |
327 | model_dict = get_all_model_path(vexpress_model_path)
328 |
329 | unet_config_path = model_dict['unet_config_path']
330 | reference_net_path = model_dict['reference_net_path']
331 | denoising_unet_path = model_dict['denoising_unet_path']
332 | v_kps_guider_path = model_dict['v_kps_guider_path']
333 | audio_projection_path = model_dict['audio_projection_path']
334 | motion_module_path = model_dict['motion_module_path']
335 |
336 | vae_path = model_dict['vae_path']
337 | audio_encoder_path = model_dict['audio_encoder_path']
338 |
339 | dtype = WEIGHT_DTYPE
340 | device = DEVICE
341 | inference_config_path = INFERENCE_CONFIG_PATH
342 |
343 | scheduler = get_scheduler(inference_config_path)
344 | reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device)
345 | denoising_unet = load_denoising_unet(
346 | inference_config_path, unet_config_path, denoising_unet_path, motion_module_path,
347 | dtype, device
348 | )
349 | v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
350 | audio_projection = load_audio_projection(
351 | audio_projection_path,
352 | dtype,
353 | device,
354 | inp_dim=denoising_unet.config.cross_attention_dim,
355 | mid_dim=denoising_unet.config.cross_attention_dim,
356 | out_dim=denoising_unet.config.cross_attention_dim,
357 | inp_seq_len=2 * (2 * NUM_PAD_AUDIO_FRAMES + 1),
358 | out_seq_len=2 * NUM_PAD_AUDIO_FRAMES + 1,
359 | )
360 |
361 | vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
362 | audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device)
363 | audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
364 |
365 | v_express_pipeline = VExpressPipeline(
366 | vae=vae,
367 | reference_net=reference_net,
368 | denoising_unet=denoising_unet,
369 | v_kps_guider=v_kps_guider,
370 | audio_processor=audio_processor,
371 | audio_encoder=audio_encoder,
372 | audio_projection=audio_projection,
373 | scheduler=scheduler,
374 | ).to(dtype=dtype, device=device)
375 |
376 | return (v_express_pipeline,)
377 |
378 |
379 | class Load_Audio_Path:
380 | @classmethod
381 | def INPUT_TYPES(s):
382 | files = []
383 | for f in os.listdir(INPUT_PATH):
384 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp3"]: # only support mp3
385 | files.append(f)
386 |
387 | return {"required":{
388 | "audio_path": (files,),
389 | }}
390 |
391 | CATEGORY = "V-Express"
392 |
393 | RETURN_TYPES = ("AUDIO_PATH",)
394 |
395 | FUNCTION = "load_audio_path"
396 | def load_audio_path(self, audio_path):
397 | audio_path = os.path.join(INPUT_PATH, audio_path)
398 | return (audio_path,)
399 |
400 |
401 | class Load_Audio_Path_From_Video:
402 | @classmethod
403 | def INPUT_TYPES(s):
404 | files = []
405 | for f in os.listdir(INPUT_PATH):
406 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]:
407 | files.append(f)
408 |
409 | return {"required":{
410 | "video_path": (files,),
411 | }}
412 |
413 | CATEGORY = "V-Express"
414 |
415 | RETURN_TYPES = ("AUDIO_PATH",)
416 |
417 | FUNCTION = "load_audio_path_from_video"
418 | def load_audio_path_from_video(self, video_path):
419 | video_path = os.path.join(INPUT_PATH, video_path)
420 | video_base_name = video_path[:video_path.rfind('.')]
421 | audio_name = f'{video_base_name}_audio.mp3'
422 | audio_path = os.path.join(INPUT_PATH, audio_name)
423 | os.system(f'{get_ffmpeg_exe()} -i "{video_path}" -y -vn "{audio_path}"')
424 | if not os.path.isfile(audio_path):
425 | raise ValueError(f'{audio_path} not exists! Please check if the video contains audio!')
426 | return (audio_path,)
427 |
428 |
429 | class Load_Kps_Path:
430 | @classmethod
431 | def INPUT_TYPES(s):
432 | files = []
433 | for f in os.listdir(INPUT_PATH):
434 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["pth"]:
435 | files.append(f)
436 |
437 | return {"required":{
438 | "kps_path": (files,),
439 | }}
440 |
441 | CATEGORY = "V-Express"
442 |
443 | RETURN_TYPES = ("VKPS_PATH",)
444 |
445 | FUNCTION = "load_kps_path"
446 | def load_kps_path(self, kps_path):
447 | kps_path = os.path.join(INPUT_PATH, kps_path)
448 | return (kps_path,)
449 |
450 |
451 | class Load_Kps_Path_From_Video:
452 | @classmethod
453 | def INPUT_TYPES(s):
454 | files = []
455 | for f in os.listdir(INPUT_PATH):
456 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]:
457 | files.append(f)
458 |
459 | return {"required":{
460 | "vexpress_model_path": ("STRING_INPUT", ),
461 | "video_path": (files,),
462 | "image_size": ("INT_INPUT",),
463 | }}
464 |
465 | CATEGORY = "V-Express"
466 |
467 | RETURN_TYPES = ("VKPS_PATH",)
468 |
469 | FUNCTION = "load_kps_path_from_video"
470 | def load_kps_path_from_video(self, vexpress_model_path, video_path, image_size):
471 | video_path = os.path.join(INPUT_PATH, video_path)
472 | video_base_name = video_path[:video_path.rfind('.')]
473 | kps_name = f'{video_base_name}_kps.pth'
474 | kps_path = os.path.join(INPUT_PATH, kps_name)
475 |
476 | model_dict = get_all_model_path(vexpress_model_path)
477 | insightface_model_path = model_dict['insightface_model_path']
478 |
479 | app = FaceAnalysis(
480 | providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'],
481 | provider_options=[{'device_id': GPU_ID}] if DEVICE == 'cuda' else [],
482 | root=insightface_model_path,
483 | )
484 | app.prepare(ctx_id=0, det_size=(image_size, image_size))
485 |
486 | kps_sequence = []
487 | video_capture = cv2.VideoCapture(video_path)
488 | frame_idx = 0
489 | while video_capture.isOpened():
490 | ret, frame = video_capture.read()
491 | if not ret:
492 | break
493 | frame = cv2.resize(frame, (image_size, image_size))
494 | faces = app.get(frame)
495 | assert len(faces) == 1, f'There are {len(faces)} faces in the {frame_idx}-th frame. Only one face is supported.'
496 |
497 | kps = faces[0].kps[:3]
498 | kps_sequence.append(kps)
499 | frame_idx += 1
500 | torch.save(kps_sequence, kps_path)
501 |
502 | if not os.path.isfile(kps_path):
503 | raise ValueError(f'{kps_path} not exists! Please check the input!')
504 | return (kps_path,)
505 |
506 |
507 | class Load_Image_Path:
508 | @classmethod
509 | def INPUT_TYPES(s):
510 | input_dir = INPUT_PATH
511 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
512 | return {"required":
513 | {"image": (sorted(files), {"image_upload": True})},
514 | }
515 |
516 | CATEGORY = "V-Express"
517 |
518 | RETURN_TYPES = ("IMAGE_PATH",)
519 | FUNCTION = "load_image_path"
520 | def load_image_path(self, image):
521 | image_path = os.path.join(INPUT_PATH, image)
522 | return (image_path,)
523 |
524 |
525 | class Load_Video_Path:
526 | @classmethod
527 | def INPUT_TYPES(s):
528 | files = []
529 | for f in os.listdir(INPUT_PATH):
530 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]:
531 | files.append(f)
532 |
533 | return {"required":{
534 | "video_path": (files,),
535 | }}
536 |
537 | CATEGORY = "V-Express"
538 |
539 | RETURN_TYPES = ("STRING_INPUT",)
540 |
541 | FUNCTION = "load_video_path"
542 | def load_video_path(self, video_path):
543 | video_path = os.path.join(INPUT_PATH, video_path)
544 | return (video_path,)
545 |
546 |
547 | class VEPreview_Video:
548 | @classmethod
549 | def INPUT_TYPES(s):
550 | return {"required":{
551 | "video":("STRING_INPUT",),
552 | }}
553 |
554 | CATEGORY = "V-Express"
555 | DESCRIPTION = "show result"
556 |
557 | RETURN_TYPES = ()
558 |
559 | OUTPUT_NODE = True
560 |
561 | FUNCTION = "load_video"
562 | def load_video(self, video):
563 | video_name = os.path.basename(video)
564 | video_path_name = os.path.basename(os.path.dirname(video))
565 | return {"ui":{"video":[video_name, video_path_name]}}
566 |
567 | @classmethod
568 | def IS_CHANGED(s,):
569 | return ""
570 |
571 |
572 | # A dictionary that contains all nodes you want to export with their names
573 | # NOTE: names should be globally unique
574 | NODE_CLASS_MAPPINGS = {
575 | "V_Express_Loader": V_Express_Loader,
576 | "V_Express_Sampler": V_Express_Sampler,
577 | "Load_Audio_Path": Load_Audio_Path,
578 | "Load_Audio_Path_From_Video": Load_Audio_Path_From_Video,
579 | "Load_Kps_Path": Load_Kps_Path,
580 | "Load_Kps_Path_From_Video": Load_Kps_Path_From_Video,
581 | "Load_Image_Path": Load_Image_Path,
582 | "Load_Video_Path": Load_Video_Path,
583 | "VEINTConstant": VEINTConstant,
584 | "VEStringConstant": VEStringConstant,
585 | "VEPreview_Video": VEPreview_Video,
586 | }
587 |
588 | # A dictionary that contains the friendly/humanly readable titles for the nodes
589 | NODE_DISPLAY_NAME_MAPPINGS = {
590 | "V_Express_Loader": "V-Express Loader",
591 | "V_Express_Sampler": "V-Express Sampler",
592 | "Load_Audio_Path": "Load Audio Path",
593 | "Load_Audio_Path_From_Video": "Load Audio Path From Video",
594 | "Load_Kps_Path": "Load V-Kps Path",
595 | "Load_Kps_Path_From_Video": "Load V-Kps Path From Video",
596 | "Load_Image_Path": "Load Reference Image Path",
597 | "Load_Video_Path": "Load Video Path",
598 | "VEINTConstant": "Set Image Size",
599 | "VEStringConstant": "Set V-Express Model Path",
600 | "VEPreview_Video": "Preview Output Video",
601 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.24.0
2 | imageio-ffmpeg==0.4.9
3 | insightface==0.7.3
4 | omegaconf==2.2.3
5 | onnxruntime==1.16.3
6 | safetensors==0.4.2
7 | torch==2.0.1
8 | torchaudio==2.0.2
9 | torchvision==0.15.2
10 | transformers==4.30.2
11 | einops==0.4.1
12 | tqdm==4.66.1
13 | xformers==0.0.22
14 | av==11.0.0
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import accelerate
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchaudio.functional
10 | import torchvision.io
11 | from PIL import Image
12 | from diffusers import AutoencoderKL, DDIMScheduler
13 | from diffusers.utils.import_utils import is_xformers_available
14 | from insightface.app import FaceAnalysis
15 | from omegaconf import OmegaConf
16 | from transformers import Wav2Vec2Model, Wav2Vec2Processor
17 |
18 | from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection
19 | from pipelines import VExpressPipeline
20 | from pipelines.utils import draw_kps_image, save_video
21 | from pipelines.utils import retarget_kps
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument('--unet_config_path', type=str, default='./model_ckpts/stable-diffusion-v1-5/unet/config.json')
28 | parser.add_argument('--vae_path', type=str, default='./model_ckpts/sd-vae-ft-mse/')
29 | parser.add_argument('--audio_encoder_path', type=str, default='./model_ckpts/wav2vec2-base-960h/')
30 | parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/')
31 |
32 | parser.add_argument('--denoising_unet_path', type=str, default='./model_ckpts/v-express/denoising_unet.bin')
33 | parser.add_argument('--reference_net_path', type=str, default='./model_ckpts/v-express/reference_net.bin')
34 | parser.add_argument('--v_kps_guider_path', type=str, default='./model_ckpts/v-express/v_kps_guider.bin')
35 | parser.add_argument('--audio_projection_path', type=str, default='./model_ckpts/v-express/audio_projection.bin')
36 | parser.add_argument('--motion_module_path', type=str, default='./model_ckpts/v-express/motion_module.bin')
37 |
38 | parser.add_argument('--retarget_strategy', type=str, default='fix_face',
39 | help='{fix_face, no_retarget, offset_retarget, naive_retarget}')
40 |
41 | parser.add_argument('--dtype', type=str, default='fp16')
42 | parser.add_argument('--device', type=str, default='cuda')
43 | parser.add_argument('--gpu_id', type=int, default=0)
44 | parser.add_argument('--do_multi_devices_inference', action='store_true')
45 | parser.add_argument('--save_gpu_memory', action='store_true')
46 |
47 | parser.add_argument('--num_pad_audio_frames', type=int, default=2)
48 | parser.add_argument('--standard_audio_sampling_rate', type=int, default=16000)
49 |
50 | parser.add_argument('--reference_image_path', type=str, default='./test_samples/emo/talk_emotion/ref.jpg')
51 | parser.add_argument('--audio_path', type=str, default='./test_samples/emo/talk_emotion/aud.mp3')
52 | parser.add_argument('--kps_path', type=str, default='./test_samples/emo/talk_emotion/kps.pth')
53 | parser.add_argument('--output_path', type=str, default='./output/emo/talk_emotion.mp4')
54 |
55 | parser.add_argument('--image_width', type=int, default=512)
56 | parser.add_argument('--image_height', type=int, default=512)
57 | parser.add_argument('--fps', type=float, default=30.0)
58 | parser.add_argument('--seed', type=int, default=42)
59 | parser.add_argument('--num_inference_steps', type=int, default=25)
60 | parser.add_argument('--guidance_scale', type=float, default=3.5)
61 | parser.add_argument('--context_frames', type=int, default=12)
62 | parser.add_argument('--context_overlap', type=int, default=4)
63 | parser.add_argument('--reference_attention_weight', default=0.95, type=float)
64 | parser.add_argument('--audio_attention_weight', default=3., type=float)
65 |
66 | args = parser.parse_args()
67 |
68 | return args
69 |
70 |
71 | def load_reference_net(unet_config_path, reference_net_path, dtype, device):
72 | reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
73 | reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
74 | print(f'Loaded weights of Reference Net from {reference_net_path}.')
75 | return reference_net
76 |
77 |
78 | def load_denoising_unet(inf_config_path, unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
79 | inference_config = OmegaConf.load(inf_config_path)
80 | denoising_unet = UNet3DConditionModel.from_config_2d(
81 | unet_config_path,
82 | unet_additional_kwargs=inference_config.unet_additional_kwargs,
83 | ).to(dtype=dtype, device=device)
84 | denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
85 | print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.')
86 |
87 | denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False)
88 | print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.')
89 |
90 | return denoising_unet
91 |
92 |
93 | def load_v_kps_guider(v_kps_guider_path, dtype, device):
94 | v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
95 | v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
96 | print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
97 | return v_kps_guider
98 |
99 |
100 | def load_audio_projection(
101 | audio_projection_path,
102 | dtype,
103 | device,
104 | inp_dim: int,
105 | mid_dim: int,
106 | out_dim: int,
107 | inp_seq_len: int,
108 | out_seq_len: int,
109 | ):
110 | audio_projection = AudioProjection(
111 | dim=mid_dim,
112 | depth=4,
113 | dim_head=64,
114 | heads=12,
115 | num_queries=out_seq_len,
116 | embedding_dim=inp_dim,
117 | output_dim=out_dim,
118 | ff_mult=4,
119 | max_seq_len=inp_seq_len,
120 | ).to(dtype=dtype, device=device)
121 | audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu'))
122 | print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
123 | return audio_projection
124 |
125 |
126 | def get_scheduler(inference_config_path):
127 | inference_config = OmegaConf.load(inference_config_path)
128 | scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
129 | scheduler = DDIMScheduler(**scheduler_kwargs)
130 | return scheduler
131 |
132 |
133 | def main():
134 | args = parse_args()
135 | start_time = time.time()
136 |
137 | if not args.do_multi_devices_inference:
138 | # TODO
139 | accelerator = None
140 | device = torch.device(f'{args.device}:{args.gpu_id}' if args.device == 'cuda' else args.device)
141 | else:
142 | accelerator = accelerate.Accelerator()
143 | device = torch.device(f'cuda:{accelerator.process_index}')
144 | dtype = torch.float16 if args.dtype == 'fp16' else torch.float32
145 |
146 | vae_path = args.vae_path
147 | audio_encoder_path = args.audio_encoder_path
148 |
149 | vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
150 | audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device)
151 | audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
152 |
153 | unet_config_path = args.unet_config_path
154 | reference_net_path = args.reference_net_path
155 | denoising_unet_path = args.denoising_unet_path
156 | v_kps_guider_path = args.v_kps_guider_path
157 | audio_projection_path = args.audio_projection_path
158 | motion_module_path = args.motion_module_path
159 |
160 | inference_config_path = './inference_v2.yaml'
161 | scheduler = get_scheduler(inference_config_path)
162 | reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device)
163 | denoising_unet = load_denoising_unet(
164 | inference_config_path, unet_config_path, denoising_unet_path, motion_module_path,
165 | dtype, device
166 | )
167 | v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
168 | audio_projection = load_audio_projection(
169 | audio_projection_path,
170 | dtype,
171 | device,
172 | inp_dim=denoising_unet.config.cross_attention_dim,
173 | mid_dim=denoising_unet.config.cross_attention_dim,
174 | out_dim=denoising_unet.config.cross_attention_dim,
175 | inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1),
176 | out_seq_len=2 * args.num_pad_audio_frames + 1,
177 | )
178 |
179 | if is_xformers_available():
180 | reference_net.enable_xformers_memory_efficient_attention()
181 | denoising_unet.enable_xformers_memory_efficient_attention()
182 | else:
183 | raise ValueError("xformers is not available. Make sure it is installed correctly")
184 |
185 | generator = torch.manual_seed(args.seed)
186 | pipeline = VExpressPipeline(
187 | vae=vae,
188 | reference_net=reference_net,
189 | denoising_unet=denoising_unet,
190 | v_kps_guider=v_kps_guider,
191 | audio_processor=audio_processor,
192 | audio_encoder=audio_encoder,
193 | audio_projection=audio_projection,
194 | scheduler=scheduler,
195 | ).to(dtype=dtype, device=device)
196 |
197 | app = FaceAnalysis(
198 | providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'],
199 | provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [],
200 | root=args.insightface_model_path,
201 | )
202 | app.prepare(ctx_id=0, det_size=(args.image_height, args.image_width))
203 |
204 | reference_image = Image.open(args.reference_image_path).convert('RGB')
205 | reference_image = reference_image.resize((args.image_height, args.image_width))
206 |
207 | reference_image_for_kps = cv2.imread(args.reference_image_path)
208 | reference_image_for_kps = cv2.resize(reference_image_for_kps, (args.image_width, args.image_height))
209 | reference_kps = app.get(reference_image_for_kps)[0].kps[:3]
210 | if args.save_gpu_memory:
211 | del app
212 | torch.cuda.empty_cache()
213 |
214 | _, audio_waveform, meta_info = torchvision.io.read_video(args.audio_path, pts_unit='sec')
215 | audio_sampling_rate = meta_info['audio_fps']
216 | print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.')
217 | if audio_sampling_rate != args.standard_audio_sampling_rate:
218 | audio_waveform = torchaudio.functional.resample(
219 | audio_waveform,
220 | orig_freq=audio_sampling_rate,
221 | new_freq=args.standard_audio_sampling_rate,
222 | )
223 | audio_waveform = audio_waveform.mean(dim=0)
224 |
225 | duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate
226 | init_video_length = int(duration * args.fps)
227 | num_contexts = np.around((init_video_length + args.context_overlap) / args.context_frames)
228 | video_length = int(num_contexts * args.context_frames - args.context_overlap)
229 | fps = video_length / duration
230 | print(f'The corresponding video length is {video_length}.')
231 |
232 | kps_sequence = None
233 | if args.kps_path != "":
234 | assert os.path.exists(args.kps_path), f'{args.kps_path} does not exist'
235 | kps_sequence = torch.tensor(torch.load(args.kps_path)) # [len, 3, 2]
236 | print(f'The original length of kps sequence is {kps_sequence.shape[0]}.')
237 |
238 | if kps_sequence.shape[0] > video_length:
239 | kps_sequence = kps_sequence[:video_length, :, :]
240 |
241 | kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
242 | kps_sequence = kps_sequence.permute(2, 0, 1)
243 | print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.')
244 |
245 | retarget_strategy = args.retarget_strategy
246 | if retarget_strategy == 'fix_face':
247 | kps_sequence = torch.tensor([reference_kps] * video_length)
248 | elif retarget_strategy == 'no_retarget':
249 | kps_sequence = kps_sequence
250 | elif retarget_strategy == 'offset_retarget':
251 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
252 | elif retarget_strategy == 'naive_retarget':
253 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
254 | else:
255 | raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.')
256 |
257 | kps_images = []
258 | for i in range(video_length):
259 | kps_image = draw_kps_image(args.image_height, args.image_width, kps_sequence[i])
260 | kps_images.append(Image.fromarray(kps_image))
261 |
262 | video_tensor = pipeline(
263 | reference_image=reference_image,
264 | kps_images=kps_images,
265 | audio_waveform=audio_waveform,
266 | width=args.image_width,
267 | height=args.image_height,
268 | video_length=video_length,
269 | num_inference_steps=args.num_inference_steps,
270 | guidance_scale=args.guidance_scale,
271 | context_frames=args.context_frames,
272 | context_overlap=args.context_overlap,
273 | reference_attention_weight=args.reference_attention_weight,
274 | audio_attention_weight=args.audio_attention_weight,
275 | num_pad_audio_frames=args.num_pad_audio_frames,
276 | generator=generator,
277 | do_multi_devices_inference=args.do_multi_devices_inference,
278 | save_gpu_memory=args.save_gpu_memory,
279 | )
280 |
281 | if accelerator is None or accelerator.is_main_process:
282 | save_video(video_tensor, args.audio_path, args.output_path, device, fps)
283 | consumed_time = time.time() - start_time
284 | generation_fps = video_tensor.shape[2] / consumed_time
285 | print(f'The generated video has been saved at {args.output_path}. '
286 | f'The generation time is {consumed_time:.1f} seconds. '
287 | f'The generation FPS is {generation_fps:.2f}.')
288 |
289 |
290 | if __name__ == '__main__':
291 | main()
292 |
--------------------------------------------------------------------------------
/src/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "scaled_linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
36 |
--------------------------------------------------------------------------------
/src/modules/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/src/modules/.DS_Store
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_2d_condition import UNet2DConditionModel
2 | from .unet_3d import UNet3DConditionModel
3 | from .v_kps_guider import VKpsGuider
4 | from .audio_projection import AudioProjection
5 | from .mutual_self_attention import ReferenceAttentionControl
6 |
--------------------------------------------------------------------------------
/src/modules/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from typing import Any, Dict, Optional
4 |
5 | import torch
6 | from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, Attention, FeedForward, GatedSelfAttentionDense
7 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8 | from einops import rearrange
9 | from torch import nn
10 |
11 |
12 | class BasicTransformerBlock(nn.Module):
13 | r"""
14 | A basic Transformer block.
15 |
16 | Parameters:
17 | dim (`int`): The number of channels in the input and output.
18 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
19 | attention_head_dim (`int`): The number of channels in each head.
20 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23 | num_embeds_ada_norm (:
24 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25 | attention_bias (:
26 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27 | only_cross_attention (`bool`, *optional*):
28 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
29 | double_self_attention (`bool`, *optional*):
30 | Whether to use two self-attention layers. In this case no cross attention layers are used.
31 | upcast_attention (`bool`, *optional*):
32 | Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34 | Whether to use learnable elementwise affine parameters for normalization.
35 | norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37 | final_dropout (`bool` *optional*, defaults to False):
38 | Whether to apply a final dropout after the last feed-forward layer.
39 | attention_type (`str`, *optional*, defaults to `"default"`):
40 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41 | positional_embeddings (`str`, *optional*, defaults to `None`):
42 | The type of positional embeddings to apply to.
43 | num_positional_embeddings (`int`, *optional*, defaults to `None`):
44 | The maximum number of positional embeddings to apply.
45 | """
46 |
47 | def __init__(
48 | self,
49 | dim: int,
50 | num_attention_heads: int,
51 | attention_head_dim: int,
52 | dropout=0.0,
53 | cross_attention_dim: Optional[int] = None,
54 | activation_fn: str = "geglu",
55 | num_embeds_ada_norm: Optional[int] = None,
56 | attention_bias: bool = False,
57 | only_cross_attention: bool = False,
58 | double_self_attention: bool = False,
59 | upcast_attention: bool = False,
60 | norm_elementwise_affine: bool = True,
61 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62 | norm_eps: float = 1e-5,
63 | final_dropout: bool = False,
64 | attention_type: str = "default",
65 | positional_embeddings: Optional[str] = None,
66 | num_positional_embeddings: Optional[int] = None,
67 | ):
68 | super().__init__()
69 | self.only_cross_attention = only_cross_attention
70 |
71 | self.use_ada_layer_norm_zero = (
72 | num_embeds_ada_norm is not None
73 | ) and norm_type == "ada_norm_zero"
74 | self.use_ada_layer_norm = (
75 | num_embeds_ada_norm is not None
76 | ) and norm_type == "ada_norm"
77 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78 | self.use_layer_norm = norm_type == "layer_norm"
79 |
80 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81 | raise ValueError(
82 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84 | )
85 |
86 | if positional_embeddings and (num_positional_embeddings is None):
87 | raise ValueError(
88 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89 | )
90 |
91 | if positional_embeddings == "sinusoidal":
92 | self.pos_embed = SinusoidalPositionalEmbedding(
93 | dim, max_seq_length=num_positional_embeddings
94 | )
95 | else:
96 | self.pos_embed = None
97 |
98 | # Define 3 blocks. Each block has its own normalization layer.
99 | # 1. Self-Attn
100 | if self.use_ada_layer_norm:
101 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102 | elif self.use_ada_layer_norm_zero:
103 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104 | else:
105 | self.norm1 = nn.LayerNorm(
106 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107 | )
108 |
109 | self.attn1 = Attention(
110 | query_dim=dim,
111 | heads=num_attention_heads,
112 | dim_head=attention_head_dim,
113 | dropout=dropout,
114 | bias=attention_bias,
115 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116 | upcast_attention=upcast_attention,
117 | )
118 |
119 | # 2. Cross-Attn
120 | if cross_attention_dim is not None or double_self_attention:
121 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123 | # the second cross attention block.
124 | self.norm2 = (
125 | AdaLayerNorm(dim, num_embeds_ada_norm)
126 | if self.use_ada_layer_norm
127 | else nn.LayerNorm(
128 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129 | )
130 | )
131 | self.attn2 = Attention(
132 | query_dim=dim,
133 | cross_attention_dim=cross_attention_dim
134 | if not double_self_attention
135 | else None,
136 | heads=num_attention_heads,
137 | dim_head=attention_head_dim,
138 | dropout=dropout,
139 | bias=attention_bias,
140 | upcast_attention=upcast_attention,
141 | ) # is self-attn if encoder_hidden_states is none
142 | else:
143 | self.norm2 = None
144 | self.attn2 = None
145 |
146 | # 3. Feed-forward
147 | if not self.use_ada_layer_norm_single:
148 | self.norm3 = nn.LayerNorm(
149 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150 | )
151 |
152 | self.ff = FeedForward(
153 | dim,
154 | dropout=dropout,
155 | activation_fn=activation_fn,
156 | final_dropout=final_dropout,
157 | )
158 |
159 | # 4. Fuser
160 | if attention_type == "gated" or attention_type == "gated-text-image":
161 | self.fuser = GatedSelfAttentionDense(
162 | dim, cross_attention_dim, num_attention_heads, attention_head_dim
163 | )
164 |
165 | # 5. Scale-shift for PixArt-Alpha.
166 | if self.use_ada_layer_norm_single:
167 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168 |
169 | # let chunk size default to None
170 | self._chunk_size = None
171 | self._chunk_dim = 0
172 |
173 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174 | # Sets chunk feed-forward
175 | self._chunk_size = chunk_size
176 | self._chunk_dim = dim
177 |
178 | def forward(
179 | self,
180 | hidden_states: torch.FloatTensor,
181 | attention_mask: Optional[torch.FloatTensor] = None,
182 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
183 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
184 | timestep: Optional[torch.LongTensor] = None,
185 | cross_attention_kwargs: Dict[str, Any] = None,
186 | class_labels: Optional[torch.LongTensor] = None,
187 | ) -> torch.FloatTensor:
188 | # Notice that normalization is always applied before the real computation in the following blocks.
189 | # 0. Self-Attention
190 | batch_size = hidden_states.shape[0]
191 |
192 | if self.use_ada_layer_norm:
193 | norm_hidden_states = self.norm1(hidden_states, timestep)
194 | elif self.use_ada_layer_norm_zero:
195 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197 | )
198 | elif self.use_layer_norm:
199 | norm_hidden_states = self.norm1(hidden_states)
200 | elif self.use_ada_layer_norm_single:
201 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203 | ).chunk(6, dim=1)
204 | norm_hidden_states = self.norm1(hidden_states)
205 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206 | norm_hidden_states = norm_hidden_states.squeeze(1)
207 | else:
208 | raise ValueError("Incorrect norm used")
209 |
210 | if self.pos_embed is not None:
211 | norm_hidden_states = self.pos_embed(norm_hidden_states)
212 |
213 | # 1. Retrieve lora scale.
214 | lora_scale = (
215 | cross_attention_kwargs.get("scale", 1.0)
216 | if cross_attention_kwargs is not None
217 | else 1.0
218 | )
219 |
220 | # 2. Prepare GLIGEN inputs
221 | cross_attention_kwargs = (
222 | cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223 | )
224 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225 |
226 | attn_output = self.attn1(
227 | norm_hidden_states,
228 | encoder_hidden_states=encoder_hidden_states
229 | if self.only_cross_attention
230 | else None,
231 | attention_mask=attention_mask,
232 | **cross_attention_kwargs,
233 | )
234 | if self.use_ada_layer_norm_zero:
235 | attn_output = gate_msa.unsqueeze(1) * attn_output
236 | elif self.use_ada_layer_norm_single:
237 | attn_output = gate_msa * attn_output
238 |
239 | hidden_states = attn_output + hidden_states
240 | if hidden_states.ndim == 4:
241 | hidden_states = hidden_states.squeeze(1)
242 |
243 | # 2.5 GLIGEN Control
244 | if gligen_kwargs is not None:
245 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246 |
247 | # 3. Cross-Attention
248 | if self.attn2 is not None:
249 | if self.use_ada_layer_norm:
250 | norm_hidden_states = self.norm2(hidden_states, timestep)
251 | elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252 | norm_hidden_states = self.norm2(hidden_states)
253 | elif self.use_ada_layer_norm_single:
254 | # For PixArt norm2 isn't applied here:
255 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256 | norm_hidden_states = hidden_states
257 | else:
258 | raise ValueError("Incorrect norm")
259 |
260 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261 | norm_hidden_states = self.pos_embed(norm_hidden_states)
262 |
263 | attn_output = self.attn2(
264 | norm_hidden_states,
265 | encoder_hidden_states=encoder_hidden_states,
266 | attention_mask=encoder_attention_mask,
267 | **cross_attention_kwargs,
268 | )
269 | hidden_states = attn_output + hidden_states
270 |
271 | # 4. Feed-forward
272 | if not self.use_ada_layer_norm_single:
273 | norm_hidden_states = self.norm3(hidden_states)
274 |
275 | if self.use_ada_layer_norm_zero:
276 | norm_hidden_states = (
277 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278 | )
279 |
280 | if self.use_ada_layer_norm_single:
281 | norm_hidden_states = self.norm2(hidden_states)
282 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283 |
284 | ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285 |
286 | if self.use_ada_layer_norm_zero:
287 | ff_output = gate_mlp.unsqueeze(1) * ff_output
288 | elif self.use_ada_layer_norm_single:
289 | ff_output = gate_mlp * ff_output
290 |
291 | hidden_states = ff_output + hidden_states
292 | if hidden_states.ndim == 4:
293 | hidden_states = hidden_states.squeeze(1)
294 |
295 | return hidden_states
296 |
297 |
298 | class TemporalBasicTransformerBlock(nn.Module):
299 | def __init__(
300 | self,
301 | dim: int,
302 | num_attention_heads: int,
303 | attention_head_dim: int,
304 | dropout=0.0,
305 | cross_attention_dim: Optional[int] = None,
306 | activation_fn: str = "geglu",
307 | num_embeds_ada_norm: Optional[int] = None,
308 | attention_bias: bool = False,
309 | only_cross_attention: bool = False,
310 | upcast_attention: bool = False,
311 | unet_use_cross_frame_attention=None,
312 | unet_use_temporal_attention=None,
313 | ):
314 | super().__init__()
315 | self.only_cross_attention = only_cross_attention
316 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
317 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318 | self.unet_use_temporal_attention = unet_use_temporal_attention
319 |
320 | # old self attention layer for only self-attention
321 | self.attn1 = Attention(
322 | query_dim=dim,
323 | heads=num_attention_heads,
324 | dim_head=attention_head_dim,
325 | dropout=dropout,
326 | bias=attention_bias,
327 | upcast_attention=upcast_attention,
328 | )
329 | self.norm1 = (
330 | AdaLayerNorm(dim, num_embeds_ada_norm)
331 | if self.use_ada_layer_norm
332 | else nn.LayerNorm(dim)
333 | )
334 |
335 | # new self attention layer for reference features
336 | self.attn1_5 = Attention(
337 | query_dim=dim,
338 | heads=num_attention_heads,
339 | dim_head=attention_head_dim,
340 | dropout=dropout,
341 | bias=attention_bias,
342 | upcast_attention=upcast_attention,
343 | )
344 | self.norm1_5 = (
345 | AdaLayerNorm(dim, num_embeds_ada_norm)
346 | if self.use_ada_layer_norm
347 | else nn.LayerNorm(dim)
348 | )
349 |
350 | # Cross-Attn
351 | if cross_attention_dim is not None:
352 | self.attn2 = Attention(
353 | query_dim=dim,
354 | cross_attention_dim=cross_attention_dim,
355 | heads=num_attention_heads,
356 | dim_head=attention_head_dim,
357 | dropout=dropout,
358 | bias=attention_bias,
359 | upcast_attention=upcast_attention,
360 | )
361 | else:
362 | self.attn2 = None
363 |
364 | if cross_attention_dim is not None:
365 | self.norm2 = (
366 | AdaLayerNorm(dim, num_embeds_ada_norm)
367 | if self.use_ada_layer_norm
368 | else nn.LayerNorm(dim)
369 | )
370 | else:
371 | self.norm2 = None
372 |
373 | # Feed-forward
374 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
375 | self.norm3 = nn.LayerNorm(dim)
376 | self.use_ada_layer_norm_zero = False
377 |
378 | # Temp-Attn
379 | assert unet_use_temporal_attention is not None
380 | if unet_use_temporal_attention:
381 | self.attn_temp = Attention(
382 | query_dim=dim,
383 | heads=num_attention_heads,
384 | dim_head=attention_head_dim,
385 | dropout=dropout,
386 | bias=attention_bias,
387 | upcast_attention=upcast_attention,
388 | )
389 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
390 | self.norm_temp = (
391 | AdaLayerNorm(dim, num_embeds_ada_norm)
392 | if self.use_ada_layer_norm
393 | else nn.LayerNorm(dim)
394 | )
395 |
396 | def forward(
397 | self,
398 | hidden_states,
399 | encoder_hidden_states=None,
400 | timestep=None,
401 | attention_mask=None,
402 | video_length=None,
403 | ):
404 | norm_hidden_states = (
405 | self.norm1(hidden_states, timestep)
406 | if self.use_ada_layer_norm
407 | else self.norm1(hidden_states)
408 | )
409 |
410 | if self.unet_use_cross_frame_attention:
411 | hidden_states = (
412 | self.attn1(
413 | norm_hidden_states,
414 | attention_mask=attention_mask,
415 | video_length=video_length,
416 | )
417 | + hidden_states
418 | )
419 | else:
420 | hidden_states = (
421 | self.attn1(norm_hidden_states, attention_mask=attention_mask)
422 | + hidden_states
423 | )
424 |
425 | norm_hidden_states = (
426 | self.norm1_5(hidden_states, timestep)
427 | if self.use_ada_layer_norm
428 | else self.norm1_5(hidden_states)
429 | )
430 |
431 | if self.unet_use_cross_frame_attention:
432 | hidden_states = (
433 | self.attn1_5(
434 | norm_hidden_states,
435 | attention_mask=attention_mask,
436 | video_length=video_length,
437 | )
438 | + hidden_states
439 | )
440 | else:
441 | hidden_states = (
442 | self.attn1_5(norm_hidden_states, attention_mask=attention_mask)
443 | + hidden_states
444 | )
445 |
446 | if self.attn2 is not None:
447 | # Cross-Attention
448 | norm_hidden_states = (
449 | self.norm2(hidden_states, timestep)
450 | if self.use_ada_layer_norm
451 | else self.norm2(hidden_states)
452 | )
453 | hidden_states = (
454 | self.attn2(
455 | norm_hidden_states,
456 | encoder_hidden_states=encoder_hidden_states,
457 | attention_mask=attention_mask,
458 | )
459 | + hidden_states
460 | )
461 |
462 | # Feed-forward
463 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
464 |
465 | # Temporal-Attention
466 | if self.unet_use_temporal_attention:
467 | d = hidden_states.shape[1]
468 | hidden_states = rearrange(
469 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
470 | )
471 | norm_hidden_states = (
472 | self.norm_temp(hidden_states, timestep)
473 | if self.use_ada_layer_norm
474 | else self.norm_temp(hidden_states)
475 | )
476 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
477 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
478 |
479 | return hidden_states
480 |
481 | class TemporalBasicTransformerBlockOld(nn.Module):
482 | def __init__(
483 | self,
484 | dim: int,
485 | num_attention_heads: int,
486 | attention_head_dim: int,
487 | dropout=0.0,
488 | cross_attention_dim: Optional[int] = None,
489 | activation_fn: str = "geglu",
490 | num_embeds_ada_norm: Optional[int] = None,
491 | attention_bias: bool = False,
492 | only_cross_attention: bool = False,
493 | upcast_attention: bool = False,
494 | unet_use_cross_frame_attention=None,
495 | unet_use_temporal_attention=None,
496 | ):
497 | super().__init__()
498 | self.only_cross_attention = only_cross_attention
499 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
500 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
501 | self.unet_use_temporal_attention = unet_use_temporal_attention
502 |
503 | # SC-Attn
504 | self.attn1 = Attention(
505 | query_dim=dim,
506 | heads=num_attention_heads,
507 | dim_head=attention_head_dim,
508 | dropout=dropout,
509 | bias=attention_bias,
510 | upcast_attention=upcast_attention,
511 | )
512 | self.norm1 = (
513 | AdaLayerNorm(dim, num_embeds_ada_norm)
514 | if self.use_ada_layer_norm
515 | else nn.LayerNorm(dim)
516 | )
517 |
518 | # Cross-Attn
519 | if cross_attention_dim is not None:
520 | self.attn2 = Attention(
521 | query_dim=dim,
522 | cross_attention_dim=cross_attention_dim,
523 | heads=num_attention_heads,
524 | dim_head=attention_head_dim,
525 | dropout=dropout,
526 | bias=attention_bias,
527 | upcast_attention=upcast_attention,
528 | )
529 | else:
530 | self.attn2 = None
531 |
532 | if cross_attention_dim is not None:
533 | self.norm2 = (
534 | AdaLayerNorm(dim, num_embeds_ada_norm)
535 | if self.use_ada_layer_norm
536 | else nn.LayerNorm(dim)
537 | )
538 | else:
539 | self.norm2 = None
540 |
541 | # Feed-forward
542 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
543 | self.norm3 = nn.LayerNorm(dim)
544 | self.use_ada_layer_norm_zero = False
545 |
546 | # Temp-Attn
547 | assert unet_use_temporal_attention is not None
548 | if unet_use_temporal_attention:
549 | self.attn_temp = Attention(
550 | query_dim=dim,
551 | heads=num_attention_heads,
552 | dim_head=attention_head_dim,
553 | dropout=dropout,
554 | bias=attention_bias,
555 | upcast_attention=upcast_attention,
556 | )
557 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
558 | self.norm_temp = (
559 | AdaLayerNorm(dim, num_embeds_ada_norm)
560 | if self.use_ada_layer_norm
561 | else nn.LayerNorm(dim)
562 | )
563 |
564 | def forward(
565 | self,
566 | hidden_states,
567 | encoder_hidden_states=None,
568 | timestep=None,
569 | attention_mask=None,
570 | video_length=None,
571 | ):
572 | norm_hidden_states = (
573 | self.norm1(hidden_states, timestep)
574 | if self.use_ada_layer_norm
575 | else self.norm1(hidden_states)
576 | )
577 |
578 | if self.unet_use_cross_frame_attention:
579 | hidden_states = (
580 | self.attn1(
581 | norm_hidden_states,
582 | attention_mask=attention_mask,
583 | video_length=video_length,
584 | )
585 | + hidden_states
586 | )
587 | else:
588 | hidden_states = (
589 | self.attn1(norm_hidden_states, attention_mask=attention_mask)
590 | + hidden_states
591 | )
592 |
593 | if self.attn2 is not None:
594 | # Cross-Attention
595 | norm_hidden_states = (
596 | self.norm2(hidden_states, timestep)
597 | if self.use_ada_layer_norm
598 | else self.norm2(hidden_states)
599 | )
600 | hidden_states = (
601 | self.attn2(
602 | norm_hidden_states,
603 | encoder_hidden_states=encoder_hidden_states,
604 | attention_mask=attention_mask,
605 | )
606 | + hidden_states
607 | )
608 |
609 | # Feed-forward
610 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
611 |
612 | # Temporal-Attention
613 | if self.unet_use_temporal_attention:
614 | d = hidden_states.shape[1]
615 | hidden_states = rearrange(
616 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
617 | )
618 | norm_hidden_states = (
619 | self.norm_temp(hidden_states, timestep)
620 | if self.use_ada_layer_norm
621 | else self.norm_temp(hidden_states)
622 | )
623 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
624 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
625 |
626 | return hidden_states
--------------------------------------------------------------------------------
/src/modules/audio_projection.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from diffusers.models.modeling_utils import ModelMixin
6 | from einops import rearrange
7 | from einops.layers.torch import Rearrange
8 |
9 |
10 | def reshape_tensor(x, heads):
11 | bs, length, width = x.shape
12 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
13 | x = x.view(bs, length, heads, -1)
14 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
15 | x = x.transpose(1, 2)
16 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
17 | x = x.reshape(bs, heads, length, -1)
18 | return x
19 |
20 |
21 | def masked_mean(t, *, dim, mask=None):
22 | if mask is None:
23 | return t.mean(dim=dim)
24 |
25 | denom = mask.sum(dim=dim, keepdim=True)
26 | mask = rearrange(mask, "b n -> b n 1")
27 | masked_t = t.masked_fill(~mask, 0.0)
28 |
29 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
30 |
31 |
32 | class PerceiverAttention(nn.Module):
33 | def __init__(self, *, dim, dim_head=64, heads=8):
34 | super().__init__()
35 | self.scale = dim_head ** -0.5
36 | self.dim_head = dim_head
37 | self.heads = heads
38 | inner_dim = dim_head * heads
39 |
40 | self.norm1 = nn.LayerNorm(dim)
41 | self.norm2 = nn.LayerNorm(dim)
42 |
43 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
44 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
45 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
46 |
47 | def forward(self, x, latents):
48 | """
49 | Args:
50 | x (torch.Tensor): image features
51 | shape (b, n1, D)
52 | latent (torch.Tensor): latent features
53 | shape (b, n2, D)
54 | """
55 | x = self.norm1(x)
56 | latents = self.norm2(latents)
57 |
58 | b, l, _ = latents.shape
59 |
60 | q = self.to_q(latents)
61 | kv_input = torch.cat((x, latents), dim=-2)
62 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
63 |
64 | q = reshape_tensor(q, self.heads)
65 | k = reshape_tensor(k, self.heads)
66 | v = reshape_tensor(v, self.heads)
67 |
68 | # attention
69 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
70 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
71 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
72 | out = weight @ v
73 |
74 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
75 |
76 | return self.to_out(out)
77 |
78 |
79 | def FeedForward(dim, mult=4):
80 | inner_dim = int(dim * mult)
81 | return nn.Sequential(
82 | nn.LayerNorm(dim),
83 | nn.Linear(dim, inner_dim, bias=False),
84 | nn.GELU(),
85 | nn.Linear(inner_dim, dim, bias=False),
86 | )
87 |
88 |
89 | class AudioProjection(ModelMixin):
90 | def __init__(
91 | self,
92 | dim=1024,
93 | depth=8,
94 | dim_head=64,
95 | heads=16,
96 | num_queries=8,
97 | embedding_dim=768,
98 | output_dim=1024,
99 | ff_mult=4,
100 | max_seq_len: int = 257,
101 | num_latents_mean_pooled: int = 0,
102 | ):
103 | super().__init__()
104 |
105 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim)
106 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
107 |
108 | self.proj_in = nn.Linear(embedding_dim, dim)
109 |
110 | self.proj_out = nn.Linear(dim, output_dim)
111 | self.norm_out = nn.LayerNorm(output_dim)
112 |
113 | self.to_latents_from_mean_pooled_seq = (
114 | nn.Sequential(
115 | nn.LayerNorm(dim),
116 | nn.Linear(dim, dim * num_latents_mean_pooled),
117 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
118 | )
119 | if num_latents_mean_pooled > 0
120 | else None
121 | )
122 |
123 | self.layers = nn.ModuleList([])
124 | for _ in range(depth):
125 | self.layers.append(nn.ModuleList([
126 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
127 | FeedForward(dim=dim, mult=ff_mult),
128 | ]))
129 |
130 | def forward(self, x):
131 | if self.pos_emb is not None:
132 | n, device = x.shape[1], x.device
133 | pos_emb = self.pos_emb(torch.arange(n, device=device))
134 | x = x + pos_emb
135 |
136 | latents = self.latents.repeat(x.size(0), 1, 1)
137 |
138 | x = self.proj_in(x)
139 |
140 | if self.to_latents_from_mean_pooled_seq:
141 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
142 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
143 | latents = torch.cat((meanpooled_latents, latents), dim=-2)
144 |
145 | for attn, ff in self.layers:
146 | latents = attn(x, latents) + latents
147 | latents = ff(latents) + latents
148 |
149 | latents = self.proj_out(latents)
150 | return self.norm_out(latents)
151 |
--------------------------------------------------------------------------------
/src/modules/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/src/modules/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from .attention import BasicTransformerBlock
8 | from .attention import TemporalBasicTransformerBlock
9 |
10 |
11 | def torch_dfs(model: torch.nn.Module):
12 | result = [model]
13 | for child in model.children():
14 | result += torch_dfs(child)
15 | return result
16 |
17 |
18 | class ReferenceAttentionControl:
19 | def __init__(
20 | self,
21 | unet,
22 | mode="write",
23 | do_classifier_free_guidance=False,
24 | attention_auto_machine_weight=float("inf"),
25 | gn_auto_machine_weight=1.0,
26 | style_fidelity=1.0,
27 | reference_attn=True,
28 | reference_adain=False,
29 | fusion_blocks="midup",
30 | batch_size=1,
31 | reference_attention_weight=1.,
32 | audio_attention_weight=1.,
33 | ) -> None:
34 | # 10. Modify self attention and group norm
35 | self.unet = unet
36 | assert mode in ["read", "write"]
37 | assert fusion_blocks in ["midup", "full"]
38 | self.reference_attn = reference_attn
39 | self.reference_adain = reference_adain
40 | self.fusion_blocks = fusion_blocks
41 | self.reference_attention_weight = reference_attention_weight
42 | self.audio_attention_weight = audio_attention_weight
43 | self.register_reference_hooks(
44 | mode,
45 | do_classifier_free_guidance,
46 | attention_auto_machine_weight,
47 | gn_auto_machine_weight,
48 | style_fidelity,
49 | reference_attn,
50 | reference_adain,
51 | fusion_blocks,
52 | batch_size=batch_size,
53 | )
54 |
55 | def register_reference_hooks(
56 | self,
57 | mode,
58 | do_classifier_free_guidance,
59 | attention_auto_machine_weight,
60 | gn_auto_machine_weight,
61 | style_fidelity,
62 | reference_attn,
63 | reference_adain,
64 | dtype=torch.float16,
65 | batch_size=1,
66 | num_images_per_prompt=1,
67 | device=torch.device("cpu"),
68 | fusion_blocks="midup",
69 | ):
70 | MODE = mode
71 | do_classifier_free_guidance = do_classifier_free_guidance
72 | attention_auto_machine_weight = attention_auto_machine_weight
73 | gn_auto_machine_weight = gn_auto_machine_weight
74 | style_fidelity = style_fidelity
75 | reference_attn = reference_attn
76 | reference_adain = reference_adain
77 | fusion_blocks = fusion_blocks
78 | num_images_per_prompt = num_images_per_prompt
79 | reference_attention_weight = self.reference_attention_weight
80 | audio_attention_weight = self.audio_attention_weight
81 | dtype = dtype
82 | if do_classifier_free_guidance:
83 | uc_mask = (
84 | torch.Tensor(
85 | [1] * batch_size * num_images_per_prompt * 16
86 | + [0] * batch_size * num_images_per_prompt * 16
87 | )
88 | .to(device)
89 | .bool()
90 | )
91 | else:
92 | uc_mask = (
93 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
94 | .to(device)
95 | .bool()
96 | )
97 |
98 | def hacked_basic_transformer_inner_forward(
99 | self,
100 | hidden_states: torch.FloatTensor,
101 | attention_mask: Optional[torch.FloatTensor] = None,
102 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
103 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
104 | timestep: Optional[torch.LongTensor] = None,
105 | cross_attention_kwargs: Dict[str, Any] = None,
106 | class_labels: Optional[torch.LongTensor] = None,
107 | video_length=None,
108 | ):
109 | if self.use_ada_layer_norm: # False
110 | norm_hidden_states = self.norm1(hidden_states, timestep)
111 | elif self.use_ada_layer_norm_zero:
112 | (
113 | norm_hidden_states,
114 | gate_msa,
115 | shift_mlp,
116 | scale_mlp,
117 | gate_mlp,
118 | ) = self.norm1(
119 | hidden_states,
120 | timestep,
121 | class_labels,
122 | hidden_dtype=hidden_states.dtype,
123 | )
124 | else:
125 | norm_hidden_states = self.norm1(hidden_states)
126 |
127 | # 1. Self-Attention
128 | # self.only_cross_attention = False
129 | cross_attention_kwargs = (
130 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
131 | )
132 | if self.only_cross_attention:
133 | attn_output = self.attn1(
134 | norm_hidden_states,
135 | encoder_hidden_states=encoder_hidden_states
136 | if self.only_cross_attention
137 | else None,
138 | attention_mask=attention_mask,
139 | **cross_attention_kwargs,
140 | )
141 | else:
142 | if MODE == "write":
143 | attn_output = self.attn1(
144 | norm_hidden_states,
145 | encoder_hidden_states=encoder_hidden_states
146 | if self.only_cross_attention
147 | else None,
148 | attention_mask=attention_mask,
149 | **cross_attention_kwargs,
150 | )
151 |
152 | if self.use_ada_layer_norm_zero:
153 | attn_output = gate_msa.unsqueeze(1) * attn_output
154 | hidden_states = attn_output + hidden_states
155 |
156 | if self.attn2 is not None:
157 | norm_hidden_states = (
158 | self.norm2(hidden_states, timestep)
159 | if self.use_ada_layer_norm
160 | else self.norm2(hidden_states)
161 | )
162 | self.bank.append(norm_hidden_states.clone())
163 |
164 | # 2. Cross-Attention
165 | attn_output = self.attn2(
166 | norm_hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | attention_mask=encoder_attention_mask,
169 | **cross_attention_kwargs,
170 | )
171 | hidden_states = attn_output + hidden_states
172 |
173 | if MODE == "read":
174 | hidden_states = (
175 | self.attn1(
176 | norm_hidden_states,
177 | encoder_hidden_states=norm_hidden_states,
178 | attention_mask=attention_mask,
179 | )
180 | + hidden_states
181 | )
182 |
183 | if self.use_ada_layer_norm: # False
184 | norm_hidden_states = self.norm1_5(hidden_states, timestep)
185 | elif self.use_ada_layer_norm_zero:
186 | (
187 | norm_hidden_states,
188 | gate_msa,
189 | shift_mlp,
190 | scale_mlp,
191 | gate_mlp,
192 | ) = self.norm1_5(
193 | hidden_states,
194 | timestep,
195 | class_labels,
196 | hidden_dtype=hidden_states.dtype,
197 | )
198 | else:
199 | norm_hidden_states = self.norm1_5(hidden_states)
200 |
201 | bank_fea = []
202 | for d in self.bank:
203 | if len(d.shape) == 3:
204 | d = d.unsqueeze(1).repeat(1, video_length, 1, 1)
205 | bank_fea.append(rearrange(d, "b t l c -> (b t) l c"))
206 |
207 | attn_hidden_states = self.attn1_5(
208 | norm_hidden_states,
209 | encoder_hidden_states=bank_fea[0],
210 | attention_mask=attention_mask,
211 | )
212 |
213 | if reference_attention_weight != 1.:
214 | attn_hidden_states *= reference_attention_weight
215 |
216 | hidden_states = (attn_hidden_states + hidden_states)
217 |
218 | # self.bank.clear()
219 | if self.attn2 is not None:
220 | # Cross-Attention
221 | norm_hidden_states = (
222 | self.norm2(hidden_states, timestep)
223 | if self.use_ada_layer_norm
224 | else self.norm2(hidden_states)
225 | )
226 |
227 | attn_hidden_states = self.attn2(
228 | norm_hidden_states,
229 | encoder_hidden_states=encoder_hidden_states,
230 | attention_mask=attention_mask,
231 | )
232 |
233 | if audio_attention_weight != 1.:
234 | attn_hidden_states *= audio_attention_weight
235 |
236 | hidden_states = (attn_hidden_states + hidden_states)
237 |
238 | # Feed-forward
239 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
240 |
241 | # Temporal-Attention
242 | if self.unet_use_temporal_attention:
243 | d = hidden_states.shape[1]
244 | hidden_states = rearrange(
245 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
246 | )
247 | norm_hidden_states = (
248 | self.norm_temp(hidden_states, timestep)
249 | if self.use_ada_layer_norm
250 | else self.norm_temp(hidden_states)
251 | )
252 | hidden_states = (
253 | self.attn_temp(norm_hidden_states) + hidden_states
254 | )
255 | hidden_states = rearrange(
256 | hidden_states, "(b d) f c -> (b f) d c", d=d
257 | )
258 |
259 | return hidden_states
260 |
261 | # 3. Feed-forward
262 | norm_hidden_states = self.norm3(hidden_states)
263 |
264 | if self.use_ada_layer_norm_zero:
265 | norm_hidden_states = (
266 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
267 | )
268 |
269 | ff_output = self.ff(norm_hidden_states)
270 |
271 | if self.use_ada_layer_norm_zero:
272 | ff_output = gate_mlp.unsqueeze(1) * ff_output
273 |
274 | hidden_states = ff_output + hidden_states
275 |
276 | return hidden_states
277 |
278 | if self.reference_attn:
279 | if self.fusion_blocks == "midup":
280 | attn_modules = [
281 | module
282 | for module in (
283 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
284 | )
285 | if isinstance(module, BasicTransformerBlock)
286 | or isinstance(module, TemporalBasicTransformerBlock)
287 | ]
288 | elif self.fusion_blocks == "full":
289 | attn_modules = [
290 | module
291 | for module in torch_dfs(self.unet)
292 | if isinstance(module, BasicTransformerBlock)
293 | or isinstance(module, TemporalBasicTransformerBlock)
294 | ]
295 | attn_modules = sorted(
296 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
297 | )
298 |
299 | for i, module in enumerate(attn_modules):
300 | module._original_inner_forward = module.forward
301 | if isinstance(module, BasicTransformerBlock):
302 | module.forward = hacked_basic_transformer_inner_forward.__get__(
303 | module, BasicTransformerBlock
304 | )
305 | if isinstance(module, TemporalBasicTransformerBlock):
306 | module.forward = hacked_basic_transformer_inner_forward.__get__(
307 | module, TemporalBasicTransformerBlock
308 | )
309 |
310 | module.bank = []
311 | module.attn_weight = float(i) / float(len(attn_modules))
312 |
313 | def update(
314 | self,
315 | writer,
316 | do_classifier_free_guidance=True,
317 | dtype=torch.float16,
318 | ):
319 | if self.reference_attn:
320 | if self.fusion_blocks == "midup":
321 | reader_attn_modules = [
322 | module
323 | for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks))
324 | if isinstance(module, TemporalBasicTransformerBlock)
325 | ]
326 | writer_attn_modules = [
327 | module
328 | for module in (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks))
329 | if isinstance(module, BasicTransformerBlock)
330 | ]
331 | elif self.fusion_blocks == "full":
332 | reader_attn_modules = [
333 | module
334 | for module in torch_dfs(self.unet)
335 | if isinstance(module, TemporalBasicTransformerBlock)
336 | ]
337 | writer_attn_modules = [
338 | module
339 | for module in torch_dfs(writer.unet)
340 | if isinstance(module, BasicTransformerBlock)
341 | ]
342 | reader_attn_modules = sorted(
343 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
344 | )
345 | writer_attn_modules = sorted(
346 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
347 | )
348 | for r, w in zip(reader_attn_modules, writer_attn_modules):
349 | if do_classifier_free_guidance:
350 | r.bank = [torch.cat([torch.zeros_like(v), v]).to(dtype) for v in w.bank]
351 | else:
352 | r.bank = [v.clone().to(dtype) for v in w.bank]
353 |
354 | def clear(self):
355 | if self.reference_attn:
356 | if self.fusion_blocks == "midup":
357 | reader_attn_modules = [
358 | module
359 | for module in (
360 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
361 | )
362 | if isinstance(module, BasicTransformerBlock)
363 | or isinstance(module, TemporalBasicTransformerBlock)
364 | ]
365 | elif self.fusion_blocks == "full":
366 | reader_attn_modules = [
367 | module
368 | for module in torch_dfs(self.unet)
369 | if isinstance(module, BasicTransformerBlock)
370 | or isinstance(module, TemporalBasicTransformerBlock)
371 | ]
372 | reader_attn_modules = sorted(
373 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
374 | )
375 | for r in reader_attn_modules:
376 | r.bank.clear()
377 |
--------------------------------------------------------------------------------
/src/modules/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))
227 | if len(temb.shape) == 2:
228 | temb = temb[:, :, None, None, None]
229 | elif len(temb.shape) == 3:
230 | temb = temb[:, :, :, None, None].permute(0, 2, 1, 3, 4)
231 |
232 | if temb is not None and self.time_embedding_norm == "default":
233 | hidden_states = hidden_states + temb
234 |
235 | hidden_states = self.norm2(hidden_states)
236 |
237 | if temb is not None and self.time_embedding_norm == "scale_shift":
238 | scale, shift = torch.chunk(temb, 2, dim=1)
239 | hidden_states = hidden_states * (1 + scale) + shift
240 |
241 | hidden_states = self.nonlinearity(hidden_states)
242 |
243 | hidden_states = self.dropout(hidden_states)
244 | hidden_states = self.conv2(hidden_states)
245 |
246 | if self.conv_shortcut is not None:
247 | input_tensor = self.conv_shortcut(input_tensor)
248 |
249 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
250 |
251 | return output_tensor
252 |
253 |
254 | class Mish(torch.nn.Module):
255 | def forward(self, hidden_states):
256 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
257 |
--------------------------------------------------------------------------------
/src/modules/transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Optional
4 |
5 | import torch
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 |
8 | try:
9 | from diffusers.models.embeddings import CaptionProjection
10 | except:
11 | from diffusers.models.embeddings import PixArtAlphaTextProjection as CaptionProjection
12 |
13 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
14 | from diffusers.models.modeling_utils import ModelMixin
15 | from diffusers.models.normalization import AdaLayerNormSingle
16 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
17 | from torch import nn
18 |
19 | from .attention import BasicTransformerBlock
20 |
21 |
22 | @dataclass
23 | class Transformer2DModelOutput(BaseOutput):
24 | """
25 | The output of [`Transformer2DModel`].
26 |
27 | Args:
28 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
29 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
30 | distributions for the unnoised latent pixels.
31 | """
32 |
33 | sample: torch.FloatTensor
34 | ref_feature: torch.FloatTensor
35 |
36 |
37 | class Transformer2DModel(ModelMixin, ConfigMixin):
38 | """
39 | A 2D Transformer model for image-like data.
40 |
41 | Parameters:
42 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
43 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
44 | in_channels (`int`, *optional*):
45 | The number of channels in the input and output (specify if the input is **continuous**).
46 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
47 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
49 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
50 | This is fixed during training since it is used to learn a number of position embeddings.
51 | num_vector_embeds (`int`, *optional*):
52 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
53 | Includes the class for the masked latent pixel.
54 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
55 | num_embeds_ada_norm ( `int`, *optional*):
56 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is
57 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
58 | added to the hidden states.
59 |
60 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
61 | attention_bias (`bool`, *optional*):
62 | Configure if the `TransformerBlocks` attention should contain a bias parameter.
63 | """
64 |
65 | _supports_gradient_checkpointing = True
66 |
67 | @register_to_config
68 | def __init__(
69 | self,
70 | num_attention_heads: int = 16,
71 | attention_head_dim: int = 88,
72 | in_channels: Optional[int] = None,
73 | out_channels: Optional[int] = None,
74 | num_layers: int = 1,
75 | dropout: float = 0.0,
76 | norm_num_groups: int = 32,
77 | cross_attention_dim: Optional[int] = None,
78 | attention_bias: bool = False,
79 | sample_size: Optional[int] = None,
80 | num_vector_embeds: Optional[int] = None,
81 | patch_size: Optional[int] = None,
82 | activation_fn: str = "geglu",
83 | num_embeds_ada_norm: Optional[int] = None,
84 | use_linear_projection: bool = False,
85 | only_cross_attention: bool = False,
86 | double_self_attention: bool = False,
87 | upcast_attention: bool = False,
88 | norm_type: str = "layer_norm",
89 | norm_elementwise_affine: bool = True,
90 | norm_eps: float = 1e-5,
91 | attention_type: str = "default",
92 | caption_channels: int = None,
93 | ):
94 | super().__init__()
95 | self.use_linear_projection = use_linear_projection
96 | self.num_attention_heads = num_attention_heads
97 | self.attention_head_dim = attention_head_dim
98 | inner_dim = num_attention_heads * attention_head_dim
99 |
100 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
101 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
102 |
103 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
104 | # Define whether input is continuous or discrete depending on configuration
105 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
106 | self.is_input_vectorized = num_vector_embeds is not None
107 | self.is_input_patches = in_channels is not None and patch_size is not None
108 |
109 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
110 | deprecation_message = (
111 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
112 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
113 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
114 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
115 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
116 | )
117 | deprecate(
118 | "norm_type!=num_embeds_ada_norm",
119 | "1.0.0",
120 | deprecation_message,
121 | standard_warn=False,
122 | )
123 | norm_type = "ada_norm"
124 |
125 | if self.is_input_continuous and self.is_input_vectorized:
126 | raise ValueError(
127 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
128 | " sure that either `in_channels` or `num_vector_embeds` is None."
129 | )
130 | elif self.is_input_vectorized and self.is_input_patches:
131 | raise ValueError(
132 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
133 | " sure that either `num_vector_embeds` or `num_patches` is None."
134 | )
135 | elif (
136 | not self.is_input_continuous
137 | and not self.is_input_vectorized
138 | and not self.is_input_patches
139 | ):
140 | raise ValueError(
141 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
142 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
143 | )
144 |
145 | # 2. Define input layers
146 | self.in_channels = in_channels
147 |
148 | self.norm = torch.nn.GroupNorm(
149 | num_groups=norm_num_groups,
150 | num_channels=in_channels,
151 | eps=1e-6,
152 | affine=True,
153 | )
154 | if use_linear_projection:
155 | self.proj_in = linear_cls(in_channels, inner_dim)
156 | else:
157 | self.proj_in = conv_cls(
158 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
159 | )
160 |
161 | # 3. Define transformers blocks
162 | self.transformer_blocks = nn.ModuleList(
163 | [
164 | BasicTransformerBlock(
165 | inner_dim,
166 | num_attention_heads,
167 | attention_head_dim,
168 | dropout=dropout,
169 | cross_attention_dim=cross_attention_dim,
170 | activation_fn=activation_fn,
171 | num_embeds_ada_norm=num_embeds_ada_norm,
172 | attention_bias=attention_bias,
173 | only_cross_attention=only_cross_attention,
174 | double_self_attention=double_self_attention,
175 | upcast_attention=upcast_attention,
176 | norm_type=norm_type,
177 | norm_elementwise_affine=norm_elementwise_affine,
178 | norm_eps=norm_eps,
179 | attention_type=attention_type,
180 | )
181 | for d in range(num_layers)
182 | ]
183 | )
184 |
185 | # 4. Define output layers
186 | self.out_channels = in_channels if out_channels is None else out_channels
187 | # TODO: should use out_channels for continuous projections
188 | if use_linear_projection:
189 | self.proj_out = linear_cls(inner_dim, in_channels)
190 | else:
191 | self.proj_out = conv_cls(
192 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
193 | )
194 |
195 | # 5. PixArt-Alpha blocks.
196 | self.adaln_single = None
197 | self.use_additional_conditions = False
198 | if norm_type == "ada_norm_single":
199 | self.use_additional_conditions = self.config.sample_size == 128
200 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
201 | # additional conditions until we find better name
202 | self.adaln_single = AdaLayerNormSingle(
203 | inner_dim, use_additional_conditions=self.use_additional_conditions
204 | )
205 |
206 | self.caption_projection = None
207 | if caption_channels is not None:
208 | self.caption_projection = CaptionProjection(
209 | in_features=caption_channels, hidden_size=inner_dim
210 | )
211 |
212 | self.gradient_checkpointing = False
213 |
214 | def _set_gradient_checkpointing(self, module, value=False):
215 | if hasattr(module, "gradient_checkpointing"):
216 | module.gradient_checkpointing = value
217 |
218 | def forward(
219 | self,
220 | hidden_states: torch.Tensor,
221 | encoder_hidden_states: Optional[torch.Tensor] = None,
222 | timestep: Optional[torch.LongTensor] = None,
223 | added_cond_kwargs: Dict[str, torch.Tensor] = None,
224 | class_labels: Optional[torch.LongTensor] = None,
225 | cross_attention_kwargs: Dict[str, Any] = None,
226 | attention_mask: Optional[torch.Tensor] = None,
227 | encoder_attention_mask: Optional[torch.Tensor] = None,
228 | return_dict: bool = True,
229 | ):
230 | """
231 | The [`Transformer2DModel`] forward method.
232 |
233 | Args:
234 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
235 | Input `hidden_states`.
236 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
237 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
238 | self-attention.
239 | timestep ( `torch.LongTensor`, *optional*):
240 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
241 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
242 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
243 | `AdaLayerZeroNorm`.
244 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
245 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
246 | `self.processor` in
247 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
248 | attention_mask ( `torch.Tensor`, *optional*):
249 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
250 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
251 | negative values to the attention scores corresponding to "discard" tokens.
252 | encoder_attention_mask ( `torch.Tensor`, *optional*):
253 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
254 |
255 | * Mask `(batch, sequence_length)` True = keep, False = discard.
256 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
257 |
258 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
259 | above. This bias will be added to the cross-attention scores.
260 | return_dict (`bool`, *optional*, defaults to `True`):
261 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
262 | tuple.
263 |
264 | Returns:
265 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
266 | `tuple` where the first element is the sample tensor.
267 | """
268 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
269 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
270 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
271 | # expects mask of shape:
272 | # [batch, key_tokens]
273 | # adds singleton query_tokens dimension:
274 | # [batch, 1, key_tokens]
275 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
276 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
277 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
278 | if attention_mask is not None and attention_mask.ndim == 2:
279 | # assume that mask is expressed as:
280 | # (1 = keep, 0 = discard)
281 | # convert mask into a bias that can be added to attention scores:
282 | # (keep = +0, discard = -10000.0)
283 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
284 | attention_mask = attention_mask.unsqueeze(1)
285 |
286 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
287 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
288 | encoder_attention_mask = (
289 | 1 - encoder_attention_mask.to(hidden_states.dtype)
290 | ) * -10000.0
291 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
292 |
293 | # Retrieve lora scale.
294 | lora_scale = (
295 | cross_attention_kwargs.get("scale", 1.0)
296 | if cross_attention_kwargs is not None
297 | else 1.0
298 | )
299 |
300 | # 1. Input
301 | batch, _, height, width = hidden_states.shape
302 | residual = hidden_states
303 |
304 | hidden_states = self.norm(hidden_states)
305 | if not self.use_linear_projection:
306 | hidden_states = (
307 | self.proj_in(hidden_states, scale=lora_scale)
308 | if not USE_PEFT_BACKEND
309 | else self.proj_in(hidden_states)
310 | )
311 | inner_dim = hidden_states.shape[1]
312 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313 | batch, height * width, inner_dim
314 | )
315 | else:
316 | inner_dim = hidden_states.shape[1]
317 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
318 | batch, height * width, inner_dim
319 | )
320 | hidden_states = (
321 | self.proj_in(hidden_states, scale=lora_scale)
322 | if not USE_PEFT_BACKEND
323 | else self.proj_in(hidden_states)
324 | )
325 |
326 | # 2. Blocks
327 | if self.caption_projection is not None:
328 | batch_size = hidden_states.shape[0]
329 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
330 | encoder_hidden_states = encoder_hidden_states.view(
331 | batch_size, -1, hidden_states.shape[-1]
332 | )
333 |
334 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
335 | for block in self.transformer_blocks:
336 | if self.training and self.gradient_checkpointing:
337 |
338 | def create_custom_forward(module, return_dict=None):
339 | def custom_forward(*inputs):
340 | if return_dict is not None:
341 | return module(*inputs, return_dict=return_dict)
342 | else:
343 | return module(*inputs)
344 |
345 | return custom_forward
346 |
347 | ckpt_kwargs: Dict[str, Any] = (
348 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
349 | )
350 | hidden_states = torch.utils.checkpoint.checkpoint(
351 | create_custom_forward(block),
352 | hidden_states,
353 | attention_mask,
354 | encoder_hidden_states,
355 | encoder_attention_mask,
356 | timestep,
357 | cross_attention_kwargs,
358 | class_labels,
359 | **ckpt_kwargs,
360 | )
361 | else:
362 | hidden_states = block(
363 | hidden_states,
364 | attention_mask=attention_mask,
365 | encoder_hidden_states=encoder_hidden_states,
366 | encoder_attention_mask=encoder_attention_mask,
367 | timestep=timestep,
368 | cross_attention_kwargs=cross_attention_kwargs,
369 | class_labels=class_labels,
370 | )
371 |
372 | # 3. Output
373 | if self.is_input_continuous:
374 | if not self.use_linear_projection:
375 | hidden_states = (
376 | hidden_states.reshape(batch, height, width, inner_dim)
377 | .permute(0, 3, 1, 2)
378 | .contiguous()
379 | )
380 | hidden_states = (
381 | self.proj_out(hidden_states, scale=lora_scale)
382 | if not USE_PEFT_BACKEND
383 | else self.proj_out(hidden_states)
384 | )
385 | else:
386 | hidden_states = (
387 | self.proj_out(hidden_states, scale=lora_scale)
388 | if not USE_PEFT_BACKEND
389 | else self.proj_out(hidden_states)
390 | )
391 | hidden_states = (
392 | hidden_states.reshape(batch, height, width, inner_dim)
393 | .permute(0, 3, 1, 2)
394 | .contiguous()
395 | )
396 |
397 | output = hidden_states + residual
398 | if not return_dict:
399 | return (output, ref_feature)
400 |
401 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
402 |
--------------------------------------------------------------------------------
/src/modules/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117 | encoder_hidden_states = repeat(
118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119 | )
120 |
121 | batch, channel, height, weight = hidden_states.shape
122 | residual = hidden_states
123 |
124 | hidden_states = self.norm(hidden_states)
125 | if not self.use_linear_projection:
126 | hidden_states = self.proj_in(hidden_states)
127 | inner_dim = hidden_states.shape[1]
128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129 | batch, height * weight, inner_dim
130 | )
131 | else:
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | hidden_states = self.proj_in(hidden_states)
137 |
138 | # Blocks
139 | for i, block in enumerate(self.transformer_blocks):
140 | hidden_states = block(
141 | hidden_states,
142 | encoder_hidden_states=encoder_hidden_states,
143 | timestep=timestep,
144 | video_length=video_length,
145 | )
146 |
147 | # Output
148 | if not self.use_linear_projection:
149 | hidden_states = (
150 | hidden_states.reshape(batch, height, weight, inner_dim)
151 | .permute(0, 3, 1, 2)
152 | .contiguous()
153 | )
154 | hidden_states = self.proj_out(hidden_states)
155 | else:
156 | hidden_states = self.proj_out(hidden_states)
157 | hidden_states = (
158 | hidden_states.reshape(batch, height, weight, inner_dim)
159 | .permute(0, 3, 1, 2)
160 | .contiguous()
161 | )
162 |
163 | output = hidden_states + residual
164 |
165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166 | if not return_dict:
167 | return (output,)
168 |
169 | return Transformer3DModelOutput(sample=output)
170 |
--------------------------------------------------------------------------------
/src/modules/v_kps_guider.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from diffusers.models.modeling_utils import ModelMixin
6 | from .motion_module import zero_module
7 | from .resnet import InflatedConv3d
8 |
9 |
10 | class VKpsGuider(ModelMixin):
11 | def __init__(
12 | self,
13 | conditioning_embedding_channels: int,
14 | conditioning_channels: int = 3,
15 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
16 | ):
17 | super().__init__()
18 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
19 |
20 | self.blocks = nn.ModuleList([])
21 |
22 | for i in range(len(block_out_channels) - 1):
23 | channel_in = block_out_channels[i]
24 | channel_out = block_out_channels[i + 1]
25 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
26 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
27 |
28 | self.conv_out = zero_module(InflatedConv3d(
29 | block_out_channels[-1],
30 | conditioning_embedding_channels,
31 | kernel_size=3,
32 | padding=1,
33 | ))
34 |
35 | def forward(self, conditioning):
36 | embedding = self.conv_in(conditioning)
37 | embedding = F.silu(embedding)
38 |
39 | for block in self.blocks:
40 | embedding = block(embedding)
41 | embedding = F.silu(embedding)
42 |
43 | embedding = self.conv_out(embedding)
44 |
45 | return embedding
46 |
--------------------------------------------------------------------------------
/src/pipelines/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/src/pipelines/.DS_Store
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .v_express_pipeline import VExpressPipeline
2 |
--------------------------------------------------------------------------------
/src/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | next_itr = []
40 | for e in range(j, j + context_size * context_step, context_step):
41 | if e >= num_frames:
42 | e = num_frames - 2 - e % num_frames
43 | next_itr.append(e)
44 |
45 | yield next_itr
46 |
47 |
48 | def get_context_scheduler(name: str) -> Callable:
49 | if name == "uniform":
50 | return uniform
51 | else:
52 | raise ValueError(f"Unknown context_overlap policy {name}")
53 |
54 |
55 | def get_total_steps(
56 | scheduler,
57 | timesteps: List[int],
58 | num_steps: Optional[int] = None,
59 | num_frames: int = ...,
60 | context_size: Optional[int] = None,
61 | context_stride: int = 3,
62 | context_overlap: int = 4,
63 | closed_loop: bool = True,
64 | ):
65 | return sum(
66 | len(
67 | list(
68 | scheduler(
69 | i,
70 | num_steps,
71 | num_frames,
72 | context_size,
73 | context_stride,
74 | context_overlap,
75 | )
76 | )
77 | )
78 | for i in range(len(timesteps))
79 | )
80 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import pathlib
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as func
9 | import tqdm
10 | from imageio_ffmpeg import get_ffmpeg_exe
11 |
12 | tensor_interpolation = None
13 |
14 |
15 | def get_tensor_interpolation_method():
16 | return tensor_interpolation
17 |
18 |
19 | def set_tensor_interpolation_method(is_slerp):
20 | global tensor_interpolation
21 | tensor_interpolation = slerp if is_slerp else linear
22 |
23 |
24 | def linear(v1, v2, t):
25 | return (1.0 - t) * v1 + t * v2
26 |
27 |
28 | def slerp(v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995) -> torch.Tensor:
29 | u0 = v0 / v0.norm()
30 | u1 = v1 / v1.norm()
31 | dot = (u0 * u1).sum()
32 | if dot.abs() > DOT_THRESHOLD:
33 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
34 | return (1.0 - t) * v0 + t * v1
35 | omega = dot.acos()
36 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
37 |
38 |
39 | def draw_kps_image(height, width, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)]):
40 | stick_width = 4
41 | limb_seq = np.array([[0, 2], [1, 2]])
42 | kps = np.array(kps)
43 |
44 | canvas = np.zeros((height, width, 3), dtype=np.uint8)
45 |
46 | for i in range(len(limb_seq)):
47 | index = limb_seq[i]
48 | color = color_list[index[0]]
49 |
50 | x = kps[index][:, 0]
51 | y = kps[index][:, 1]
52 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
53 | angle = int(math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])))
54 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), angle, 0, 360, 1)
55 | cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
56 |
57 | for idx_kp, kp in enumerate(kps):
58 | color = color_list[idx_kp]
59 | x, y = kp
60 | cv2.circle(canvas, (int(x), int(y)), 4, color, -1)
61 |
62 | return canvas
63 |
64 |
65 | def median_filter_3d(video_tensor, kernel_size, device):
66 | _, video_length, height, width = video_tensor.shape
67 |
68 | pad_size = kernel_size // 2
69 | video_tensor = func.pad(video_tensor, (pad_size, pad_size, pad_size, pad_size, pad_size, pad_size), mode='reflect')
70 |
71 | filtered_video_tensor = []
72 | for i in tqdm.tqdm(range(video_length), desc='Median Filtering'):
73 | video_segment = video_tensor[:, i:i + kernel_size, ...].to(device)
74 | video_segment = video_segment.unfold(dimension=2, size=kernel_size, step=1)
75 | video_segment = video_segment.unfold(dimension=3, size=kernel_size, step=1)
76 | video_segment = video_segment.permute(0, 2, 3, 1, 4, 5).reshape(3, height, width, -1)
77 | filtered_video_frame = torch.median(video_segment, dim=-1)[0]
78 | filtered_video_tensor.append(filtered_video_frame.cpu())
79 | filtered_video_tensor = torch.stack(filtered_video_tensor, dim=1)
80 | return filtered_video_tensor
81 |
82 |
83 | def save_video(video_tensor, audio_path, output_path, device, fps=30.0):
84 | pathlib.Path(output_path).parent.mkdir(exist_ok=True, parents=True)
85 |
86 | video_tensor = video_tensor[0, ...]
87 | _, num_frames, height, width = video_tensor.shape
88 |
89 | video_tensor = median_filter_3d(video_tensor, kernel_size=3, device=device)
90 | video_tensor = video_tensor.permute(1, 2, 3, 0)
91 | video_frames = (video_tensor * 255).numpy().astype(np.uint8)
92 |
93 | output_name = pathlib.Path(output_path).stem
94 | temp_output_path = output_path.replace(output_name, output_name + '-temp')
95 | video_writer = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
96 |
97 | for i in tqdm.tqdm(range(num_frames), 'Writing frames into file'):
98 | frame_image = video_frames[i, ...]
99 | frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
100 | video_writer.write(frame_image)
101 | video_writer.release()
102 |
103 | cmd = (f'{get_ffmpeg_exe()} -i "{temp_output_path}" -i "{audio_path}" '
104 | f'-map 0:v -map 1:a -c:v h264 -shortest -y "{output_path}" -loglevel quiet')
105 | os.system(cmd)
106 | os.remove(temp_output_path)
107 |
108 |
109 | def compute_dist(x1, y1, x2, y2):
110 | return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
111 |
112 |
113 | def compute_ratio(kps):
114 | l_eye_x, l_eye_y = kps[0][0], kps[0][1]
115 | r_eye_x, r_eye_y = kps[1][0], kps[1][1]
116 | nose_x, nose_y = kps[2][0], kps[2][1]
117 | d_left = compute_dist(l_eye_x, l_eye_y, nose_x, nose_y)
118 | d_right = compute_dist(r_eye_x, r_eye_y, nose_x, nose_y)
119 | ratio = d_left / (d_right + 1e-6)
120 | return ratio
121 |
122 |
123 | def point_to_line_dist(point, line_points):
124 | point = np.array(point)
125 | line_points = np.array(line_points)
126 | line_vec = line_points[1] - line_points[0]
127 | point_vec = point - line_points[0]
128 | line_norm = line_vec / np.sqrt(np.sum(line_vec ** 2))
129 | point_vec_scaled = point_vec * 1.0 / np.sqrt(np.sum(line_vec ** 2))
130 | t = np.dot(line_norm, point_vec_scaled)
131 | if t < 0.0:
132 | t = 0.0
133 | elif t > 1.0:
134 | t = 1.0
135 | nearest = line_points[0] + t * line_vec
136 | dist = np.sqrt(np.sum((point - nearest) ** 2))
137 | return dist
138 |
139 |
140 | def get_face_size(kps):
141 | # 0: left eye, 1: right eye, 2: nose
142 | A = kps[0, :]
143 | B = kps[1, :]
144 | C = kps[2, :]
145 |
146 | AB_dist = math.sqrt((A[0] - B[0]) ** 2 + (A[1] - B[1]) ** 2)
147 | C_AB_dist = point_to_line_dist(C, [A, B])
148 | return AB_dist, C_AB_dist
149 |
150 |
151 | def get_rescale_params(kps_ref, kps_target):
152 | kps_ref = np.array(kps_ref)
153 | kps_target = np.array(kps_target)
154 |
155 | ref_AB_dist, ref_C_AB_dist = get_face_size(kps_ref)
156 | target_AB_dist, target_C_AB_dist = get_face_size(kps_target)
157 |
158 | scale_width = ref_AB_dist / target_AB_dist
159 | scale_height = ref_C_AB_dist / target_C_AB_dist
160 |
161 | return scale_width, scale_height
162 |
163 |
164 | def retarget_kps(ref_kps, tgt_kps_list, only_offset=True):
165 | ref_kps = np.array(ref_kps)
166 | tgt_kps_list = np.array(tgt_kps_list)
167 |
168 | ref_ratio = compute_ratio(ref_kps)
169 |
170 | ratio_delta = 10000
171 | selected_tgt_kps_idx = None
172 | for idx, tgt_kps in enumerate(tgt_kps_list):
173 | tgt_ratio = compute_ratio(tgt_kps)
174 | if math.fabs(tgt_ratio - ref_ratio) < ratio_delta:
175 | selected_tgt_kps_idx = idx
176 | ratio_delta = tgt_ratio
177 |
178 | scale_width, scale_height = get_rescale_params(
179 | kps_ref=ref_kps,
180 | kps_target=tgt_kps_list[selected_tgt_kps_idx],
181 | )
182 |
183 | rescaled_tgt_kps_list = np.array(tgt_kps_list)
184 | rescaled_tgt_kps_list[:, :, 0] *= scale_width
185 | rescaled_tgt_kps_list[:, :, 1] *= scale_height
186 |
187 | if only_offset:
188 | nose_offset = rescaled_tgt_kps_list[:, 2, :] - rescaled_tgt_kps_list[0, 2, :]
189 | nose_offset = nose_offset[:, np.newaxis, :]
190 | ref_kps_repeat = np.tile(ref_kps, (tgt_kps_list.shape[0], 1, 1))
191 |
192 | ref_kps_repeat[:, :, :] -= (nose_offset / 2.0)
193 | rescaled_tgt_kps_list = ref_kps_repeat
194 | else:
195 | nose_offset_x = rescaled_tgt_kps_list[0, 2, 0] - ref_kps[2][0]
196 | nose_offset_y = rescaled_tgt_kps_list[0, 2, 1] - ref_kps[2][1]
197 |
198 | rescaled_tgt_kps_list[:, :, 0] -= nose_offset_x
199 | rescaled_tgt_kps_list[:, :, 1] -= nose_offset_y
200 |
201 | return rescaled_tgt_kps_list
202 |
--------------------------------------------------------------------------------
/src/pipelines/v_express_pipeline.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2 | import inspect
3 | import math
4 | from typing import Callable, List, Optional, Union
5 |
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (
10 | DDIMScheduler,
11 | DPMSolverMultistepScheduler,
12 | EulerAncestralDiscreteScheduler,
13 | EulerDiscreteScheduler,
14 | LMSDiscreteScheduler,
15 | PNDMScheduler,
16 | )
17 | from diffusers.utils import is_accelerate_available
18 | from diffusers.utils.torch_utils import randn_tensor
19 | from einops import rearrange
20 | from tqdm import tqdm
21 | from transformers import CLIPImageProcessor
22 |
23 | from modules import ReferenceAttentionControl
24 | from .context import get_context_scheduler
25 |
26 |
27 | def retrieve_timesteps(
28 | scheduler,
29 | num_inference_steps: Optional[int] = None,
30 | device: Optional[Union[str, torch.device]] = None,
31 | timesteps: Optional[List[int]] = None,
32 | **kwargs,
33 | ):
34 | """
35 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
36 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
37 |
38 | Args:
39 | scheduler (`SchedulerMixin`):
40 | The scheduler to get timesteps from.
41 | num_inference_steps (`int`):
42 | The number of diffusion steps used when generating samples with a pre-trained model. If used,
43 | `timesteps` must be `None`.
44 | device (`str` or `torch.device`, *optional*):
45 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
46 | timesteps (`List[int]`, *optional*):
47 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
48 | timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
49 | must be `None`.
50 |
51 | Returns:
52 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
53 | second element is the number of inference steps.
54 | """
55 | if timesteps is not None:
56 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
57 | if not accepts_timesteps:
58 | raise ValueError(
59 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
60 | f" timestep schedules. Please check whether you are using the correct scheduler."
61 | )
62 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
63 | timesteps = scheduler.timesteps
64 | num_inference_steps = len(timesteps)
65 | else:
66 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
67 | timesteps = scheduler.timesteps
68 | return timesteps, num_inference_steps
69 |
70 |
71 | class VExpressPipeline(DiffusionPipeline):
72 | _optional_components = []
73 |
74 | def __init__(
75 | self,
76 | vae,
77 | reference_net,
78 | denoising_unet,
79 | v_kps_guider,
80 | audio_processor,
81 | audio_encoder,
82 | audio_projection,
83 | scheduler: Union[
84 | DDIMScheduler,
85 | PNDMScheduler,
86 | LMSDiscreteScheduler,
87 | EulerDiscreteScheduler,
88 | EulerAncestralDiscreteScheduler,
89 | DPMSolverMultistepScheduler,
90 | ],
91 | image_proj_model=None,
92 | tokenizer=None,
93 | text_encoder=None,
94 | ):
95 | super().__init__()
96 |
97 | self.register_modules(
98 | vae=vae,
99 | reference_net=reference_net,
100 | denoising_unet=denoising_unet,
101 | v_kps_guider=v_kps_guider,
102 | audio_processor=audio_processor,
103 | audio_encoder=audio_encoder,
104 | audio_projection=audio_projection,
105 | scheduler=scheduler,
106 | image_proj_model=image_proj_model,
107 | tokenizer=tokenizer,
108 | text_encoder=text_encoder,
109 | )
110 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
111 | self.clip_image_processor = CLIPImageProcessor()
112 | self.reference_image_processor = VaeImageProcessor(
113 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
114 | )
115 | self.condition_image_processor = VaeImageProcessor(
116 | vae_scale_factor=self.vae_scale_factor,
117 | do_convert_rgb=True,
118 | do_normalize=False,
119 | )
120 |
121 | def enable_vae_slicing(self):
122 | self.vae.enable_slicing()
123 |
124 | def disable_vae_slicing(self):
125 | self.vae.disable_slicing()
126 |
127 | def enable_sequential_cpu_offload(self, gpu_id=0):
128 | if is_accelerate_available():
129 | from accelerate import cpu_offload
130 | else:
131 | raise ImportError("Please install accelerate via `pip install accelerate`")
132 |
133 | device = torch.device(f"cuda:{gpu_id}")
134 |
135 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
136 | if cpu_offloaded_model is not None:
137 | cpu_offload(cpu_offloaded_model, device)
138 |
139 | @property
140 | def _execution_device(self):
141 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
142 | return self.device
143 | for module in self.unet.modules():
144 | if (
145 | hasattr(module, "_hf_hook")
146 | and hasattr(module._hf_hook, "execution_device")
147 | and module._hf_hook.execution_device is not None
148 | ):
149 | return torch.device(module._hf_hook.execution_device)
150 | return self.device
151 |
152 | @torch.no_grad()
153 | def decode_latents(self, latents):
154 | video_length = latents.shape[2]
155 | latents = 1 / 0.18215 * latents
156 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
157 | video = []
158 | for frame_idx in tqdm(range(latents.shape[0]), desc='Decoding latents into frames'):
159 | image = self.vae.decode(latents[frame_idx: frame_idx + 1].to(self.vae.device)).sample
160 | image = (image / 2 + 0.5).clamp(0, 1)
161 | image = image.cpu().float()
162 | video.append(image)
163 | video = torch.cat(video)
164 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
165 |
166 | return video
167 |
168 | def prepare_extra_step_kwargs(self, generator, eta):
169 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
170 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
171 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
172 | # and should be between [0, 1]
173 |
174 | accepts_eta = "eta" in set(
175 | inspect.signature(self.scheduler.step).parameters.keys()
176 | )
177 | extra_step_kwargs = {}
178 | if accepts_eta:
179 | extra_step_kwargs["eta"] = eta
180 |
181 | # check if the scheduler accepts generator
182 | accepts_generator = "generator" in set(
183 | inspect.signature(self.scheduler.step).parameters.keys()
184 | )
185 | if accepts_generator:
186 | extra_step_kwargs["generator"] = generator
187 | return extra_step_kwargs
188 |
189 | def prepare_latents(
190 | self,
191 | batch_size,
192 | num_channels_latents,
193 | width,
194 | height,
195 | video_length,
196 | dtype,
197 | device,
198 | generator,
199 | latents=None
200 | ):
201 | shape = (
202 | batch_size,
203 | num_channels_latents,
204 | video_length,
205 | height // self.vae_scale_factor,
206 | width // self.vae_scale_factor,
207 | )
208 | if isinstance(generator, list) and len(generator) != batch_size:
209 | raise ValueError(
210 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
211 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
212 | )
213 |
214 | if latents is None:
215 | latents = randn_tensor(
216 | shape, generator=generator, device=device, dtype=dtype
217 | )
218 |
219 | else:
220 | latents = latents.to(device)
221 |
222 | # scale the initial noise by the standard deviation required by the scheduler
223 | latents = latents * self.scheduler.init_noise_sigma
224 | return latents
225 |
226 | def _encode_prompt(
227 | self,
228 | prompt,
229 | device,
230 | num_videos_per_prompt,
231 | do_classifier_free_guidance,
232 | negative_prompt,
233 | ):
234 | batch_size = len(prompt) if isinstance(prompt, list) else 1
235 |
236 | text_inputs = self.tokenizer(
237 | prompt,
238 | padding="max_length",
239 | max_length=self.tokenizer.model_max_length,
240 | truncation=True,
241 | return_tensors="pt",
242 | )
243 | text_input_ids = text_inputs.input_ids
244 | untruncated_ids = self.tokenizer(
245 | prompt, padding="longest", return_tensors="pt"
246 | ).input_ids
247 |
248 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
249 | text_input_ids, untruncated_ids
250 | ):
251 | removed_text = self.tokenizer.batch_decode(
252 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
253 | )
254 |
255 | if (
256 | hasattr(self.text_encoder.config, "use_attention_mask")
257 | and self.text_encoder.config.use_attention_mask
258 | ):
259 | attention_mask = text_inputs.attention_mask.to(device)
260 | else:
261 | attention_mask = None
262 |
263 | text_embeddings = self.text_encoder(
264 | text_input_ids.to(device),
265 | attention_mask=attention_mask,
266 | )
267 | text_embeddings = text_embeddings[0]
268 |
269 | # duplicate text embeddings for each generation per prompt, using mps friendly method
270 | bs_embed, seq_len, _ = text_embeddings.shape
271 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
272 | text_embeddings = text_embeddings.view(
273 | bs_embed * num_videos_per_prompt, seq_len, -1
274 | )
275 |
276 | # get unconditional embeddings for classifier free guidance
277 | if do_classifier_free_guidance:
278 | uncond_tokens: List[str]
279 | if negative_prompt is None:
280 | uncond_tokens = [""] * batch_size
281 | elif type(prompt) is not type(negative_prompt):
282 | raise TypeError(
283 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284 | f" {type(prompt)}."
285 | )
286 | elif isinstance(negative_prompt, str):
287 | uncond_tokens = [negative_prompt]
288 | elif batch_size != len(negative_prompt):
289 | raise ValueError(
290 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292 | " the batch size of `prompt`."
293 | )
294 | else:
295 | uncond_tokens = negative_prompt
296 |
297 | max_length = text_input_ids.shape[-1]
298 | uncond_input = self.tokenizer(
299 | uncond_tokens,
300 | padding="max_length",
301 | max_length=max_length,
302 | truncation=True,
303 | return_tensors="pt",
304 | )
305 |
306 | if (
307 | hasattr(self.text_encoder.config, "use_attention_mask")
308 | and self.text_encoder.config.use_attention_mask
309 | ):
310 | attention_mask = uncond_input.attention_mask.to(device)
311 | else:
312 | attention_mask = None
313 |
314 | uncond_embeddings = self.text_encoder(
315 | uncond_input.input_ids.to(device),
316 | attention_mask=attention_mask,
317 | )
318 | uncond_embeddings = uncond_embeddings[0]
319 |
320 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
321 | seq_len = uncond_embeddings.shape[1]
322 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
323 | uncond_embeddings = uncond_embeddings.view(
324 | batch_size * num_videos_per_prompt, seq_len, -1
325 | )
326 |
327 | # For classifier free guidance, we need to do two forward passes.
328 | # Here we concatenate the unconditional and text embeddings into a single batch
329 | # to avoid doing two forward passes
330 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
331 |
332 | return text_embeddings
333 |
334 | def get_timesteps(self, num_inference_steps, strength, device):
335 | # get the original timestep using init_timestep
336 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
337 |
338 | t_start = max(num_inference_steps - init_timestep, 0)
339 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
340 |
341 | return timesteps, num_inference_steps - t_start
342 |
343 | def prepare_reference_latent(self, reference_image, height, width):
344 | reference_image_tensor = self.reference_image_processor.preprocess(reference_image, height=height, width=width)
345 | reference_image_tensor = reference_image_tensor.to(dtype=self.dtype, device=self.device)
346 | reference_image_latents = self.vae.encode(reference_image_tensor).latent_dist.mean
347 | reference_image_latents = reference_image_latents * 0.18215
348 | return reference_image_latents
349 |
350 | def prepare_kps_feature(self, kps_images, height, width, do_classifier_free_guidance):
351 | kps_image_tensors = []
352 | for idx, kps_image in enumerate(kps_images):
353 | kps_image_tensor = self.condition_image_processor.preprocess(kps_image, height=height, width=width)
354 | kps_image_tensor = kps_image_tensor.unsqueeze(2) # [bs, c, 1, h, w]
355 | kps_image_tensors.append(kps_image_tensor)
356 | kps_images_tensor = torch.cat(kps_image_tensors, dim=2) # [bs, c, t, h, w]
357 |
358 | bs = 16
359 | num_forward = math.ceil(kps_images_tensor.shape[2] / bs)
360 | kps_feature = []
361 | for i in range(num_forward):
362 | tensor = kps_images_tensor[:, :, i * bs:(i + 1) * bs, ...].to(device=self.device, dtype=self.dtype)
363 | feature = self.v_kps_guider(tensor).cpu()
364 | kps_feature.append(feature)
365 | torch.cuda.empty_cache()
366 | kps_feature = torch.cat(kps_feature, dim=2)
367 |
368 | if do_classifier_free_guidance:
369 | uc_kps_feature = torch.zeros_like(kps_feature)
370 | kps_feature = torch.cat([uc_kps_feature, kps_feature], dim=0)
371 |
372 | return kps_feature
373 |
374 | def prepare_audio_embeddings(self, audio_waveform, video_length, num_pad_audio_frames, do_classifier_free_guidance):
375 | audio_waveform = self.audio_processor(audio_waveform, return_tensors="pt", sampling_rate=16000)['input_values']
376 | audio_waveform = audio_waveform.to(self.device, self.dtype)
377 | audio_embeddings = self.audio_encoder(audio_waveform).last_hidden_state # [1, num_embeds, d]
378 |
379 | audio_embeddings = torch.nn.functional.interpolate(
380 | audio_embeddings.permute(0, 2, 1),
381 | size=2 * video_length,
382 | mode='linear',
383 | )[0, :, :].permute(1, 0) # [2*vid_len, dim]
384 |
385 | audio_embeddings = torch.cat([
386 | torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :],
387 | audio_embeddings,
388 | torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :],
389 | ], dim=0) # [2*num_pad+2*vid_len+2*num_pad, dim]
390 |
391 | frame_audio_embeddings = []
392 | for frame_idx in range(video_length):
393 | start_sample = frame_idx
394 | end_sample = frame_idx + 2 * num_pad_audio_frames
395 |
396 | frame_audio_embedding = audio_embeddings[2 * start_sample:2 * (end_sample + 1), :] # [2*num_pad+1, dim]
397 | frame_audio_embeddings.append(frame_audio_embedding)
398 | audio_embeddings = torch.stack(frame_audio_embeddings, dim=0) # [vid_len, 2*num_pad+1, dim]
399 |
400 | audio_embeddings = self.audio_projection(audio_embeddings).unsqueeze(0)
401 | if do_classifier_free_guidance:
402 | uc_audio_embeddings = torch.zeros_like(audio_embeddings)
403 | audio_embeddings = torch.cat([uc_audio_embeddings, audio_embeddings], dim=0)
404 | return audio_embeddings
405 |
406 | @torch.no_grad()
407 | def __call__(
408 | self,
409 | reference_image,
410 | kps_images,
411 | audio_waveform,
412 | width,
413 | height,
414 | video_length,
415 | num_inference_steps,
416 | guidance_scale,
417 | strength=1.,
418 | num_images_per_prompt=1,
419 | eta: float = 0.0,
420 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
421 | output_type: Optional[str] = "tensor",
422 | return_dict: bool = True,
423 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
424 | callback_steps: Optional[int] = 1,
425 | context_schedule="uniform",
426 | context_frames=24,
427 | context_overlap=4,
428 | reference_attention_weight=1.,
429 | audio_attention_weight=1.,
430 | num_pad_audio_frames=2,
431 | do_multi_devices_inference=False,
432 | save_gpu_memory=False,
433 | **kwargs,
434 | ):
435 | # Default height and width to unet
436 | height = height or self.unet.config.sample_size * self.vae_scale_factor
437 | width = width or self.unet.config.sample_size * self.vae_scale_factor
438 |
439 | device = self._execution_device
440 |
441 | do_classifier_free_guidance = guidance_scale > 1.0
442 | batch_size = 1
443 |
444 | # Prepare timesteps
445 | timesteps = None
446 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
447 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
448 |
449 | reference_control_writer = ReferenceAttentionControl(
450 | self.reference_net,
451 | do_classifier_free_guidance=do_classifier_free_guidance,
452 | mode="write",
453 | batch_size=batch_size,
454 | fusion_blocks="full",
455 | )
456 | reference_control_reader = ReferenceAttentionControl(
457 | self.denoising_unet,
458 | do_classifier_free_guidance=do_classifier_free_guidance,
459 | mode="read",
460 | batch_size=batch_size,
461 | fusion_blocks="full",
462 | reference_attention_weight=reference_attention_weight,
463 | audio_attention_weight=audio_attention_weight,
464 | )
465 |
466 | num_channels_latents = self.denoising_unet.in_channels
467 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
468 |
469 | reference_image_latents = self.prepare_reference_latent(reference_image, height, width)
470 | kps_feature = self.prepare_kps_feature(kps_images, height, width, do_classifier_free_guidance)
471 | # if save_gpu_memory:
472 | # del self.v_kps_guider
473 | torch.cuda.empty_cache()
474 | audio_embeddings = self.prepare_audio_embeddings(
475 | audio_waveform,
476 | video_length,
477 | num_pad_audio_frames,
478 | do_classifier_free_guidance,
479 | )
480 | # if save_gpu_memory:
481 | # del self.audio_processor, self.audio_encoder, self.audio_projection
482 | torch.cuda.empty_cache()
483 |
484 | context_scheduler = get_context_scheduler(context_schedule)
485 | context_queue = list(
486 | context_scheduler(
487 | step=0,
488 | num_frames=video_length,
489 | context_size=context_frames,
490 | context_stride=1,
491 | context_overlap=context_overlap,
492 | closed_loop=False,
493 | )
494 | )
495 |
496 | num_frame_context = torch.zeros(video_length, device=device, dtype=torch.long)
497 | for context in context_queue:
498 | num_frame_context[context] += 1
499 |
500 | encoder_hidden_states = torch.zeros((1, 1, 768), dtype=self.dtype, device=self.device)
501 | self.reference_net(
502 | reference_image_latents,
503 | timestep=0,
504 | encoder_hidden_states=encoder_hidden_states,
505 | return_dict=False,
506 | )
507 | reference_control_reader.update(reference_control_writer, do_classifier_free_guidance)
508 | # if save_gpu_memory:
509 | # del self.reference_net
510 | torch.cuda.empty_cache()
511 |
512 | latents = self.prepare_latents(
513 | batch_size * num_images_per_prompt,
514 | num_channels_latents,
515 | width,
516 | height,
517 | video_length,
518 | self.dtype,
519 | torch.device('cpu'),
520 | generator,
521 | )
522 |
523 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
524 | with self.progress_bar(total=num_inference_steps) as progress_bar:
525 | for i, t in enumerate(timesteps):
526 | context_counter = torch.zeros(video_length, device=device, dtype=torch.long)
527 | noise_preds = [None] * video_length
528 | for context_idx, context in enumerate(context_queue):
529 | latent_kps_feature = kps_feature[:, :, context].to(device, self.dtype)
530 |
531 | latent_audio_embeddings = audio_embeddings[:, context, ...]
532 | _, _, num_tokens, dim = latent_audio_embeddings.shape
533 | latent_audio_embeddings = latent_audio_embeddings.reshape(-1, num_tokens, dim)
534 |
535 | input_latents = latents[:, :, context, ...].to(device)
536 | input_latents = input_latents.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
537 | input_latents = self.scheduler.scale_model_input(input_latents, t)
538 | noise_pred = self.denoising_unet(
539 | input_latents,
540 | t,
541 | encoder_hidden_states=latent_audio_embeddings.reshape(-1, num_tokens, dim),
542 | kps_features=latent_kps_feature,
543 | return_dict=False,
544 | )[0]
545 | if do_classifier_free_guidance:
546 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
547 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
548 |
549 | context_counter[context] += 1
550 | noise_pred /= num_frame_context[context][None, None, :, None, None]
551 | step_frame_ids = []
552 | step_noise_preds = []
553 | for latent_idx, frame_idx in enumerate(context):
554 | if noise_preds[frame_idx] is None:
555 | noise_preds[frame_idx] = noise_pred[:, :, latent_idx, ...]
556 | else:
557 | noise_preds[frame_idx] += noise_pred[:, :, latent_idx, ...]
558 | if context_counter[frame_idx] == num_frame_context[frame_idx]:
559 | step_frame_ids.append(frame_idx)
560 | step_noise_preds.append(noise_preds[frame_idx])
561 | noise_preds[frame_idx] = None
562 | step_noise_preds = torch.stack(step_noise_preds, dim=2)
563 | output_latents = self.scheduler.step(
564 | step_noise_preds,
565 | t,
566 | latents[:, :, step_frame_ids, ...].to(device),
567 | **extra_step_kwargs,
568 | ).prev_sample
569 | latents[:, :, step_frame_ids, ...] = output_latents.cpu()
570 |
571 | progress_bar.set_description(
572 | f'Denoising Step Index: {i + 1} / {len(timesteps)}, '
573 | f'Context Index: {context_idx + 1} / {len(context_queue)}'
574 | )
575 |
576 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
577 | progress_bar.update()
578 | if callback is not None and i % callback_steps == 0:
579 | step_idx = i // getattr(self.scheduler, "order", 1)
580 | callback(step_idx, t, latents)
581 |
582 | reference_control_reader.clear()
583 | reference_control_writer.clear()
584 |
585 | video_tensor = self.decode_latents(latents)
586 | return video_tensor
587 |
--------------------------------------------------------------------------------
/src/scripts/extract_kps_sequence_and_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import os
4 | import cv2
5 | import torch
6 | from insightface.app import FaceAnalysis
7 | from imageio_ffmpeg import get_ffmpeg_exe
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--video_path', type=str, default='')
11 | parser.add_argument('--kps_sequence_save_path', type=str, default='')
12 | parser.add_argument('--audio_save_path', type=str, default='')
13 | parser.add_argument('--device', type=str, default='cuda')
14 | parser.add_argument('--gpu_id', type=int, default=0)
15 | parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/')
16 | parser.add_argument('--height', type=int, default=512)
17 | parser.add_argument('--width', type=int, default=512)
18 | args = parser.parse_args()
19 |
20 | app = FaceAnalysis(
21 | providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'],
22 | provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [],
23 | root=args.insightface_model_path,
24 | )
25 | app.prepare(ctx_id=0, det_size=(args.height, args.width))
26 |
27 | os.system(f'{get_ffmpeg_exe()} -i "{args.video_path}" -y -vn "{args.audio_save_path}"')
28 |
29 | kps_sequence = []
30 | video_capture = cv2.VideoCapture(args.video_path)
31 | frame_idx = 0
32 | while video_capture.isOpened():
33 | ret, frame = video_capture.read()
34 | if not ret:
35 | break
36 | frame = cv2.resize(frame, (args.width, args.height))
37 | faces = app.get(frame)
38 | assert len(faces) == 1, f'There are {len(faces)} faces in the {frame_idx}-th frame. Only one face is supported.'
39 |
40 | kps = faces[0].kps[:3]
41 | kps_sequence.append(kps)
42 | frame_idx += 1
43 | torch.save(kps_sequence, args.kps_sequence_save_path)
44 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | from typing import Iterable
4 | import shutil
5 | import subprocess
6 | import re
7 |
8 |
9 | def ffmpeg_suitability(path):
10 | try:
11 | version = subprocess.run([path, "-version"], check=True,
12 | capture_output=True).stdout.decode("utf-8")
13 | except:
14 | return 0
15 | score = 0
16 | #rough layout of the importance of various features
17 | simple_criterion = [("libvpx", 20),("264",10), ("265",3),
18 | ("svtav1",5),("libopus", 1)]
19 | for criterion in simple_criterion:
20 | if version.find(criterion[0]) >= 0:
21 | score += criterion[1]
22 | #obtain rough compile year from copyright information
23 | copyright_index = version.find('2000-2')
24 | if copyright_index >= 0:
25 | copyright_year = version[copyright_index+6:copyright_index+9]
26 | if copyright_year.isnumeric():
27 | score += int(copyright_year)
28 | return score
29 |
30 |
31 | def get_ffmpeg():
32 | ffmpeg_paths = []
33 | try:
34 | from imageio_ffmpeg import get_ffmpeg_exe
35 | imageio_ffmpeg_path = get_ffmpeg_exe()
36 | ffmpeg_paths.append(imageio_ffmpeg_path)
37 | except:
38 | print("Failed to import imageio_ffmpeg")
39 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
40 | ffmpeg_path = imageio_ffmpeg_path
41 | else:
42 | system_ffmpeg = shutil.which("ffmpeg")
43 | if system_ffmpeg is not None:
44 | ffmpeg_paths.append(system_ffmpeg)
45 | if os.path.isfile("ffmpeg"):
46 | ffmpeg_paths.append(os.path.abspath("ffmpeg"))
47 | if os.path.isfile("ffmpeg.exe"):
48 | ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
49 | if len(ffmpeg_paths) == 0:
50 | print("No valid ffmpeg found.")
51 | ffmpeg_path = None
52 | elif len(ffmpeg_paths) == 1:
53 | #Evaluation of suitability isn't required, can take sole option
54 | #to reduce startup time
55 | ffmpeg_path = ffmpeg_paths[0]
56 | else:
57 | ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
58 | return ffmpeg_path
59 |
--------------------------------------------------------------------------------
/web/js/VEpreviewVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from "../../../scripts/api.js";
3 |
4 | // 定义一个函数来预览视频
5 | function previewVideo(node, file, type) {
6 | // 清除 node 元素中的所有子元素
7 | try {
8 | var el = document.getElementById("VEpreviewVideo");
9 | el.remove();
10 | } catch (error) {
11 | console.log(error);
12 | }
13 | var element = document.createElement("div");
14 | element.id = "VEpreviewVideo";
15 | const previewNode = node;
16 |
17 | // 创建一个新的 video 元素
18 | let videoEl = document.createElement("video");
19 |
20 | // 设置 video 元素的属性
21 | videoEl.controls = true;
22 | videoEl.style.width = "100%";
23 |
24 | let params = {
25 | filename: file,
26 | type: type,
27 | };
28 | // 更新 video 元素的 src 属性
29 | videoEl.src = api.apiURL("/view?" + new URLSearchParams(params));
30 |
31 | // 重新加载并播放视频
32 | videoEl.load();
33 | videoEl.play();
34 |
35 | // 清除 div 元素中的所有子元素
36 | while (element.firstChild) {
37 | element.removeChild(element.firstChild);
38 | }
39 |
40 | // 将 video 元素添加到 div 元素中
41 | element.appendChild(videoEl);
42 |
43 | node.previewWidget = node.addDOMWidget("videopreview", "preview", element, {
44 | serialize: false,
45 | hideOnZoom: false,
46 | getValue() {
47 | return element.value;
48 | },
49 | setValue(v) {
50 | element.value = v;
51 | },
52 | });
53 |
54 | var previewWidget = node.previewWidget;
55 |
56 | previewWidget.computeSize = function (width) {
57 | if (this.aspectRatio && !this.parentEl.hidden) {
58 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
59 | if (!(height > 0)) {
60 | height = 0;
61 | }
62 | this.computedHeight = height + 10;
63 | return [width, height];
64 | }
65 | return [width, -4]; //no loaded src, widget should not display
66 | };
67 | }
68 |
69 | app.registerExtension({
70 | name: "ComfyUI-V-Express.VideoPreviewer",
71 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
72 | if (nodeData?.name == "VEPreview_Video") {
73 | nodeType.prototype.onExecuted = function (data) {
74 | previewVideo(this, data.video[0], data.video[1]);
75 | };
76 | }
77 | },
78 | });
79 |
--------------------------------------------------------------------------------
/web/js/VEuploadAudio.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from "../../../scripts/api.js";
3 | import { ComfyWidgets } from "../../../scripts/widgets.js";
4 |
5 | function previewAudio(node, file) {
6 | while (node.widgets.length > 2) {
7 | node.widgets.pop();
8 | }
9 | try {
10 | var el = document.getElementById("VEuploadAudio");
11 | el.remove();
12 | } catch (error) {
13 | console.log(error);
14 | }
15 | var element = document.createElement("div");
16 | element.id = "VEuploadAudio";
17 | const previewNode = node;
18 | var previewWidget = node.addDOMWidget("audiopreview", "preview", element, {
19 | serialize: false,
20 | hideOnZoom: false,
21 | getValue() {
22 | return element.value;
23 | },
24 | setValue(v) {
25 | element.value = v;
26 | },
27 | });
28 | previewWidget.computeSize = function (width) {
29 | if (this.aspectRatio && !this.parentEl.hidden) {
30 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
31 | if (!(height > 0)) {
32 | height = 0;
33 | }
34 | this.computedHeight = height + 10;
35 | return [width, height];
36 | }
37 | return [width, -4]; //no loaded src, widget should not display
38 | };
39 | // element.style['pointer-events'] = "none"
40 | previewWidget.value = { hidden: false, paused: false, params: {} };
41 | previewWidget.parentEl = document.createElement("div");
42 | previewWidget.parentEl.className = "audio_preview";
43 | previewWidget.parentEl.style["width"] = "100%";
44 | element.appendChild(previewWidget.parentEl);
45 | previewWidget.audioEl = document.createElement("audio");
46 | previewWidget.audioEl.controls = true;
47 | previewWidget.audioEl.loop = false;
48 | previewWidget.audioEl.muted = false;
49 | previewWidget.audioEl.style["width"] = "100%";
50 | previewWidget.audioEl.addEventListener("loadedmetadata", () => {
51 | previewWidget.aspectRatio =
52 | previewWidget.audioEl.audioWidth / previewWidget.audioEl.audioHeight;
53 | });
54 | previewWidget.audioEl.addEventListener("error", () => {
55 | //TODO: consider a way to properly notify the user why a preview isn't shown.
56 | previewWidget.parentEl.hidden = true;
57 | });
58 |
59 | let params = {
60 | filename: file,
61 | type: "input",
62 | };
63 |
64 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
65 | previewWidget.audioEl.autoplay =
66 | !previewWidget.value.paused && !previewWidget.value.hidden;
67 | let target_width = 256;
68 | if (element.style?.width) {
69 | //overscale to allow scrolling. Endpoint won't return higher than native
70 | target_width = element.style.width.slice(0, -2) * 2;
71 | }
72 | if (
73 | !params.force_size ||
74 | params.force_size.includes("?") ||
75 | params.force_size == "Disabled"
76 | ) {
77 | params.force_size = target_width + "x?";
78 | } else {
79 | let size = params.force_size.split("x");
80 | let ar = parseInt(size[0]) / parseInt(size[1]);
81 | params.force_size = target_width + "x" + target_width / ar;
82 | }
83 |
84 | previewWidget.audioEl.src = api.apiURL(
85 | "/view?" + new URLSearchParams(params)
86 | );
87 |
88 | previewWidget.audioEl.hidden = false;
89 | previewWidget.parentEl.appendChild(previewWidget.audioEl);
90 | }
91 |
92 | function audioUpload(node, inputName, inputData, app) {
93 | const audioWidget = node.widgets.find((w) => w.name === "audio_path");
94 | let uploadWidget;
95 | /*
96 | A method that returns the required style for the html
97 | */
98 | var default_value = audioWidget.value;
99 | Object.defineProperty(audioWidget, "audio_path", {
100 | set: function (value) {
101 | this._real_value = value;
102 | },
103 |
104 | get: function () {
105 | let value = "";
106 | if (this._real_value) {
107 | value = this._real_value;
108 | } else {
109 | return default_value;
110 | }
111 |
112 | if (value.filename) {
113 | let real_value = value;
114 | value = "";
115 | if (real_value.subfolder) {
116 | value = real_value.subfolder + "/";
117 | }
118 |
119 | value += real_value.filename;
120 |
121 | if (real_value.type && real_value.type !== "input")
122 | value += ` [${real_value.type}]`;
123 | }
124 | return value;
125 | },
126 | });
127 | async function uploadFile(file, updateNode, pasted = false) {
128 | try {
129 | // Wrap file in formdata so it includes filename
130 | const body = new FormData();
131 | body.append("image", file);
132 | if (pasted) body.append("subfolder", "pasted");
133 | const resp = await api.fetchApi("/upload/image", {
134 | method: "POST",
135 | body,
136 | });
137 |
138 | if (resp.status === 200) {
139 | const data = await resp.json();
140 | // Add the file to the dropdown list and update the widget value
141 | let path = data.name;
142 | if (data.subfolder) path = data.subfolder + "/" + path;
143 |
144 | if (!audioWidget.options.values.includes(path)) {
145 | audioWidget.options.values.push(path);
146 | }
147 |
148 | if (updateNode) {
149 | audioWidget.value = path;
150 | previewAudio(node, path);
151 | }
152 | } else {
153 | alert(resp.status + " - " + resp.statusText);
154 | }
155 | } catch (error) {
156 | alert(error);
157 | }
158 | }
159 |
160 | const fileInput = document.createElement("input");
161 | Object.assign(fileInput, {
162 | type: "file",
163 | accept: "audio/mp3",
164 | style: "display: none",
165 | onchange: async () => {
166 | if (fileInput.files.length) {
167 | await uploadFile(fileInput.files[0], true);
168 | }
169 | },
170 | });
171 | document.body.append(fileInput);
172 |
173 | // Create the button widget for selecting the files
174 | uploadWidget = node.addWidget(
175 | "button",
176 | "choose audio file to upload",
177 | "Audio",
178 | () => {
179 | fileInput.click();
180 | }
181 | );
182 |
183 | uploadWidget.serialize = false;
184 |
185 | previewAudio(node, audioWidget.value);
186 | const cb = node.callback;
187 | audioWidget.callback = function () {
188 | previewAudio(node, audioWidget.value);
189 | if (cb) {
190 | return cb.apply(this, arguments);
191 | }
192 | };
193 |
194 | return { widget: uploadWidget };
195 | }
196 | ComfyWidgets.VEAUDIOUPLOAD = audioUpload;
197 |
198 | app.registerExtension({
199 | name: "ComfyUI-V-Express.UploadAudio",
200 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
201 | if (nodeData?.name == "Load_Audio_Path") {
202 | nodeData.input.required.upload = ["VEAUDIOUPLOAD"];
203 | }
204 | },
205 | });
206 |
--------------------------------------------------------------------------------
/web/js/VEuploadVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from "../../../scripts/api.js";
3 | import { ComfyWidgets } from "../../../scripts/widgets.js";
4 |
5 | function fitHeight(node) {
6 | node.setSize([
7 | node.size[0],
8 | node.computeSize([node.size[0], node.size[1]])[1],
9 | ]);
10 | node?.graph?.setDirtyCanvas(true);
11 | }
12 |
13 | function previewVideo(node, file) {
14 | while (node.widgets.length > 2) {
15 | node.widgets.pop();
16 | }
17 | try {
18 | var el = document.getElementById("VEuploadVideo");
19 | el.remove();
20 | } catch (error) {
21 | console.log(error);
22 | }
23 | var element = document.createElement("div");
24 | element.id = "VEuploadVideo";
25 | const previewNode = node;
26 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
27 | serialize: false,
28 | hideOnZoom: false,
29 | getValue() {
30 | return element.value;
31 | },
32 | setValue(v) {
33 | element.value = v;
34 | },
35 | });
36 | previewWidget.computeSize = function (width) {
37 | if (this.aspectRatio && !this.parentEl.hidden) {
38 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10;
39 | if (!(height > 0)) {
40 | height = 0;
41 | }
42 | this.computedHeight = height + 10;
43 | return [width, height];
44 | }
45 | return [width, -4]; //no loaded src, widget should not display
46 | };
47 | // element.style['pointer-events'] = "none"
48 | previewWidget.value = { hidden: false, paused: false, params: {} };
49 | previewWidget.parentEl = document.createElement("div");
50 | previewWidget.parentEl.className = "video_preview";
51 | previewWidget.parentEl.style["width"] = "100%";
52 | element.appendChild(previewWidget.parentEl);
53 | previewWidget.videoEl = document.createElement("video");
54 | previewWidget.videoEl.controls = true;
55 | previewWidget.videoEl.loop = false;
56 | previewWidget.videoEl.muted = false;
57 | previewWidget.videoEl.style["width"] = "100%";
58 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
59 | previewWidget.aspectRatio =
60 | previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
61 | fitHeight(this);
62 | });
63 | previewWidget.videoEl.addEventListener("error", () => {
64 | //TODO: consider a way to properly notify the user why a preview isn't shown.
65 | previewWidget.parentEl.hidden = true;
66 | fitHeight(this);
67 | });
68 |
69 | let params = {
70 | filename: file,
71 | type: "input",
72 | };
73 |
74 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
75 | previewWidget.videoEl.autoplay =
76 | !previewWidget.value.paused && !previewWidget.value.hidden;
77 | let target_width = 256;
78 | if (element.style?.width) {
79 | //overscale to allow scrolling. Endpoint won't return higher than native
80 | target_width = element.style.width.slice(0, -2) * 2;
81 | }
82 | if (
83 | !params.force_size ||
84 | params.force_size.includes("?") ||
85 | params.force_size == "Disabled"
86 | ) {
87 | params.force_size = target_width + "x?";
88 | } else {
89 | let size = params.force_size.split("x");
90 | let ar = parseInt(size[0]) / parseInt(size[1]);
91 | params.force_size = target_width + "x" + target_width / ar;
92 | }
93 |
94 | previewWidget.videoEl.src = api.apiURL(
95 | "/view?" + new URLSearchParams(params)
96 | );
97 |
98 | previewWidget.videoEl.hidden = false;
99 | previewWidget.parentEl.appendChild(previewWidget.videoEl);
100 | }
101 |
102 | function videoUpload(node, inputName, inputData, app) {
103 | const videoWidget = node.widgets.find((w) => w.name === "video_path");
104 | let uploadWidget;
105 | /*
106 | A method that returns the required style for the html
107 | */
108 | var default_value = videoWidget.value;
109 | Object.defineProperty(videoWidget, "video_path", {
110 | set: function (value) {
111 | this._real_value = value;
112 | },
113 |
114 | get: function () {
115 | let value = "";
116 | if (this._real_value) {
117 | value = this._real_value;
118 | } else {
119 | return default_value;
120 | }
121 |
122 | if (value.filename) {
123 | let real_value = value;
124 | value = "";
125 | if (real_value.subfolder) {
126 | value = real_value.subfolder + "/";
127 | }
128 |
129 | value += real_value.filename;
130 |
131 | if (real_value.type && real_value.type !== "input")
132 | value += ` [${real_value.type}]`;
133 | }
134 | return value;
135 | },
136 | });
137 | async function uploadFile(file, updateNode, pasted = false) {
138 | try {
139 | // Wrap file in formdata so it includes filename
140 | const body = new FormData();
141 | body.append("image", file);
142 | if (pasted) body.append("subfolder", "pasted");
143 | const resp = await api.fetchApi("/upload/image", {
144 | method: "POST",
145 | body,
146 | });
147 |
148 | if (resp.status === 200) {
149 | const data = await resp.json();
150 | // Add the file to the dropdown list and update the widget value
151 | let path = data.name;
152 | if (data.subfolder) path = data.subfolder + "/" + path;
153 |
154 | if (!videoWidget.options.values.includes(path)) {
155 | videoWidget.options.values.push(path);
156 | }
157 |
158 | if (updateNode) {
159 | videoWidget.value = path;
160 | previewVideo(node, path);
161 | }
162 | } else {
163 | alert(resp.status + " - " + resp.statusText);
164 | }
165 | } catch (error) {
166 | alert(error);
167 | }
168 | }
169 |
170 | const fileInput = document.createElement("input");
171 | Object.assign(fileInput, {
172 | type: "file",
173 | accept: "video/webm,video/mp4,video/mkv,video/avi",
174 | style: "display: none",
175 | onchange: async () => {
176 | if (fileInput.files.length) {
177 | await uploadFile(fileInput.files[0], true);
178 | }
179 | },
180 | });
181 | document.body.append(fileInput);
182 |
183 | // Create the button widget for selecting the files
184 | uploadWidget = node.addWidget(
185 | "button",
186 | "choose video file to upload",
187 | "Video",
188 | () => {
189 | fileInput.click();
190 | }
191 | );
192 |
193 | uploadWidget.serialize = false;
194 |
195 | previewVideo(node, videoWidget.value);
196 | const cb = node.callback;
197 | videoWidget.callback = function () {
198 | previewVideo(node, videoWidget.value);
199 | if (cb) {
200 | return cb.apply(this, arguments);
201 | }
202 | };
203 |
204 | return { widget: uploadWidget };
205 | }
206 |
207 | ComfyWidgets.VEVIDEOUPLOAD = videoUpload;
208 |
209 | app.registerExtension({
210 | name: "ComfyUI-V-Express.UploadVideo",
211 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
212 | if (
213 | nodeData?.name == "Load_Audio_Path_From_Video" ||
214 | nodeData?.name == "Load_Kps_Path_From_Video"
215 | ) {
216 | nodeData.input.required.upload = ["VEVIDEOUPLOAD"];
217 | }
218 | },
219 | });
220 |
--------------------------------------------------------------------------------
/workflow/V-Express.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 20,
3 | "last_link_id": 33,
4 | "nodes": [
5 | {
6 | "id": 4,
7 | "type": "V_Express_Loader",
8 | "pos": [
9 | 443,
10 | 14
11 | ],
12 | "size": {
13 | "0": 320.79998779296875,
14 | "1": 26
15 | },
16 | "flags": {},
17 | "order": 4,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "vexpress_model_path",
22 | "type": "STRING_INPUT",
23 | "link": 28
24 | }
25 | ],
26 | "outputs": [
27 | {
28 | "name": "v_express_pipeline",
29 | "type": "V_EXPRESS_PIPELINE",
30 | "links": [
31 | 11
32 | ],
33 | "shape": 3,
34 | "slot_index": 0
35 | }
36 | ],
37 | "properties": {
38 | "Node name for S&R": "V_Express_Loader"
39 | }
40 | },
41 | {
42 | "id": 6,
43 | "type": "Load_Audio_Path",
44 | "pos": [
45 | -30,
46 | 138
47 | ],
48 | "size": {
49 | "0": 315,
50 | "1": 82
51 | },
52 | "flags": {},
53 | "order": 0,
54 | "mode": 0,
55 | "outputs": [
56 | {
57 | "name": "AUDIO_PATH",
58 | "type": "AUDIO_PATH",
59 | "links": [
60 | 13
61 | ],
62 | "shape": 3,
63 | "slot_index": 0
64 | }
65 | ],
66 | "properties": {
67 | "Node name for S&R": "Load_Audio_Path"
68 | },
69 | "widgets_values": [
70 | "aud.mp3",
71 | "Audio",
72 | {
73 | "hidden": false,
74 | "paused": false,
75 | "params": {}
76 | }
77 | ]
78 | },
79 | {
80 | "id": 8,
81 | "type": "Load_Image_Path",
82 | "pos": [
83 | -32,
84 | 368
85 | ],
86 | "size": {
87 | "0": 315,
88 | "1": 294
89 | },
90 | "flags": {},
91 | "order": 1,
92 | "mode": 0,
93 | "outputs": [
94 | {
95 | "name": "IMAGE_PATH",
96 | "type": "IMAGE_PATH",
97 | "links": [
98 | 15
99 | ],
100 | "shape": 3,
101 | "slot_index": 0
102 | }
103 | ],
104 | "properties": {
105 | "Node name for S&R": "Load_Image_Path"
106 | },
107 | "widgets_values": [
108 | "ref.jpg",
109 | "image"
110 | ]
111 | },
112 | {
113 | "id": 13,
114 | "type": "Load_Kps_Path_From_Video",
115 | "pos": [
116 | 496,
117 | 361
118 | ],
119 | "size": {
120 | "0": 367.79998779296875,
121 | "1": 455.79998779296875
122 | },
123 | "flags": {},
124 | "order": 5,
125 | "mode": 0,
126 | "inputs": [
127 | {
128 | "name": "vexpress_model_path",
129 | "type": "STRING_INPUT",
130 | "link": 29
131 | },
132 | {
133 | "name": "image_size",
134 | "type": "INT_INPUT",
135 | "link": 31
136 | }
137 | ],
138 | "outputs": [
139 | {
140 | "name": "VKPS_PATH",
141 | "type": "VKPS_PATH",
142 | "links": [
143 | 20
144 | ],
145 | "shape": 3,
146 | "slot_index": 0
147 | }
148 | ],
149 | "properties": {
150 | "Node name for S&R": "Load_Kps_Path_From_Video"
151 | },
152 | "widgets_values": [
153 | "00000gt.mp4",
154 | "Video",
155 | {
156 | "hidden": false,
157 | "paused": false,
158 | "params": {}
159 | }
160 | ]
161 | },
162 | {
163 | "id": 18,
164 | "type": "VEStringConstant",
165 | "pos": [
166 | -18,
167 | -50
168 | ],
169 | "size": {
170 | "0": 315,
171 | "1": 58
172 | },
173 | "flags": {},
174 | "order": 2,
175 | "mode": 0,
176 | "outputs": [
177 | {
178 | "name": "STRING_INPUT",
179 | "type": "STRING_INPUT",
180 | "links": [
181 | 28,
182 | 29,
183 | 30
184 | ],
185 | "shape": 3,
186 | "slot_index": 0
187 | }
188 | ],
189 | "properties": {
190 | "Node name for S&R": "VEStringConstant"
191 | },
192 | "widgets_values": [
193 | "./model_ckpts"
194 | ]
195 | },
196 | {
197 | "id": 19,
198 | "type": "VEINTConstant",
199 | "pos": [
200 | -43,
201 | 773
202 | ],
203 | "size": {
204 | "0": 315,
205 | "1": 58
206 | },
207 | "flags": {},
208 | "order": 3,
209 | "mode": 0,
210 | "outputs": [
211 | {
212 | "name": "image_size",
213 | "type": "INT_INPUT",
214 | "links": [
215 | 31,
216 | 32
217 | ],
218 | "shape": 3,
219 | "slot_index": 0
220 | }
221 | ],
222 | "properties": {
223 | "Node name for S&R": "VEINTConstant"
224 | },
225 | "widgets_values": [
226 | 512
227 | ]
228 | },
229 | {
230 | "id": 10,
231 | "type": "V_Express_Sampler",
232 | "pos": [
233 | 1046,
234 | 13
235 | ],
236 | "size": {
237 | "0": 393,
238 | "1": 422
239 | },
240 | "flags": {},
241 | "order": 6,
242 | "mode": 0,
243 | "inputs": [
244 | {
245 | "name": "v_express_pipeline",
246 | "type": "V_EXPRESS_PIPELINE",
247 | "link": 11
248 | },
249 | {
250 | "name": "vexpress_model_path",
251 | "type": "STRING_INPUT",
252 | "link": 30
253 | },
254 | {
255 | "name": "audio_path",
256 | "type": "AUDIO_PATH",
257 | "link": 13
258 | },
259 | {
260 | "name": "kps_path",
261 | "type": "VKPS_PATH",
262 | "link": 20
263 | },
264 | {
265 | "name": "ref_image_path",
266 | "type": "IMAGE_PATH",
267 | "link": 15
268 | },
269 | {
270 | "name": "image_size",
271 | "type": "INT_INPUT",
272 | "link": 32
273 | }
274 | ],
275 | "outputs": [
276 | {
277 | "name": "output_path",
278 | "type": "STRING_INPUT",
279 | "links": [
280 | 33
281 | ],
282 | "shape": 3,
283 | "slot_index": 0
284 | }
285 | ],
286 | "properties": {
287 | "Node name for S&R": "V_Express_Sampler"
288 | },
289 | "widgets_values": [
290 | "E:\\ComfyUI_windows_portable\\ComfyUI\\output\\test3.mp4",
291 | "fix_face",
292 | 30,
293 | 42,
294 | "fixed",
295 | 20,
296 | 3.5,
297 | 12,
298 | 1,
299 | 4,
300 | 0.95,
301 | 3
302 | ]
303 | },
304 | {
305 | "id": 20,
306 | "type": "VEPreview_Video",
307 | "pos": [
308 | 1527,
309 | 33
310 | ],
311 | "size": {
312 | "0": 210,
313 | "1": 26
314 | },
315 | "flags": {},
316 | "order": 7,
317 | "mode": 0,
318 | "inputs": [
319 | {
320 | "name": "video",
321 | "type": "STRING_INPUT",
322 | "link": 33
323 | }
324 | ],
325 | "properties": {
326 | "Node name for S&R": "VEPreview_Video"
327 | },
328 | "widgets_values": [
329 | null
330 | ]
331 | }
332 | ],
333 | "links": [
334 | [
335 | 11,
336 | 4,
337 | 0,
338 | 10,
339 | 0,
340 | "V_EXPRESS_PIPELINE"
341 | ],
342 | [
343 | 13,
344 | 6,
345 | 0,
346 | 10,
347 | 2,
348 | "AUDIO_PATH"
349 | ],
350 | [
351 | 15,
352 | 8,
353 | 0,
354 | 10,
355 | 4,
356 | "IMAGE_PATH"
357 | ],
358 | [
359 | 20,
360 | 13,
361 | 0,
362 | 10,
363 | 3,
364 | "VKPS_PATH"
365 | ],
366 | [
367 | 28,
368 | 18,
369 | 0,
370 | 4,
371 | 0,
372 | "STRING_INPUT"
373 | ],
374 | [
375 | 29,
376 | 18,
377 | 0,
378 | 13,
379 | 0,
380 | "STRING_INPUT"
381 | ],
382 | [
383 | 30,
384 | 18,
385 | 0,
386 | 10,
387 | 1,
388 | "STRING_INPUT"
389 | ],
390 | [
391 | 31,
392 | 19,
393 | 0,
394 | 13,
395 | 1,
396 | "INT_INPUT"
397 | ],
398 | [
399 | 32,
400 | 19,
401 | 0,
402 | 10,
403 | 5,
404 | "INT_INPUT"
405 | ],
406 | [
407 | 33,
408 | 10,
409 | 0,
410 | 20,
411 | 0,
412 | "STRING_INPUT"
413 | ]
414 | ],
415 | "groups": [],
416 | "config": {},
417 | "extra": {},
418 | "version": 0.4
419 | }
--------------------------------------------------------------------------------