├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── assets └── workflow.png ├── input ├── .DS_Store ├── aud.mp3 ├── gt.mp4 ├── kps.pth └── ref.jpg ├── model_ckpts └── put_V-Express_model_here ├── nodes.py ├── requirements.txt ├── src ├── inference.py ├── inference_v2.yaml ├── modules │ ├── .DS_Store │ ├── __init__.py │ ├── attention.py │ ├── audio_projection.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ ├── unet_3d_blocks.py │ └── v_kps_guider.py ├── pipelines │ ├── .DS_Store │ ├── __init__.py │ ├── context.py │ ├── utils.py │ └── v_express_pipeline.py ├── scripts │ └── extract_kps_sequence_and_audio.py └── util.py ├── web └── js │ ├── VEpreviewVideo.js │ ├── VEuploadAudio.js │ └── VEuploadVideo.js └── workflow └── V-Express.json /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | - name: Publish Custom Node 19 | uses: Comfy-Org/publish-node-action@main 20 | with: 21 | ## Add your own personal access token to your Github Repository secrets and reference it here. 22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | *.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **_V-Express: Conditional Dropout for Progressive Training of Portrait Video Generation_** 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | --- 10 | 11 | ## Introduction 12 | 13 | In the field of portrait video generation, the use of single images to generate portrait videos has become increasingly prevalent. 14 | A common approach involves leveraging generative models to enhance adapters for controlled generation. 15 | However, control signals can vary in strength, including text, audio, image reference, pose, depth map, etc. 16 | Among these, weaker conditions often struggle to be effective due to interference from stronger conditions, posing a challenge in balancing these conditions. 17 | In our work on portrait video generation, we identified audio signals as particularly weak, often overshadowed by stronger signals such as pose and original image. 18 | However, direct training with weak signals often leads to difficulties in convergence. 19 | To address this, we propose V-Express, a simple method that balances different control signals through a series of progressive drop operations. 20 | Our method gradually enables effective control by weak conditions, thereby achieving generation capabilities that simultaneously take into account pose, input image, and audio. 21 | 22 | ## Workflow 23 | 24 | ![workflow](./assets/workflow.png) 25 | 26 | ## Installation 27 | 28 | 1. Clone this repo into the Your ComfyUI root directory\ComfyUI\custom_nodes\ and install dependent Python packages from [here](https://github.com/tencent-ailab/V-Express#installation): 29 | 30 | ```shell 31 | cd Your_ComfyUI_root_directory\ComfyUI\custom_nodes\ 32 | 33 | git clone https://github.com/tiankuan93/ComfyUI-V-Express 34 | 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | If you are using ComfyUI_windows_portable , you should use `.\python_embeded\python.exe -m pip` to replace `pip` for installation. 39 | 40 | If you got error regards insightface, you may find solution [here](https://www.youtube.com/watch?v=vCCVxGtCyho). 41 | 42 | - first, you should download .whl file [here](https://github.com/Gourieff/Assets/tree/main/Insightface) 43 | - then, install it by `.\python_embeded\python.exe -m pip install --no-deps --target=\your_path_of\python_embeded\Lib\site-packages [path-to-wheel]` 44 | 45 | 2. Download V-Express models and other needed models: 46 | 47 | - [model_ckpts](https://huggingface.co/tk93/V-Express) 48 | - You need to replace the **model_ckpts** folder with the downloaded **V-Express/model_ckpts**. Then you should download and put all `.bin` model to `model_ckpts/v-express` directory, which includes `audio_projection.bin`, `denoising_unet.bin`, `motion_module.bin`, `reference_net.bin`, and `v_kps_guider.bin`. The final **model_ckpts** folder is as follows: 49 | 50 | ```text 51 | ./model_ckpts/ 52 | |-- insightface_models 53 | |-- sd-vae-ft-mse 54 | |-- stable-diffusion-v1-5 55 | |-- v-express 56 | |-- wav2vec2-base-960h 57 | ``` 58 | 59 | 3. You should put the files in input directory into the Your ComfyUI Input root `directory\ComfyUI\input\`. 60 | 4. You need to set `output_path` as `directory\ComfyUI\output\xxx.mp4`, otherwise the output video will not be displayed in the ComfyUI. 61 | 62 | ## Acknowledgements 63 | 64 | We would like to thank the contributors to the [AIFSH/ComfyUI_V-Express](https://github.com/AIFSH/ComfyUI_V-Express), for the open research and exploration. 65 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | WEB_DIRECTORY = "./web" 4 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 5 | -------------------------------------------------------------------------------- /assets/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/assets/workflow.png -------------------------------------------------------------------------------- /input/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/.DS_Store -------------------------------------------------------------------------------- /input/aud.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/aud.mp3 -------------------------------------------------------------------------------- /input/gt.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/gt.mp4 -------------------------------------------------------------------------------- /input/kps.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/kps.pth -------------------------------------------------------------------------------- /input/ref.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/input/ref.jpg -------------------------------------------------------------------------------- /model_ckpts/put_V-Express_model_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/model_ckpts/put_V-Express_model_here -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import numpy as np 5 | import time 6 | import torch 7 | import torchaudio.functional 8 | import torchvision.io 9 | from imageio_ffmpeg import get_ffmpeg_exe 10 | from PIL import Image 11 | 12 | from diffusers.utils.torch_utils import randn_tensor 13 | from diffusers import AutoencoderKL 14 | from insightface.app import FaceAnalysis 15 | from transformers import Wav2Vec2Model, Wav2Vec2Processor 16 | import accelerate 17 | 18 | import folder_paths 19 | import folder_paths as comfy_paths 20 | from comfy import model_management 21 | 22 | ROOT_PATH = os.path.join(comfy_paths.get_folder_paths("custom_nodes")[0], "ComfyUI-V-Express") 23 | sys.path.append(os.path.join(ROOT_PATH, 'src')) 24 | 25 | from .src.pipelines import VExpressPipeline 26 | from .src.pipelines.utils import draw_kps_image, save_video 27 | from .src.pipelines.utils import retarget_kps 28 | from .src.util import get_ffmpeg 29 | from .src.inference import ( 30 | get_scheduler, 31 | load_reference_net, 32 | load_denoising_unet, 33 | load_v_kps_guider, 34 | load_audio_projection, 35 | ) 36 | 37 | INPUT_PATH = folder_paths.get_input_directory() 38 | OUTPUT_PATH = folder_paths.get_output_directory() 39 | 40 | 41 | INFERENCE_CONFIG_PATH = os.path.join(ROOT_PATH, "src/inference_v2.yaml") 42 | 43 | load_device = model_management.get_torch_device() 44 | offload_device = model_management.unet_offload_device() 45 | DEVICE = load_device 46 | WEIGHT_DTYPE = torch.float16 47 | GPU_ID = 0 48 | 49 | STANDARD_AUDIO_SAMPLING_RATE = 16000 50 | NUM_PAD_AUDIO_FRAMES = 2 51 | 52 | 53 | def get_all_model_path(vexpress_model_path): 54 | if not os.path.isabs(vexpress_model_path): 55 | vexpress_model_path = os.path.join(ROOT_PATH, vexpress_model_path) 56 | 57 | unet_config_path = os.path.join(vexpress_model_path, 'stable-diffusion-v1-5/unet/config.json') 58 | vae_path = os.path.join(vexpress_model_path, 'sd-vae-ft-mse') 59 | audio_encoder_path = os.path.join(vexpress_model_path, 'wav2vec2-base-960h') 60 | insightface_model_path = os.path.join(vexpress_model_path, 'insightface_models') 61 | 62 | denoising_unet_path = os.path.join(vexpress_model_path, 'v-express/denoising_unet.bin') 63 | reference_net_path = os.path.join(vexpress_model_path, 'v-express/reference_net.bin') 64 | v_kps_guider_path = os.path.join(vexpress_model_path, 'v-express/v_kps_guider.bin') 65 | audio_projection_path = os.path.join(vexpress_model_path, 'v-express/audio_projection.bin') 66 | motion_module_path = os.path.join(vexpress_model_path, 'v-express/motion_module.bin') 67 | 68 | if not os.path.isfile(denoising_unet_path): 69 | denoising_unet_path = os.path.join(vexpress_model_path, 'v-express/denoising_unet.pth') 70 | if not os.path.isfile(reference_net_path): 71 | reference_net_path = os.path.join(vexpress_model_path, 'v-express/reference_net.pth') 72 | if not os.path.isfile(v_kps_guider_path): 73 | v_kps_guider_path = os.path.join(vexpress_model_path, 'v-express/v_kps_guider.pth') 74 | if not os.path.isfile(audio_projection_path): 75 | audio_projection_path = os.path.join(vexpress_model_path, 'v-express/audio_projection.pth') 76 | if not os.path.isfile(motion_module_path): 77 | motion_module_path = os.path.join(vexpress_model_path, 'v-express/motion_module.pth') 78 | 79 | model_dict = { 80 | "unet_config_path": unet_config_path, 81 | "vae_path": vae_path, 82 | "audio_encoder_path": audio_encoder_path, 83 | "insightface_model_path": insightface_model_path, 84 | "denoising_unet_path": denoising_unet_path, 85 | "reference_net_path": reference_net_path, 86 | "v_kps_guider_path": v_kps_guider_path, 87 | "audio_projection_path": audio_projection_path, 88 | "motion_module_path": motion_module_path, 89 | } 90 | return model_dict 91 | 92 | 93 | class VEINTConstant: 94 | @classmethod 95 | def INPUT_TYPES(s): 96 | return {"required": { 97 | "image_size": ("INT", {"default": 512, "min": 512, "max": 2048}), 98 | }, 99 | } 100 | RETURN_TYPES = ("INT_INPUT",) 101 | RETURN_NAMES = ("image_size",) 102 | FUNCTION = "get_value" 103 | CATEGORY = "V-Express" 104 | 105 | def get_value(self, image_size): 106 | return (image_size,) 107 | 108 | 109 | class VEStringConstant: 110 | @classmethod 111 | def INPUT_TYPES(cls): 112 | return { 113 | "required": { 114 | "string": ("STRING", {"default": './model_ckpts', "multiline": False}), 115 | } 116 | } 117 | RETURN_TYPES = ("STRING_INPUT",) 118 | FUNCTION = "passtring" 119 | CATEGORY = "V-Express" 120 | 121 | def passtring(self, string): 122 | return (string, ) 123 | 124 | 125 | class V_Express_Sampler: 126 | @classmethod 127 | def INPUT_TYPES(s): 128 | return { 129 | "required": { 130 | "v_express_pipeline": ("V_EXPRESS_PIPELINE",), 131 | "vexpress_model_path": ("STRING_INPUT", ), 132 | "audio_path": ("AUDIO_PATH",), 133 | "kps_path": ("VKPS_PATH",), 134 | "ref_image_path": ("IMAGE_PATH",), 135 | "output_path": ("STRING",{ 136 | "default": os.path.join(OUTPUT_PATH,f"{time.time()}_vexpress.mp4") 137 | }), 138 | "image_size": ("INT_INPUT",), 139 | "retarget_strategy": ( 140 | ["fix_face", "no_retarget", "offset_retarget", "naive_retarget"], 141 | {"default": "fix_face"} 142 | ), 143 | "fps": ("FLOAT", {"default": 30.0, "min": 20.0, "max": 60.0}), 144 | "seed": ("INT",{ 145 | "default": 42 146 | }), 147 | "num_inference_steps": ("INT",{ 148 | "default": 20 149 | }), 150 | "guidance_scale": ("FLOAT",{ 151 | "default": 3.5 152 | }), 153 | "context_frames": ("INT",{ 154 | "default": 12 155 | }), 156 | "context_stride": ("INT",{ 157 | "default": 1 158 | }), 159 | "context_overlap": ("INT",{ 160 | "default": 4 161 | }), 162 | "reference_attention_weight": ("FLOAT",{ 163 | "default": 0.95 164 | }), 165 | "audio_attention_weight": ("FLOAT",{ 166 | "default": 3. 167 | }), 168 | } 169 | } 170 | 171 | RETURN_TYPES = ( 172 | "STRING_INPUT", 173 | ) 174 | RETURN_NAMES = ( 175 | "output_path", 176 | ) 177 | OUTPUT_NODE = True 178 | # OUTPUT_NODE = False 179 | CATEGORY = "V-Express" 180 | FUNCTION = "v_express" 181 | def v_express( 182 | self, 183 | v_express_pipeline, 184 | vexpress_model_path, 185 | audio_path, 186 | kps_path, 187 | ref_image_path, 188 | output_path, 189 | image_size, 190 | retarget_strategy, 191 | fps, 192 | seed, 193 | num_inference_steps, 194 | guidance_scale, 195 | context_frames, 196 | context_stride, 197 | context_overlap, 198 | reference_attention_weight, 199 | audio_attention_weight, 200 | save_gpu_memory=True, 201 | do_multi_devices_inference=False, 202 | ): 203 | start_time = time.time() 204 | 205 | accelerator = None 206 | 207 | reference_image_path = ref_image_path 208 | model_dict = get_all_model_path(vexpress_model_path) 209 | 210 | insightface_model_path = model_dict['insightface_model_path'] 211 | 212 | app = FaceAnalysis( 213 | providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'], 214 | provider_options=[{'device_id': GPU_ID}] if DEVICE == 'cuda' else [], 215 | root=insightface_model_path, 216 | ) 217 | app.prepare(ctx_id=0, det_size=(image_size, image_size)) 218 | 219 | reference_image = Image.open(reference_image_path).convert('RGB') 220 | reference_image = reference_image.resize((image_size, image_size)) 221 | 222 | reference_image_for_kps = cv2.imread(reference_image_path) 223 | reference_image_for_kps = cv2.resize(reference_image_for_kps, (image_size, image_size)) 224 | reference_kps = app.get(reference_image_for_kps)[0].kps[:3] 225 | if save_gpu_memory: 226 | del app 227 | torch.cuda.empty_cache() 228 | 229 | _, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec') 230 | audio_sampling_rate = meta_info['audio_fps'] 231 | print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.') 232 | if audio_sampling_rate != STANDARD_AUDIO_SAMPLING_RATE: 233 | audio_waveform = torchaudio.functional.resample( 234 | audio_waveform, 235 | orig_freq=audio_sampling_rate, 236 | new_freq=STANDARD_AUDIO_SAMPLING_RATE, 237 | ) 238 | audio_waveform = audio_waveform.mean(dim=0) 239 | 240 | duration = audio_waveform.shape[0] / STANDARD_AUDIO_SAMPLING_RATE 241 | init_video_length = int(duration * fps) 242 | num_contexts = np.around((init_video_length + context_overlap) / context_frames) 243 | video_length = int(num_contexts * context_frames - context_overlap) 244 | fps = video_length / duration 245 | print(f'The corresponding video length is {video_length}.') 246 | 247 | kps_sequence = None 248 | if kps_path != "": 249 | assert os.path.exists(kps_path), f'{kps_path} does not exist' 250 | kps_sequence = torch.tensor(torch.load(kps_path)) # [len, 3, 2] 251 | print(f'The original length of kps sequence is {kps_sequence.shape[0]}.') 252 | 253 | if kps_sequence.shape[0] > video_length: 254 | kps_sequence = kps_sequence[:video_length, :, :] 255 | 256 | kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear') 257 | kps_sequence = kps_sequence.permute(2, 0, 1) 258 | print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.') 259 | 260 | if retarget_strategy == 'fix_face': 261 | kps_sequence = torch.tensor([reference_kps] * video_length) 262 | elif retarget_strategy == 'no_retarget': 263 | kps_sequence = kps_sequence 264 | elif retarget_strategy == 'offset_retarget': 265 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True) 266 | elif retarget_strategy == 'naive_retarget': 267 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False) 268 | else: 269 | raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.') 270 | 271 | kps_images = [] 272 | for i in range(video_length): 273 | kps_image = draw_kps_image(image_size, image_size, kps_sequence[i]) 274 | kps_images.append(Image.fromarray(kps_image)) 275 | 276 | generator = torch.manual_seed(seed) 277 | video_tensor = v_express_pipeline( 278 | reference_image=reference_image, 279 | kps_images=kps_images, 280 | audio_waveform=audio_waveform, 281 | width=image_size, 282 | height=image_size, 283 | video_length=video_length, 284 | num_inference_steps=num_inference_steps, 285 | guidance_scale=guidance_scale, 286 | context_frames=context_frames, 287 | context_overlap=context_overlap, 288 | reference_attention_weight=reference_attention_weight, 289 | audio_attention_weight=audio_attention_weight, 290 | num_pad_audio_frames=NUM_PAD_AUDIO_FRAMES, 291 | generator=generator, 292 | do_multi_devices_inference=do_multi_devices_inference, 293 | save_gpu_memory=save_gpu_memory, 294 | ) 295 | 296 | if accelerator is None or accelerator.is_main_process: 297 | save_video(video_tensor, audio_path, output_path, DEVICE, fps) 298 | consumed_time = time.time() - start_time 299 | generation_fps = video_tensor.shape[2] / consumed_time 300 | print(f'The generated video has been saved at {output_path}. ' 301 | f'The generation time is {consumed_time:.1f} seconds. ' 302 | f'The generation FPS is {generation_fps:.2f}.') 303 | 304 | return (output_path, ) 305 | 306 | 307 | class V_Express_Loader: 308 | @classmethod 309 | def INPUT_TYPES(s): 310 | return { 311 | "required": { 312 | "vexpress_model_path": ("STRING_INPUT", ), 313 | }, 314 | } 315 | 316 | RETURN_TYPES = ( 317 | "V_EXPRESS_PIPELINE", 318 | ) 319 | RETURN_NAMES = ( 320 | "v_express_pipeline", 321 | ) 322 | 323 | CATEGORY = "V-Express" 324 | FUNCTION = "load_vexpress_pipeline" 325 | def load_vexpress_pipeline(self, vexpress_model_path): 326 | 327 | model_dict = get_all_model_path(vexpress_model_path) 328 | 329 | unet_config_path = model_dict['unet_config_path'] 330 | reference_net_path = model_dict['reference_net_path'] 331 | denoising_unet_path = model_dict['denoising_unet_path'] 332 | v_kps_guider_path = model_dict['v_kps_guider_path'] 333 | audio_projection_path = model_dict['audio_projection_path'] 334 | motion_module_path = model_dict['motion_module_path'] 335 | 336 | vae_path = model_dict['vae_path'] 337 | audio_encoder_path = model_dict['audio_encoder_path'] 338 | 339 | dtype = WEIGHT_DTYPE 340 | device = DEVICE 341 | inference_config_path = INFERENCE_CONFIG_PATH 342 | 343 | scheduler = get_scheduler(inference_config_path) 344 | reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device) 345 | denoising_unet = load_denoising_unet( 346 | inference_config_path, unet_config_path, denoising_unet_path, motion_module_path, 347 | dtype, device 348 | ) 349 | v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device) 350 | audio_projection = load_audio_projection( 351 | audio_projection_path, 352 | dtype, 353 | device, 354 | inp_dim=denoising_unet.config.cross_attention_dim, 355 | mid_dim=denoising_unet.config.cross_attention_dim, 356 | out_dim=denoising_unet.config.cross_attention_dim, 357 | inp_seq_len=2 * (2 * NUM_PAD_AUDIO_FRAMES + 1), 358 | out_seq_len=2 * NUM_PAD_AUDIO_FRAMES + 1, 359 | ) 360 | 361 | vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device) 362 | audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device) 363 | audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path) 364 | 365 | v_express_pipeline = VExpressPipeline( 366 | vae=vae, 367 | reference_net=reference_net, 368 | denoising_unet=denoising_unet, 369 | v_kps_guider=v_kps_guider, 370 | audio_processor=audio_processor, 371 | audio_encoder=audio_encoder, 372 | audio_projection=audio_projection, 373 | scheduler=scheduler, 374 | ).to(dtype=dtype, device=device) 375 | 376 | return (v_express_pipeline,) 377 | 378 | 379 | class Load_Audio_Path: 380 | @classmethod 381 | def INPUT_TYPES(s): 382 | files = [] 383 | for f in os.listdir(INPUT_PATH): 384 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp3"]: # only support mp3 385 | files.append(f) 386 | 387 | return {"required":{ 388 | "audio_path": (files,), 389 | }} 390 | 391 | CATEGORY = "V-Express" 392 | 393 | RETURN_TYPES = ("AUDIO_PATH",) 394 | 395 | FUNCTION = "load_audio_path" 396 | def load_audio_path(self, audio_path): 397 | audio_path = os.path.join(INPUT_PATH, audio_path) 398 | return (audio_path,) 399 | 400 | 401 | class Load_Audio_Path_From_Video: 402 | @classmethod 403 | def INPUT_TYPES(s): 404 | files = [] 405 | for f in os.listdir(INPUT_PATH): 406 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]: 407 | files.append(f) 408 | 409 | return {"required":{ 410 | "video_path": (files,), 411 | }} 412 | 413 | CATEGORY = "V-Express" 414 | 415 | RETURN_TYPES = ("AUDIO_PATH",) 416 | 417 | FUNCTION = "load_audio_path_from_video" 418 | def load_audio_path_from_video(self, video_path): 419 | video_path = os.path.join(INPUT_PATH, video_path) 420 | video_base_name = video_path[:video_path.rfind('.')] 421 | audio_name = f'{video_base_name}_audio.mp3' 422 | audio_path = os.path.join(INPUT_PATH, audio_name) 423 | os.system(f'{get_ffmpeg_exe()} -i "{video_path}" -y -vn "{audio_path}"') 424 | if not os.path.isfile(audio_path): 425 | raise ValueError(f'{audio_path} not exists! Please check if the video contains audio!') 426 | return (audio_path,) 427 | 428 | 429 | class Load_Kps_Path: 430 | @classmethod 431 | def INPUT_TYPES(s): 432 | files = [] 433 | for f in os.listdir(INPUT_PATH): 434 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["pth"]: 435 | files.append(f) 436 | 437 | return {"required":{ 438 | "kps_path": (files,), 439 | }} 440 | 441 | CATEGORY = "V-Express" 442 | 443 | RETURN_TYPES = ("VKPS_PATH",) 444 | 445 | FUNCTION = "load_kps_path" 446 | def load_kps_path(self, kps_path): 447 | kps_path = os.path.join(INPUT_PATH, kps_path) 448 | return (kps_path,) 449 | 450 | 451 | class Load_Kps_Path_From_Video: 452 | @classmethod 453 | def INPUT_TYPES(s): 454 | files = [] 455 | for f in os.listdir(INPUT_PATH): 456 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]: 457 | files.append(f) 458 | 459 | return {"required":{ 460 | "vexpress_model_path": ("STRING_INPUT", ), 461 | "video_path": (files,), 462 | "image_size": ("INT_INPUT",), 463 | }} 464 | 465 | CATEGORY = "V-Express" 466 | 467 | RETURN_TYPES = ("VKPS_PATH",) 468 | 469 | FUNCTION = "load_kps_path_from_video" 470 | def load_kps_path_from_video(self, vexpress_model_path, video_path, image_size): 471 | video_path = os.path.join(INPUT_PATH, video_path) 472 | video_base_name = video_path[:video_path.rfind('.')] 473 | kps_name = f'{video_base_name}_kps.pth' 474 | kps_path = os.path.join(INPUT_PATH, kps_name) 475 | 476 | model_dict = get_all_model_path(vexpress_model_path) 477 | insightface_model_path = model_dict['insightface_model_path'] 478 | 479 | app = FaceAnalysis( 480 | providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'], 481 | provider_options=[{'device_id': GPU_ID}] if DEVICE == 'cuda' else [], 482 | root=insightface_model_path, 483 | ) 484 | app.prepare(ctx_id=0, det_size=(image_size, image_size)) 485 | 486 | kps_sequence = [] 487 | video_capture = cv2.VideoCapture(video_path) 488 | frame_idx = 0 489 | while video_capture.isOpened(): 490 | ret, frame = video_capture.read() 491 | if not ret: 492 | break 493 | frame = cv2.resize(frame, (image_size, image_size)) 494 | faces = app.get(frame) 495 | assert len(faces) == 1, f'There are {len(faces)} faces in the {frame_idx}-th frame. Only one face is supported.' 496 | 497 | kps = faces[0].kps[:3] 498 | kps_sequence.append(kps) 499 | frame_idx += 1 500 | torch.save(kps_sequence, kps_path) 501 | 502 | if not os.path.isfile(kps_path): 503 | raise ValueError(f'{kps_path} not exists! Please check the input!') 504 | return (kps_path,) 505 | 506 | 507 | class Load_Image_Path: 508 | @classmethod 509 | def INPUT_TYPES(s): 510 | input_dir = INPUT_PATH 511 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] 512 | return {"required": 513 | {"image": (sorted(files), {"image_upload": True})}, 514 | } 515 | 516 | CATEGORY = "V-Express" 517 | 518 | RETURN_TYPES = ("IMAGE_PATH",) 519 | FUNCTION = "load_image_path" 520 | def load_image_path(self, image): 521 | image_path = os.path.join(INPUT_PATH, image) 522 | return (image_path,) 523 | 524 | 525 | class Load_Video_Path: 526 | @classmethod 527 | def INPUT_TYPES(s): 528 | files = [] 529 | for f in os.listdir(INPUT_PATH): 530 | if os.path.isfile(os.path.join(INPUT_PATH, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]: 531 | files.append(f) 532 | 533 | return {"required":{ 534 | "video_path": (files,), 535 | }} 536 | 537 | CATEGORY = "V-Express" 538 | 539 | RETURN_TYPES = ("STRING_INPUT",) 540 | 541 | FUNCTION = "load_video_path" 542 | def load_video_path(self, video_path): 543 | video_path = os.path.join(INPUT_PATH, video_path) 544 | return (video_path,) 545 | 546 | 547 | class VEPreview_Video: 548 | @classmethod 549 | def INPUT_TYPES(s): 550 | return {"required":{ 551 | "video":("STRING_INPUT",), 552 | }} 553 | 554 | CATEGORY = "V-Express" 555 | DESCRIPTION = "show result" 556 | 557 | RETURN_TYPES = () 558 | 559 | OUTPUT_NODE = True 560 | 561 | FUNCTION = "load_video" 562 | def load_video(self, video): 563 | video_name = os.path.basename(video) 564 | video_path_name = os.path.basename(os.path.dirname(video)) 565 | return {"ui":{"video":[video_name, video_path_name]}} 566 | 567 | @classmethod 568 | def IS_CHANGED(s,): 569 | return "" 570 | 571 | 572 | # A dictionary that contains all nodes you want to export with their names 573 | # NOTE: names should be globally unique 574 | NODE_CLASS_MAPPINGS = { 575 | "V_Express_Loader": V_Express_Loader, 576 | "V_Express_Sampler": V_Express_Sampler, 577 | "Load_Audio_Path": Load_Audio_Path, 578 | "Load_Audio_Path_From_Video": Load_Audio_Path_From_Video, 579 | "Load_Kps_Path": Load_Kps_Path, 580 | "Load_Kps_Path_From_Video": Load_Kps_Path_From_Video, 581 | "Load_Image_Path": Load_Image_Path, 582 | "Load_Video_Path": Load_Video_Path, 583 | "VEINTConstant": VEINTConstant, 584 | "VEStringConstant": VEStringConstant, 585 | "VEPreview_Video": VEPreview_Video, 586 | } 587 | 588 | # A dictionary that contains the friendly/humanly readable titles for the nodes 589 | NODE_DISPLAY_NAME_MAPPINGS = { 590 | "V_Express_Loader": "V-Express Loader", 591 | "V_Express_Sampler": "V-Express Sampler", 592 | "Load_Audio_Path": "Load Audio Path", 593 | "Load_Audio_Path_From_Video": "Load Audio Path From Video", 594 | "Load_Kps_Path": "Load V-Kps Path", 595 | "Load_Kps_Path_From_Video": "Load V-Kps Path From Video", 596 | "Load_Image_Path": "Load Reference Image Path", 597 | "Load_Video_Path": "Load Video Path", 598 | "VEINTConstant": "Set Image Size", 599 | "VEStringConstant": "Set V-Express Model Path", 600 | "VEPreview_Video": "Preview Output Video", 601 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.24.0 2 | imageio-ffmpeg==0.4.9 3 | insightface==0.7.3 4 | omegaconf==2.2.3 5 | onnxruntime==1.16.3 6 | safetensors==0.4.2 7 | torch==2.0.1 8 | torchaudio==2.0.2 9 | torchvision==0.15.2 10 | transformers==4.30.2 11 | einops==0.4.1 12 | tqdm==4.66.1 13 | xformers==0.0.22 14 | av==11.0.0 -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import accelerate 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchaudio.functional 10 | import torchvision.io 11 | from PIL import Image 12 | from diffusers import AutoencoderKL, DDIMScheduler 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from insightface.app import FaceAnalysis 15 | from omegaconf import OmegaConf 16 | from transformers import Wav2Vec2Model, Wav2Vec2Processor 17 | 18 | from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection 19 | from pipelines import VExpressPipeline 20 | from pipelines.utils import draw_kps_image, save_video 21 | from pipelines.utils import retarget_kps 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--unet_config_path', type=str, default='./model_ckpts/stable-diffusion-v1-5/unet/config.json') 28 | parser.add_argument('--vae_path', type=str, default='./model_ckpts/sd-vae-ft-mse/') 29 | parser.add_argument('--audio_encoder_path', type=str, default='./model_ckpts/wav2vec2-base-960h/') 30 | parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/') 31 | 32 | parser.add_argument('--denoising_unet_path', type=str, default='./model_ckpts/v-express/denoising_unet.bin') 33 | parser.add_argument('--reference_net_path', type=str, default='./model_ckpts/v-express/reference_net.bin') 34 | parser.add_argument('--v_kps_guider_path', type=str, default='./model_ckpts/v-express/v_kps_guider.bin') 35 | parser.add_argument('--audio_projection_path', type=str, default='./model_ckpts/v-express/audio_projection.bin') 36 | parser.add_argument('--motion_module_path', type=str, default='./model_ckpts/v-express/motion_module.bin') 37 | 38 | parser.add_argument('--retarget_strategy', type=str, default='fix_face', 39 | help='{fix_face, no_retarget, offset_retarget, naive_retarget}') 40 | 41 | parser.add_argument('--dtype', type=str, default='fp16') 42 | parser.add_argument('--device', type=str, default='cuda') 43 | parser.add_argument('--gpu_id', type=int, default=0) 44 | parser.add_argument('--do_multi_devices_inference', action='store_true') 45 | parser.add_argument('--save_gpu_memory', action='store_true') 46 | 47 | parser.add_argument('--num_pad_audio_frames', type=int, default=2) 48 | parser.add_argument('--standard_audio_sampling_rate', type=int, default=16000) 49 | 50 | parser.add_argument('--reference_image_path', type=str, default='./test_samples/emo/talk_emotion/ref.jpg') 51 | parser.add_argument('--audio_path', type=str, default='./test_samples/emo/talk_emotion/aud.mp3') 52 | parser.add_argument('--kps_path', type=str, default='./test_samples/emo/talk_emotion/kps.pth') 53 | parser.add_argument('--output_path', type=str, default='./output/emo/talk_emotion.mp4') 54 | 55 | parser.add_argument('--image_width', type=int, default=512) 56 | parser.add_argument('--image_height', type=int, default=512) 57 | parser.add_argument('--fps', type=float, default=30.0) 58 | parser.add_argument('--seed', type=int, default=42) 59 | parser.add_argument('--num_inference_steps', type=int, default=25) 60 | parser.add_argument('--guidance_scale', type=float, default=3.5) 61 | parser.add_argument('--context_frames', type=int, default=12) 62 | parser.add_argument('--context_overlap', type=int, default=4) 63 | parser.add_argument('--reference_attention_weight', default=0.95, type=float) 64 | parser.add_argument('--audio_attention_weight', default=3., type=float) 65 | 66 | args = parser.parse_args() 67 | 68 | return args 69 | 70 | 71 | def load_reference_net(unet_config_path, reference_net_path, dtype, device): 72 | reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device) 73 | reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False) 74 | print(f'Loaded weights of Reference Net from {reference_net_path}.') 75 | return reference_net 76 | 77 | 78 | def load_denoising_unet(inf_config_path, unet_config_path, denoising_unet_path, motion_module_path, dtype, device): 79 | inference_config = OmegaConf.load(inf_config_path) 80 | denoising_unet = UNet3DConditionModel.from_config_2d( 81 | unet_config_path, 82 | unet_additional_kwargs=inference_config.unet_additional_kwargs, 83 | ).to(dtype=dtype, device=device) 84 | denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) 85 | print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.') 86 | 87 | denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False) 88 | print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.') 89 | 90 | return denoising_unet 91 | 92 | 93 | def load_v_kps_guider(v_kps_guider_path, dtype, device): 94 | v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device) 95 | v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu")) 96 | print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.') 97 | return v_kps_guider 98 | 99 | 100 | def load_audio_projection( 101 | audio_projection_path, 102 | dtype, 103 | device, 104 | inp_dim: int, 105 | mid_dim: int, 106 | out_dim: int, 107 | inp_seq_len: int, 108 | out_seq_len: int, 109 | ): 110 | audio_projection = AudioProjection( 111 | dim=mid_dim, 112 | depth=4, 113 | dim_head=64, 114 | heads=12, 115 | num_queries=out_seq_len, 116 | embedding_dim=inp_dim, 117 | output_dim=out_dim, 118 | ff_mult=4, 119 | max_seq_len=inp_seq_len, 120 | ).to(dtype=dtype, device=device) 121 | audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu')) 122 | print(f'Loaded weights of Audio Projection from {audio_projection_path}.') 123 | return audio_projection 124 | 125 | 126 | def get_scheduler(inference_config_path): 127 | inference_config = OmegaConf.load(inference_config_path) 128 | scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs) 129 | scheduler = DDIMScheduler(**scheduler_kwargs) 130 | return scheduler 131 | 132 | 133 | def main(): 134 | args = parse_args() 135 | start_time = time.time() 136 | 137 | if not args.do_multi_devices_inference: 138 | # TODO 139 | accelerator = None 140 | device = torch.device(f'{args.device}:{args.gpu_id}' if args.device == 'cuda' else args.device) 141 | else: 142 | accelerator = accelerate.Accelerator() 143 | device = torch.device(f'cuda:{accelerator.process_index}') 144 | dtype = torch.float16 if args.dtype == 'fp16' else torch.float32 145 | 146 | vae_path = args.vae_path 147 | audio_encoder_path = args.audio_encoder_path 148 | 149 | vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device) 150 | audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device) 151 | audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path) 152 | 153 | unet_config_path = args.unet_config_path 154 | reference_net_path = args.reference_net_path 155 | denoising_unet_path = args.denoising_unet_path 156 | v_kps_guider_path = args.v_kps_guider_path 157 | audio_projection_path = args.audio_projection_path 158 | motion_module_path = args.motion_module_path 159 | 160 | inference_config_path = './inference_v2.yaml' 161 | scheduler = get_scheduler(inference_config_path) 162 | reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device) 163 | denoising_unet = load_denoising_unet( 164 | inference_config_path, unet_config_path, denoising_unet_path, motion_module_path, 165 | dtype, device 166 | ) 167 | v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device) 168 | audio_projection = load_audio_projection( 169 | audio_projection_path, 170 | dtype, 171 | device, 172 | inp_dim=denoising_unet.config.cross_attention_dim, 173 | mid_dim=denoising_unet.config.cross_attention_dim, 174 | out_dim=denoising_unet.config.cross_attention_dim, 175 | inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1), 176 | out_seq_len=2 * args.num_pad_audio_frames + 1, 177 | ) 178 | 179 | if is_xformers_available(): 180 | reference_net.enable_xformers_memory_efficient_attention() 181 | denoising_unet.enable_xformers_memory_efficient_attention() 182 | else: 183 | raise ValueError("xformers is not available. Make sure it is installed correctly") 184 | 185 | generator = torch.manual_seed(args.seed) 186 | pipeline = VExpressPipeline( 187 | vae=vae, 188 | reference_net=reference_net, 189 | denoising_unet=denoising_unet, 190 | v_kps_guider=v_kps_guider, 191 | audio_processor=audio_processor, 192 | audio_encoder=audio_encoder, 193 | audio_projection=audio_projection, 194 | scheduler=scheduler, 195 | ).to(dtype=dtype, device=device) 196 | 197 | app = FaceAnalysis( 198 | providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'], 199 | provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [], 200 | root=args.insightface_model_path, 201 | ) 202 | app.prepare(ctx_id=0, det_size=(args.image_height, args.image_width)) 203 | 204 | reference_image = Image.open(args.reference_image_path).convert('RGB') 205 | reference_image = reference_image.resize((args.image_height, args.image_width)) 206 | 207 | reference_image_for_kps = cv2.imread(args.reference_image_path) 208 | reference_image_for_kps = cv2.resize(reference_image_for_kps, (args.image_width, args.image_height)) 209 | reference_kps = app.get(reference_image_for_kps)[0].kps[:3] 210 | if args.save_gpu_memory: 211 | del app 212 | torch.cuda.empty_cache() 213 | 214 | _, audio_waveform, meta_info = torchvision.io.read_video(args.audio_path, pts_unit='sec') 215 | audio_sampling_rate = meta_info['audio_fps'] 216 | print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.') 217 | if audio_sampling_rate != args.standard_audio_sampling_rate: 218 | audio_waveform = torchaudio.functional.resample( 219 | audio_waveform, 220 | orig_freq=audio_sampling_rate, 221 | new_freq=args.standard_audio_sampling_rate, 222 | ) 223 | audio_waveform = audio_waveform.mean(dim=0) 224 | 225 | duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate 226 | init_video_length = int(duration * args.fps) 227 | num_contexts = np.around((init_video_length + args.context_overlap) / args.context_frames) 228 | video_length = int(num_contexts * args.context_frames - args.context_overlap) 229 | fps = video_length / duration 230 | print(f'The corresponding video length is {video_length}.') 231 | 232 | kps_sequence = None 233 | if args.kps_path != "": 234 | assert os.path.exists(args.kps_path), f'{args.kps_path} does not exist' 235 | kps_sequence = torch.tensor(torch.load(args.kps_path)) # [len, 3, 2] 236 | print(f'The original length of kps sequence is {kps_sequence.shape[0]}.') 237 | 238 | if kps_sequence.shape[0] > video_length: 239 | kps_sequence = kps_sequence[:video_length, :, :] 240 | 241 | kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear') 242 | kps_sequence = kps_sequence.permute(2, 0, 1) 243 | print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.') 244 | 245 | retarget_strategy = args.retarget_strategy 246 | if retarget_strategy == 'fix_face': 247 | kps_sequence = torch.tensor([reference_kps] * video_length) 248 | elif retarget_strategy == 'no_retarget': 249 | kps_sequence = kps_sequence 250 | elif retarget_strategy == 'offset_retarget': 251 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True) 252 | elif retarget_strategy == 'naive_retarget': 253 | kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False) 254 | else: 255 | raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.') 256 | 257 | kps_images = [] 258 | for i in range(video_length): 259 | kps_image = draw_kps_image(args.image_height, args.image_width, kps_sequence[i]) 260 | kps_images.append(Image.fromarray(kps_image)) 261 | 262 | video_tensor = pipeline( 263 | reference_image=reference_image, 264 | kps_images=kps_images, 265 | audio_waveform=audio_waveform, 266 | width=args.image_width, 267 | height=args.image_height, 268 | video_length=video_length, 269 | num_inference_steps=args.num_inference_steps, 270 | guidance_scale=args.guidance_scale, 271 | context_frames=args.context_frames, 272 | context_overlap=args.context_overlap, 273 | reference_attention_weight=args.reference_attention_weight, 274 | audio_attention_weight=args.audio_attention_weight, 275 | num_pad_audio_frames=args.num_pad_audio_frames, 276 | generator=generator, 277 | do_multi_devices_inference=args.do_multi_devices_inference, 278 | save_gpu_memory=args.save_gpu_memory, 279 | ) 280 | 281 | if accelerator is None or accelerator.is_main_process: 282 | save_video(video_tensor, args.audio_path, args.output_path, device, fps) 283 | consumed_time = time.time() - start_time 284 | generation_fps = video_tensor.shape[2] / consumed_time 285 | print(f'The generated video has been saved at {args.output_path}. ' 286 | f'The generation time is {consumed_time:.1f} seconds. ' 287 | f'The generation FPS is {generation_fps:.2f}.') 288 | 289 | 290 | if __name__ == '__main__': 291 | main() 292 | -------------------------------------------------------------------------------- /src/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "scaled_linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM 36 | -------------------------------------------------------------------------------- /src/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/src/modules/.DS_Store -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_2d_condition import UNet2DConditionModel 2 | from .unet_3d import UNet3DConditionModel 3 | from .v_kps_guider import VKpsGuider 4 | from .audio_projection import AudioProjection 5 | from .mutual_self_attention import ReferenceAttentionControl 6 | -------------------------------------------------------------------------------- /src/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, Attention, FeedForward, GatedSelfAttentionDense 7 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 8 | from einops import rearrange 9 | from torch import nn 10 | 11 | 12 | class BasicTransformerBlock(nn.Module): 13 | r""" 14 | A basic Transformer block. 15 | 16 | Parameters: 17 | dim (`int`): The number of channels in the input and output. 18 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 19 | attention_head_dim (`int`): The number of channels in each head. 20 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 21 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 22 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 23 | num_embeds_ada_norm (: 24 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 25 | attention_bias (: 26 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 27 | only_cross_attention (`bool`, *optional*): 28 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 29 | double_self_attention (`bool`, *optional*): 30 | Whether to use two self-attention layers. In this case no cross attention layers are used. 31 | upcast_attention (`bool`, *optional*): 32 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 33 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 34 | Whether to use learnable elementwise affine parameters for normalization. 35 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 36 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 37 | final_dropout (`bool` *optional*, defaults to False): 38 | Whether to apply a final dropout after the last feed-forward layer. 39 | attention_type (`str`, *optional*, defaults to `"default"`): 40 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 41 | positional_embeddings (`str`, *optional*, defaults to `None`): 42 | The type of positional embeddings to apply to. 43 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 44 | The maximum number of positional embeddings to apply. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | dim: int, 50 | num_attention_heads: int, 51 | attention_head_dim: int, 52 | dropout=0.0, 53 | cross_attention_dim: Optional[int] = None, 54 | activation_fn: str = "geglu", 55 | num_embeds_ada_norm: Optional[int] = None, 56 | attention_bias: bool = False, 57 | only_cross_attention: bool = False, 58 | double_self_attention: bool = False, 59 | upcast_attention: bool = False, 60 | norm_elementwise_affine: bool = True, 61 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 62 | norm_eps: float = 1e-5, 63 | final_dropout: bool = False, 64 | attention_type: str = "default", 65 | positional_embeddings: Optional[str] = None, 66 | num_positional_embeddings: Optional[int] = None, 67 | ): 68 | super().__init__() 69 | self.only_cross_attention = only_cross_attention 70 | 71 | self.use_ada_layer_norm_zero = ( 72 | num_embeds_ada_norm is not None 73 | ) and norm_type == "ada_norm_zero" 74 | self.use_ada_layer_norm = ( 75 | num_embeds_ada_norm is not None 76 | ) and norm_type == "ada_norm" 77 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 78 | self.use_layer_norm = norm_type == "layer_norm" 79 | 80 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 81 | raise ValueError( 82 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 83 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 84 | ) 85 | 86 | if positional_embeddings and (num_positional_embeddings is None): 87 | raise ValueError( 88 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 89 | ) 90 | 91 | if positional_embeddings == "sinusoidal": 92 | self.pos_embed = SinusoidalPositionalEmbedding( 93 | dim, max_seq_length=num_positional_embeddings 94 | ) 95 | else: 96 | self.pos_embed = None 97 | 98 | # Define 3 blocks. Each block has its own normalization layer. 99 | # 1. Self-Attn 100 | if self.use_ada_layer_norm: 101 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 102 | elif self.use_ada_layer_norm_zero: 103 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 104 | else: 105 | self.norm1 = nn.LayerNorm( 106 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 107 | ) 108 | 109 | self.attn1 = Attention( 110 | query_dim=dim, 111 | heads=num_attention_heads, 112 | dim_head=attention_head_dim, 113 | dropout=dropout, 114 | bias=attention_bias, 115 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 116 | upcast_attention=upcast_attention, 117 | ) 118 | 119 | # 2. Cross-Attn 120 | if cross_attention_dim is not None or double_self_attention: 121 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 122 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 123 | # the second cross attention block. 124 | self.norm2 = ( 125 | AdaLayerNorm(dim, num_embeds_ada_norm) 126 | if self.use_ada_layer_norm 127 | else nn.LayerNorm( 128 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 129 | ) 130 | ) 131 | self.attn2 = Attention( 132 | query_dim=dim, 133 | cross_attention_dim=cross_attention_dim 134 | if not double_self_attention 135 | else None, 136 | heads=num_attention_heads, 137 | dim_head=attention_head_dim, 138 | dropout=dropout, 139 | bias=attention_bias, 140 | upcast_attention=upcast_attention, 141 | ) # is self-attn if encoder_hidden_states is none 142 | else: 143 | self.norm2 = None 144 | self.attn2 = None 145 | 146 | # 3. Feed-forward 147 | if not self.use_ada_layer_norm_single: 148 | self.norm3 = nn.LayerNorm( 149 | dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps 150 | ) 151 | 152 | self.ff = FeedForward( 153 | dim, 154 | dropout=dropout, 155 | activation_fn=activation_fn, 156 | final_dropout=final_dropout, 157 | ) 158 | 159 | # 4. Fuser 160 | if attention_type == "gated" or attention_type == "gated-text-image": 161 | self.fuser = GatedSelfAttentionDense( 162 | dim, cross_attention_dim, num_attention_heads, attention_head_dim 163 | ) 164 | 165 | # 5. Scale-shift for PixArt-Alpha. 166 | if self.use_ada_layer_norm_single: 167 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 168 | 169 | # let chunk size default to None 170 | self._chunk_size = None 171 | self._chunk_dim = 0 172 | 173 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 174 | # Sets chunk feed-forward 175 | self._chunk_size = chunk_size 176 | self._chunk_dim = dim 177 | 178 | def forward( 179 | self, 180 | hidden_states: torch.FloatTensor, 181 | attention_mask: Optional[torch.FloatTensor] = None, 182 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 183 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 184 | timestep: Optional[torch.LongTensor] = None, 185 | cross_attention_kwargs: Dict[str, Any] = None, 186 | class_labels: Optional[torch.LongTensor] = None, 187 | ) -> torch.FloatTensor: 188 | # Notice that normalization is always applied before the real computation in the following blocks. 189 | # 0. Self-Attention 190 | batch_size = hidden_states.shape[0] 191 | 192 | if self.use_ada_layer_norm: 193 | norm_hidden_states = self.norm1(hidden_states, timestep) 194 | elif self.use_ada_layer_norm_zero: 195 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 196 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 197 | ) 198 | elif self.use_layer_norm: 199 | norm_hidden_states = self.norm1(hidden_states) 200 | elif self.use_ada_layer_norm_single: 201 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 202 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 203 | ).chunk(6, dim=1) 204 | norm_hidden_states = self.norm1(hidden_states) 205 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 206 | norm_hidden_states = norm_hidden_states.squeeze(1) 207 | else: 208 | raise ValueError("Incorrect norm used") 209 | 210 | if self.pos_embed is not None: 211 | norm_hidden_states = self.pos_embed(norm_hidden_states) 212 | 213 | # 1. Retrieve lora scale. 214 | lora_scale = ( 215 | cross_attention_kwargs.get("scale", 1.0) 216 | if cross_attention_kwargs is not None 217 | else 1.0 218 | ) 219 | 220 | # 2. Prepare GLIGEN inputs 221 | cross_attention_kwargs = ( 222 | cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 223 | ) 224 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 225 | 226 | attn_output = self.attn1( 227 | norm_hidden_states, 228 | encoder_hidden_states=encoder_hidden_states 229 | if self.only_cross_attention 230 | else None, 231 | attention_mask=attention_mask, 232 | **cross_attention_kwargs, 233 | ) 234 | if self.use_ada_layer_norm_zero: 235 | attn_output = gate_msa.unsqueeze(1) * attn_output 236 | elif self.use_ada_layer_norm_single: 237 | attn_output = gate_msa * attn_output 238 | 239 | hidden_states = attn_output + hidden_states 240 | if hidden_states.ndim == 4: 241 | hidden_states = hidden_states.squeeze(1) 242 | 243 | # 2.5 GLIGEN Control 244 | if gligen_kwargs is not None: 245 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 246 | 247 | # 3. Cross-Attention 248 | if self.attn2 is not None: 249 | if self.use_ada_layer_norm: 250 | norm_hidden_states = self.norm2(hidden_states, timestep) 251 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 252 | norm_hidden_states = self.norm2(hidden_states) 253 | elif self.use_ada_layer_norm_single: 254 | # For PixArt norm2 isn't applied here: 255 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 256 | norm_hidden_states = hidden_states 257 | else: 258 | raise ValueError("Incorrect norm") 259 | 260 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 261 | norm_hidden_states = self.pos_embed(norm_hidden_states) 262 | 263 | attn_output = self.attn2( 264 | norm_hidden_states, 265 | encoder_hidden_states=encoder_hidden_states, 266 | attention_mask=encoder_attention_mask, 267 | **cross_attention_kwargs, 268 | ) 269 | hidden_states = attn_output + hidden_states 270 | 271 | # 4. Feed-forward 272 | if not self.use_ada_layer_norm_single: 273 | norm_hidden_states = self.norm3(hidden_states) 274 | 275 | if self.use_ada_layer_norm_zero: 276 | norm_hidden_states = ( 277 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 278 | ) 279 | 280 | if self.use_ada_layer_norm_single: 281 | norm_hidden_states = self.norm2(hidden_states) 282 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 283 | 284 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 285 | 286 | if self.use_ada_layer_norm_zero: 287 | ff_output = gate_mlp.unsqueeze(1) * ff_output 288 | elif self.use_ada_layer_norm_single: 289 | ff_output = gate_mlp * ff_output 290 | 291 | hidden_states = ff_output + hidden_states 292 | if hidden_states.ndim == 4: 293 | hidden_states = hidden_states.squeeze(1) 294 | 295 | return hidden_states 296 | 297 | 298 | class TemporalBasicTransformerBlock(nn.Module): 299 | def __init__( 300 | self, 301 | dim: int, 302 | num_attention_heads: int, 303 | attention_head_dim: int, 304 | dropout=0.0, 305 | cross_attention_dim: Optional[int] = None, 306 | activation_fn: str = "geglu", 307 | num_embeds_ada_norm: Optional[int] = None, 308 | attention_bias: bool = False, 309 | only_cross_attention: bool = False, 310 | upcast_attention: bool = False, 311 | unet_use_cross_frame_attention=None, 312 | unet_use_temporal_attention=None, 313 | ): 314 | super().__init__() 315 | self.only_cross_attention = only_cross_attention 316 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 317 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 318 | self.unet_use_temporal_attention = unet_use_temporal_attention 319 | 320 | # old self attention layer for only self-attention 321 | self.attn1 = Attention( 322 | query_dim=dim, 323 | heads=num_attention_heads, 324 | dim_head=attention_head_dim, 325 | dropout=dropout, 326 | bias=attention_bias, 327 | upcast_attention=upcast_attention, 328 | ) 329 | self.norm1 = ( 330 | AdaLayerNorm(dim, num_embeds_ada_norm) 331 | if self.use_ada_layer_norm 332 | else nn.LayerNorm(dim) 333 | ) 334 | 335 | # new self attention layer for reference features 336 | self.attn1_5 = Attention( 337 | query_dim=dim, 338 | heads=num_attention_heads, 339 | dim_head=attention_head_dim, 340 | dropout=dropout, 341 | bias=attention_bias, 342 | upcast_attention=upcast_attention, 343 | ) 344 | self.norm1_5 = ( 345 | AdaLayerNorm(dim, num_embeds_ada_norm) 346 | if self.use_ada_layer_norm 347 | else nn.LayerNorm(dim) 348 | ) 349 | 350 | # Cross-Attn 351 | if cross_attention_dim is not None: 352 | self.attn2 = Attention( 353 | query_dim=dim, 354 | cross_attention_dim=cross_attention_dim, 355 | heads=num_attention_heads, 356 | dim_head=attention_head_dim, 357 | dropout=dropout, 358 | bias=attention_bias, 359 | upcast_attention=upcast_attention, 360 | ) 361 | else: 362 | self.attn2 = None 363 | 364 | if cross_attention_dim is not None: 365 | self.norm2 = ( 366 | AdaLayerNorm(dim, num_embeds_ada_norm) 367 | if self.use_ada_layer_norm 368 | else nn.LayerNorm(dim) 369 | ) 370 | else: 371 | self.norm2 = None 372 | 373 | # Feed-forward 374 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 375 | self.norm3 = nn.LayerNorm(dim) 376 | self.use_ada_layer_norm_zero = False 377 | 378 | # Temp-Attn 379 | assert unet_use_temporal_attention is not None 380 | if unet_use_temporal_attention: 381 | self.attn_temp = Attention( 382 | query_dim=dim, 383 | heads=num_attention_heads, 384 | dim_head=attention_head_dim, 385 | dropout=dropout, 386 | bias=attention_bias, 387 | upcast_attention=upcast_attention, 388 | ) 389 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 390 | self.norm_temp = ( 391 | AdaLayerNorm(dim, num_embeds_ada_norm) 392 | if self.use_ada_layer_norm 393 | else nn.LayerNorm(dim) 394 | ) 395 | 396 | def forward( 397 | self, 398 | hidden_states, 399 | encoder_hidden_states=None, 400 | timestep=None, 401 | attention_mask=None, 402 | video_length=None, 403 | ): 404 | norm_hidden_states = ( 405 | self.norm1(hidden_states, timestep) 406 | if self.use_ada_layer_norm 407 | else self.norm1(hidden_states) 408 | ) 409 | 410 | if self.unet_use_cross_frame_attention: 411 | hidden_states = ( 412 | self.attn1( 413 | norm_hidden_states, 414 | attention_mask=attention_mask, 415 | video_length=video_length, 416 | ) 417 | + hidden_states 418 | ) 419 | else: 420 | hidden_states = ( 421 | self.attn1(norm_hidden_states, attention_mask=attention_mask) 422 | + hidden_states 423 | ) 424 | 425 | norm_hidden_states = ( 426 | self.norm1_5(hidden_states, timestep) 427 | if self.use_ada_layer_norm 428 | else self.norm1_5(hidden_states) 429 | ) 430 | 431 | if self.unet_use_cross_frame_attention: 432 | hidden_states = ( 433 | self.attn1_5( 434 | norm_hidden_states, 435 | attention_mask=attention_mask, 436 | video_length=video_length, 437 | ) 438 | + hidden_states 439 | ) 440 | else: 441 | hidden_states = ( 442 | self.attn1_5(norm_hidden_states, attention_mask=attention_mask) 443 | + hidden_states 444 | ) 445 | 446 | if self.attn2 is not None: 447 | # Cross-Attention 448 | norm_hidden_states = ( 449 | self.norm2(hidden_states, timestep) 450 | if self.use_ada_layer_norm 451 | else self.norm2(hidden_states) 452 | ) 453 | hidden_states = ( 454 | self.attn2( 455 | norm_hidden_states, 456 | encoder_hidden_states=encoder_hidden_states, 457 | attention_mask=attention_mask, 458 | ) 459 | + hidden_states 460 | ) 461 | 462 | # Feed-forward 463 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 464 | 465 | # Temporal-Attention 466 | if self.unet_use_temporal_attention: 467 | d = hidden_states.shape[1] 468 | hidden_states = rearrange( 469 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 470 | ) 471 | norm_hidden_states = ( 472 | self.norm_temp(hidden_states, timestep) 473 | if self.use_ada_layer_norm 474 | else self.norm_temp(hidden_states) 475 | ) 476 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 477 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 478 | 479 | return hidden_states 480 | 481 | class TemporalBasicTransformerBlockOld(nn.Module): 482 | def __init__( 483 | self, 484 | dim: int, 485 | num_attention_heads: int, 486 | attention_head_dim: int, 487 | dropout=0.0, 488 | cross_attention_dim: Optional[int] = None, 489 | activation_fn: str = "geglu", 490 | num_embeds_ada_norm: Optional[int] = None, 491 | attention_bias: bool = False, 492 | only_cross_attention: bool = False, 493 | upcast_attention: bool = False, 494 | unet_use_cross_frame_attention=None, 495 | unet_use_temporal_attention=None, 496 | ): 497 | super().__init__() 498 | self.only_cross_attention = only_cross_attention 499 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 500 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 501 | self.unet_use_temporal_attention = unet_use_temporal_attention 502 | 503 | # SC-Attn 504 | self.attn1 = Attention( 505 | query_dim=dim, 506 | heads=num_attention_heads, 507 | dim_head=attention_head_dim, 508 | dropout=dropout, 509 | bias=attention_bias, 510 | upcast_attention=upcast_attention, 511 | ) 512 | self.norm1 = ( 513 | AdaLayerNorm(dim, num_embeds_ada_norm) 514 | if self.use_ada_layer_norm 515 | else nn.LayerNorm(dim) 516 | ) 517 | 518 | # Cross-Attn 519 | if cross_attention_dim is not None: 520 | self.attn2 = Attention( 521 | query_dim=dim, 522 | cross_attention_dim=cross_attention_dim, 523 | heads=num_attention_heads, 524 | dim_head=attention_head_dim, 525 | dropout=dropout, 526 | bias=attention_bias, 527 | upcast_attention=upcast_attention, 528 | ) 529 | else: 530 | self.attn2 = None 531 | 532 | if cross_attention_dim is not None: 533 | self.norm2 = ( 534 | AdaLayerNorm(dim, num_embeds_ada_norm) 535 | if self.use_ada_layer_norm 536 | else nn.LayerNorm(dim) 537 | ) 538 | else: 539 | self.norm2 = None 540 | 541 | # Feed-forward 542 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 543 | self.norm3 = nn.LayerNorm(dim) 544 | self.use_ada_layer_norm_zero = False 545 | 546 | # Temp-Attn 547 | assert unet_use_temporal_attention is not None 548 | if unet_use_temporal_attention: 549 | self.attn_temp = Attention( 550 | query_dim=dim, 551 | heads=num_attention_heads, 552 | dim_head=attention_head_dim, 553 | dropout=dropout, 554 | bias=attention_bias, 555 | upcast_attention=upcast_attention, 556 | ) 557 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 558 | self.norm_temp = ( 559 | AdaLayerNorm(dim, num_embeds_ada_norm) 560 | if self.use_ada_layer_norm 561 | else nn.LayerNorm(dim) 562 | ) 563 | 564 | def forward( 565 | self, 566 | hidden_states, 567 | encoder_hidden_states=None, 568 | timestep=None, 569 | attention_mask=None, 570 | video_length=None, 571 | ): 572 | norm_hidden_states = ( 573 | self.norm1(hidden_states, timestep) 574 | if self.use_ada_layer_norm 575 | else self.norm1(hidden_states) 576 | ) 577 | 578 | if self.unet_use_cross_frame_attention: 579 | hidden_states = ( 580 | self.attn1( 581 | norm_hidden_states, 582 | attention_mask=attention_mask, 583 | video_length=video_length, 584 | ) 585 | + hidden_states 586 | ) 587 | else: 588 | hidden_states = ( 589 | self.attn1(norm_hidden_states, attention_mask=attention_mask) 590 | + hidden_states 591 | ) 592 | 593 | if self.attn2 is not None: 594 | # Cross-Attention 595 | norm_hidden_states = ( 596 | self.norm2(hidden_states, timestep) 597 | if self.use_ada_layer_norm 598 | else self.norm2(hidden_states) 599 | ) 600 | hidden_states = ( 601 | self.attn2( 602 | norm_hidden_states, 603 | encoder_hidden_states=encoder_hidden_states, 604 | attention_mask=attention_mask, 605 | ) 606 | + hidden_states 607 | ) 608 | 609 | # Feed-forward 610 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 611 | 612 | # Temporal-Attention 613 | if self.unet_use_temporal_attention: 614 | d = hidden_states.shape[1] 615 | hidden_states = rearrange( 616 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 617 | ) 618 | norm_hidden_states = ( 619 | self.norm_temp(hidden_states, timestep) 620 | if self.use_ada_layer_norm 621 | else self.norm_temp(hidden_states) 622 | ) 623 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 624 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 625 | 626 | return hidden_states -------------------------------------------------------------------------------- /src/modules/audio_projection.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.modeling_utils import ModelMixin 6 | from einops import rearrange 7 | from einops.layers.torch import Rearrange 8 | 9 | 10 | def reshape_tensor(x, heads): 11 | bs, length, width = x.shape 12 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 13 | x = x.view(bs, length, heads, -1) 14 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 15 | x = x.transpose(1, 2) 16 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 17 | x = x.reshape(bs, heads, length, -1) 18 | return x 19 | 20 | 21 | def masked_mean(t, *, dim, mask=None): 22 | if mask is None: 23 | return t.mean(dim=dim) 24 | 25 | denom = mask.sum(dim=dim, keepdim=True) 26 | mask = rearrange(mask, "b n -> b n 1") 27 | masked_t = t.masked_fill(~mask, 0.0) 28 | 29 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 30 | 31 | 32 | class PerceiverAttention(nn.Module): 33 | def __init__(self, *, dim, dim_head=64, heads=8): 34 | super().__init__() 35 | self.scale = dim_head ** -0.5 36 | self.dim_head = dim_head 37 | self.heads = heads 38 | inner_dim = dim_head * heads 39 | 40 | self.norm1 = nn.LayerNorm(dim) 41 | self.norm2 = nn.LayerNorm(dim) 42 | 43 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 44 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 45 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 46 | 47 | def forward(self, x, latents): 48 | """ 49 | Args: 50 | x (torch.Tensor): image features 51 | shape (b, n1, D) 52 | latent (torch.Tensor): latent features 53 | shape (b, n2, D) 54 | """ 55 | x = self.norm1(x) 56 | latents = self.norm2(latents) 57 | 58 | b, l, _ = latents.shape 59 | 60 | q = self.to_q(latents) 61 | kv_input = torch.cat((x, latents), dim=-2) 62 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 63 | 64 | q = reshape_tensor(q, self.heads) 65 | k = reshape_tensor(k, self.heads) 66 | v = reshape_tensor(v, self.heads) 67 | 68 | # attention 69 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 70 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 71 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 72 | out = weight @ v 73 | 74 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 75 | 76 | return self.to_out(out) 77 | 78 | 79 | def FeedForward(dim, mult=4): 80 | inner_dim = int(dim * mult) 81 | return nn.Sequential( 82 | nn.LayerNorm(dim), 83 | nn.Linear(dim, inner_dim, bias=False), 84 | nn.GELU(), 85 | nn.Linear(inner_dim, dim, bias=False), 86 | ) 87 | 88 | 89 | class AudioProjection(ModelMixin): 90 | def __init__( 91 | self, 92 | dim=1024, 93 | depth=8, 94 | dim_head=64, 95 | heads=16, 96 | num_queries=8, 97 | embedding_dim=768, 98 | output_dim=1024, 99 | ff_mult=4, 100 | max_seq_len: int = 257, 101 | num_latents_mean_pooled: int = 0, 102 | ): 103 | super().__init__() 104 | 105 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) 106 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) 107 | 108 | self.proj_in = nn.Linear(embedding_dim, dim) 109 | 110 | self.proj_out = nn.Linear(dim, output_dim) 111 | self.norm_out = nn.LayerNorm(output_dim) 112 | 113 | self.to_latents_from_mean_pooled_seq = ( 114 | nn.Sequential( 115 | nn.LayerNorm(dim), 116 | nn.Linear(dim, dim * num_latents_mean_pooled), 117 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 118 | ) 119 | if num_latents_mean_pooled > 0 120 | else None 121 | ) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append(nn.ModuleList([ 126 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 127 | FeedForward(dim=dim, mult=ff_mult), 128 | ])) 129 | 130 | def forward(self, x): 131 | if self.pos_emb is not None: 132 | n, device = x.shape[1], x.device 133 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 134 | x = x + pos_emb 135 | 136 | latents = self.latents.repeat(x.size(0), 1, 1) 137 | 138 | x = self.proj_in(x) 139 | 140 | if self.to_latents_from_mean_pooled_seq: 141 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 142 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 143 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 144 | 145 | for attn, ff in self.layers: 146 | latents = attn(x, latents) + latents 147 | latents = ff(latents) + latents 148 | 149 | latents = self.proj_out(latents) 150 | return self.norm_out(latents) 151 | -------------------------------------------------------------------------------- /src/modules/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /src/modules/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from .attention import BasicTransformerBlock 8 | from .attention import TemporalBasicTransformerBlock 9 | 10 | 11 | def torch_dfs(model: torch.nn.Module): 12 | result = [model] 13 | for child in model.children(): 14 | result += torch_dfs(child) 15 | return result 16 | 17 | 18 | class ReferenceAttentionControl: 19 | def __init__( 20 | self, 21 | unet, 22 | mode="write", 23 | do_classifier_free_guidance=False, 24 | attention_auto_machine_weight=float("inf"), 25 | gn_auto_machine_weight=1.0, 26 | style_fidelity=1.0, 27 | reference_attn=True, 28 | reference_adain=False, 29 | fusion_blocks="midup", 30 | batch_size=1, 31 | reference_attention_weight=1., 32 | audio_attention_weight=1., 33 | ) -> None: 34 | # 10. Modify self attention and group norm 35 | self.unet = unet 36 | assert mode in ["read", "write"] 37 | assert fusion_blocks in ["midup", "full"] 38 | self.reference_attn = reference_attn 39 | self.reference_adain = reference_adain 40 | self.fusion_blocks = fusion_blocks 41 | self.reference_attention_weight = reference_attention_weight 42 | self.audio_attention_weight = audio_attention_weight 43 | self.register_reference_hooks( 44 | mode, 45 | do_classifier_free_guidance, 46 | attention_auto_machine_weight, 47 | gn_auto_machine_weight, 48 | style_fidelity, 49 | reference_attn, 50 | reference_adain, 51 | fusion_blocks, 52 | batch_size=batch_size, 53 | ) 54 | 55 | def register_reference_hooks( 56 | self, 57 | mode, 58 | do_classifier_free_guidance, 59 | attention_auto_machine_weight, 60 | gn_auto_machine_weight, 61 | style_fidelity, 62 | reference_attn, 63 | reference_adain, 64 | dtype=torch.float16, 65 | batch_size=1, 66 | num_images_per_prompt=1, 67 | device=torch.device("cpu"), 68 | fusion_blocks="midup", 69 | ): 70 | MODE = mode 71 | do_classifier_free_guidance = do_classifier_free_guidance 72 | attention_auto_machine_weight = attention_auto_machine_weight 73 | gn_auto_machine_weight = gn_auto_machine_weight 74 | style_fidelity = style_fidelity 75 | reference_attn = reference_attn 76 | reference_adain = reference_adain 77 | fusion_blocks = fusion_blocks 78 | num_images_per_prompt = num_images_per_prompt 79 | reference_attention_weight = self.reference_attention_weight 80 | audio_attention_weight = self.audio_attention_weight 81 | dtype = dtype 82 | if do_classifier_free_guidance: 83 | uc_mask = ( 84 | torch.Tensor( 85 | [1] * batch_size * num_images_per_prompt * 16 86 | + [0] * batch_size * num_images_per_prompt * 16 87 | ) 88 | .to(device) 89 | .bool() 90 | ) 91 | else: 92 | uc_mask = ( 93 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 94 | .to(device) 95 | .bool() 96 | ) 97 | 98 | def hacked_basic_transformer_inner_forward( 99 | self, 100 | hidden_states: torch.FloatTensor, 101 | attention_mask: Optional[torch.FloatTensor] = None, 102 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 103 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 104 | timestep: Optional[torch.LongTensor] = None, 105 | cross_attention_kwargs: Dict[str, Any] = None, 106 | class_labels: Optional[torch.LongTensor] = None, 107 | video_length=None, 108 | ): 109 | if self.use_ada_layer_norm: # False 110 | norm_hidden_states = self.norm1(hidden_states, timestep) 111 | elif self.use_ada_layer_norm_zero: 112 | ( 113 | norm_hidden_states, 114 | gate_msa, 115 | shift_mlp, 116 | scale_mlp, 117 | gate_mlp, 118 | ) = self.norm1( 119 | hidden_states, 120 | timestep, 121 | class_labels, 122 | hidden_dtype=hidden_states.dtype, 123 | ) 124 | else: 125 | norm_hidden_states = self.norm1(hidden_states) 126 | 127 | # 1. Self-Attention 128 | # self.only_cross_attention = False 129 | cross_attention_kwargs = ( 130 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 131 | ) 132 | if self.only_cross_attention: 133 | attn_output = self.attn1( 134 | norm_hidden_states, 135 | encoder_hidden_states=encoder_hidden_states 136 | if self.only_cross_attention 137 | else None, 138 | attention_mask=attention_mask, 139 | **cross_attention_kwargs, 140 | ) 141 | else: 142 | if MODE == "write": 143 | attn_output = self.attn1( 144 | norm_hidden_states, 145 | encoder_hidden_states=encoder_hidden_states 146 | if self.only_cross_attention 147 | else None, 148 | attention_mask=attention_mask, 149 | **cross_attention_kwargs, 150 | ) 151 | 152 | if self.use_ada_layer_norm_zero: 153 | attn_output = gate_msa.unsqueeze(1) * attn_output 154 | hidden_states = attn_output + hidden_states 155 | 156 | if self.attn2 is not None: 157 | norm_hidden_states = ( 158 | self.norm2(hidden_states, timestep) 159 | if self.use_ada_layer_norm 160 | else self.norm2(hidden_states) 161 | ) 162 | self.bank.append(norm_hidden_states.clone()) 163 | 164 | # 2. Cross-Attention 165 | attn_output = self.attn2( 166 | norm_hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | attention_mask=encoder_attention_mask, 169 | **cross_attention_kwargs, 170 | ) 171 | hidden_states = attn_output + hidden_states 172 | 173 | if MODE == "read": 174 | hidden_states = ( 175 | self.attn1( 176 | norm_hidden_states, 177 | encoder_hidden_states=norm_hidden_states, 178 | attention_mask=attention_mask, 179 | ) 180 | + hidden_states 181 | ) 182 | 183 | if self.use_ada_layer_norm: # False 184 | norm_hidden_states = self.norm1_5(hidden_states, timestep) 185 | elif self.use_ada_layer_norm_zero: 186 | ( 187 | norm_hidden_states, 188 | gate_msa, 189 | shift_mlp, 190 | scale_mlp, 191 | gate_mlp, 192 | ) = self.norm1_5( 193 | hidden_states, 194 | timestep, 195 | class_labels, 196 | hidden_dtype=hidden_states.dtype, 197 | ) 198 | else: 199 | norm_hidden_states = self.norm1_5(hidden_states) 200 | 201 | bank_fea = [] 202 | for d in self.bank: 203 | if len(d.shape) == 3: 204 | d = d.unsqueeze(1).repeat(1, video_length, 1, 1) 205 | bank_fea.append(rearrange(d, "b t l c -> (b t) l c")) 206 | 207 | attn_hidden_states = self.attn1_5( 208 | norm_hidden_states, 209 | encoder_hidden_states=bank_fea[0], 210 | attention_mask=attention_mask, 211 | ) 212 | 213 | if reference_attention_weight != 1.: 214 | attn_hidden_states *= reference_attention_weight 215 | 216 | hidden_states = (attn_hidden_states + hidden_states) 217 | 218 | # self.bank.clear() 219 | if self.attn2 is not None: 220 | # Cross-Attention 221 | norm_hidden_states = ( 222 | self.norm2(hidden_states, timestep) 223 | if self.use_ada_layer_norm 224 | else self.norm2(hidden_states) 225 | ) 226 | 227 | attn_hidden_states = self.attn2( 228 | norm_hidden_states, 229 | encoder_hidden_states=encoder_hidden_states, 230 | attention_mask=attention_mask, 231 | ) 232 | 233 | if audio_attention_weight != 1.: 234 | attn_hidden_states *= audio_attention_weight 235 | 236 | hidden_states = (attn_hidden_states + hidden_states) 237 | 238 | # Feed-forward 239 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 240 | 241 | # Temporal-Attention 242 | if self.unet_use_temporal_attention: 243 | d = hidden_states.shape[1] 244 | hidden_states = rearrange( 245 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 246 | ) 247 | norm_hidden_states = ( 248 | self.norm_temp(hidden_states, timestep) 249 | if self.use_ada_layer_norm 250 | else self.norm_temp(hidden_states) 251 | ) 252 | hidden_states = ( 253 | self.attn_temp(norm_hidden_states) + hidden_states 254 | ) 255 | hidden_states = rearrange( 256 | hidden_states, "(b d) f c -> (b f) d c", d=d 257 | ) 258 | 259 | return hidden_states 260 | 261 | # 3. Feed-forward 262 | norm_hidden_states = self.norm3(hidden_states) 263 | 264 | if self.use_ada_layer_norm_zero: 265 | norm_hidden_states = ( 266 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 267 | ) 268 | 269 | ff_output = self.ff(norm_hidden_states) 270 | 271 | if self.use_ada_layer_norm_zero: 272 | ff_output = gate_mlp.unsqueeze(1) * ff_output 273 | 274 | hidden_states = ff_output + hidden_states 275 | 276 | return hidden_states 277 | 278 | if self.reference_attn: 279 | if self.fusion_blocks == "midup": 280 | attn_modules = [ 281 | module 282 | for module in ( 283 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 284 | ) 285 | if isinstance(module, BasicTransformerBlock) 286 | or isinstance(module, TemporalBasicTransformerBlock) 287 | ] 288 | elif self.fusion_blocks == "full": 289 | attn_modules = [ 290 | module 291 | for module in torch_dfs(self.unet) 292 | if isinstance(module, BasicTransformerBlock) 293 | or isinstance(module, TemporalBasicTransformerBlock) 294 | ] 295 | attn_modules = sorted( 296 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 297 | ) 298 | 299 | for i, module in enumerate(attn_modules): 300 | module._original_inner_forward = module.forward 301 | if isinstance(module, BasicTransformerBlock): 302 | module.forward = hacked_basic_transformer_inner_forward.__get__( 303 | module, BasicTransformerBlock 304 | ) 305 | if isinstance(module, TemporalBasicTransformerBlock): 306 | module.forward = hacked_basic_transformer_inner_forward.__get__( 307 | module, TemporalBasicTransformerBlock 308 | ) 309 | 310 | module.bank = [] 311 | module.attn_weight = float(i) / float(len(attn_modules)) 312 | 313 | def update( 314 | self, 315 | writer, 316 | do_classifier_free_guidance=True, 317 | dtype=torch.float16, 318 | ): 319 | if self.reference_attn: 320 | if self.fusion_blocks == "midup": 321 | reader_attn_modules = [ 322 | module 323 | for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) 324 | if isinstance(module, TemporalBasicTransformerBlock) 325 | ] 326 | writer_attn_modules = [ 327 | module 328 | for module in (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks)) 329 | if isinstance(module, BasicTransformerBlock) 330 | ] 331 | elif self.fusion_blocks == "full": 332 | reader_attn_modules = [ 333 | module 334 | for module in torch_dfs(self.unet) 335 | if isinstance(module, TemporalBasicTransformerBlock) 336 | ] 337 | writer_attn_modules = [ 338 | module 339 | for module in torch_dfs(writer.unet) 340 | if isinstance(module, BasicTransformerBlock) 341 | ] 342 | reader_attn_modules = sorted( 343 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 344 | ) 345 | writer_attn_modules = sorted( 346 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 347 | ) 348 | for r, w in zip(reader_attn_modules, writer_attn_modules): 349 | if do_classifier_free_guidance: 350 | r.bank = [torch.cat([torch.zeros_like(v), v]).to(dtype) for v in w.bank] 351 | else: 352 | r.bank = [v.clone().to(dtype) for v in w.bank] 353 | 354 | def clear(self): 355 | if self.reference_attn: 356 | if self.fusion_blocks == "midup": 357 | reader_attn_modules = [ 358 | module 359 | for module in ( 360 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 361 | ) 362 | if isinstance(module, BasicTransformerBlock) 363 | or isinstance(module, TemporalBasicTransformerBlock) 364 | ] 365 | elif self.fusion_blocks == "full": 366 | reader_attn_modules = [ 367 | module 368 | for module in torch_dfs(self.unet) 369 | if isinstance(module, BasicTransformerBlock) 370 | or isinstance(module, TemporalBasicTransformerBlock) 371 | ] 372 | reader_attn_modules = sorted( 373 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 374 | ) 375 | for r in reader_attn_modules: 376 | r.bank.clear() 377 | -------------------------------------------------------------------------------- /src/modules/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class InflatedGroupNorm(nn.GroupNorm): 21 | def forward(self, x): 22 | video_length = x.shape[2] 23 | 24 | x = rearrange(x, "b c f h w -> (b f) c h w") 25 | x = super().forward(x) 26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 27 | 28 | return x 29 | 30 | 31 | class Upsample3D(nn.Module): 32 | def __init__( 33 | self, 34 | channels, 35 | use_conv=False, 36 | use_conv_transpose=False, 37 | out_channels=None, 38 | name="conv", 39 | ): 40 | super().__init__() 41 | self.channels = channels 42 | self.out_channels = out_channels or channels 43 | self.use_conv = use_conv 44 | self.use_conv_transpose = use_conv_transpose 45 | self.name = name 46 | 47 | conv = None 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate( 72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 73 | ) 74 | else: 75 | hidden_states = F.interpolate( 76 | hidden_states, size=output_size, mode="nearest" 77 | ) 78 | 79 | # If the input is bfloat16, we cast back to bfloat16 80 | if dtype == torch.bfloat16: 81 | hidden_states = hidden_states.to(dtype) 82 | 83 | # if self.use_conv: 84 | # if self.name == "conv": 85 | # hidden_states = self.conv(hidden_states) 86 | # else: 87 | # hidden_states = self.Conv2d_0(hidden_states) 88 | hidden_states = self.conv(hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | class Downsample3D(nn.Module): 94 | def __init__( 95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.padding = padding 102 | stride = 2 103 | self.name = name 104 | 105 | if use_conv: 106 | self.conv = InflatedConv3d( 107 | self.channels, self.out_channels, 3, stride=stride, padding=padding 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | def forward(self, hidden_states): 113 | assert hidden_states.shape[1] == self.channels 114 | if self.use_conv and self.padding == 0: 115 | raise NotImplementedError 116 | 117 | assert hidden_states.shape[1] == self.channels 118 | hidden_states = self.conv(hidden_states) 119 | 120 | return hidden_states 121 | 122 | 123 | class ResnetBlock3D(nn.Module): 124 | def __init__( 125 | self, 126 | *, 127 | in_channels, 128 | out_channels=None, 129 | conv_shortcut=False, 130 | dropout=0.0, 131 | temb_channels=512, 132 | groups=32, 133 | groups_out=None, 134 | pre_norm=True, 135 | eps=1e-6, 136 | non_linearity="swish", 137 | time_embedding_norm="default", 138 | output_scale_factor=1.0, 139 | use_in_shortcut=None, 140 | use_inflated_groupnorm=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | assert use_inflated_groupnorm != None 156 | if use_inflated_groupnorm: 157 | self.norm1 = InflatedGroupNorm( 158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 159 | ) 160 | else: 161 | self.norm1 = torch.nn.GroupNorm( 162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 163 | ) 164 | 165 | self.conv1 = InflatedConv3d( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | if self.time_embedding_norm == "default": 171 | time_emb_proj_out_channels = out_channels 172 | elif self.time_embedding_norm == "scale_shift": 173 | time_emb_proj_out_channels = out_channels * 2 174 | else: 175 | raise ValueError( 176 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 177 | ) 178 | 179 | self.time_emb_proj = torch.nn.Linear( 180 | temb_channels, time_emb_proj_out_channels 181 | ) 182 | else: 183 | self.time_emb_proj = None 184 | 185 | if use_inflated_groupnorm: 186 | self.norm2 = InflatedGroupNorm( 187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 188 | ) 189 | else: 190 | self.norm2 = torch.nn.GroupNorm( 191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 192 | ) 193 | self.dropout = torch.nn.Dropout(dropout) 194 | self.conv2 = InflatedConv3d( 195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 196 | ) 197 | 198 | if non_linearity == "swish": 199 | self.nonlinearity = lambda x: F.silu(x) 200 | elif non_linearity == "mish": 201 | self.nonlinearity = Mish() 202 | elif non_linearity == "silu": 203 | self.nonlinearity = nn.SiLU() 204 | 205 | self.use_in_shortcut = ( 206 | self.in_channels != self.out_channels 207 | if use_in_shortcut is None 208 | else use_in_shortcut 209 | ) 210 | 211 | self.conv_shortcut = None 212 | if self.use_in_shortcut: 213 | self.conv_shortcut = InflatedConv3d( 214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 215 | ) 216 | 217 | def forward(self, input_tensor, temb): 218 | hidden_states = input_tensor 219 | 220 | hidden_states = self.norm1(hidden_states) 221 | hidden_states = self.nonlinearity(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if temb is not None: 226 | temb = self.time_emb_proj(self.nonlinearity(temb)) 227 | if len(temb.shape) == 2: 228 | temb = temb[:, :, None, None, None] 229 | elif len(temb.shape) == 3: 230 | temb = temb[:, :, :, None, None].permute(0, 2, 1, 3, 4) 231 | 232 | if temb is not None and self.time_embedding_norm == "default": 233 | hidden_states = hidden_states + temb 234 | 235 | hidden_states = self.norm2(hidden_states) 236 | 237 | if temb is not None and self.time_embedding_norm == "scale_shift": 238 | scale, shift = torch.chunk(temb, 2, dim=1) 239 | hidden_states = hidden_states * (1 + scale) + shift 240 | 241 | hidden_states = self.nonlinearity(hidden_states) 242 | 243 | hidden_states = self.dropout(hidden_states) 244 | hidden_states = self.conv2(hidden_states) 245 | 246 | if self.conv_shortcut is not None: 247 | input_tensor = self.conv_shortcut(input_tensor) 248 | 249 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 250 | 251 | return output_tensor 252 | 253 | 254 | class Mish(torch.nn.Module): 255 | def forward(self, hidden_states): 256 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 257 | -------------------------------------------------------------------------------- /src/modules/transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | 8 | try: 9 | from diffusers.models.embeddings import CaptionProjection 10 | except: 11 | from diffusers.models.embeddings import PixArtAlphaTextProjection as CaptionProjection 12 | 13 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 14 | from diffusers.models.modeling_utils import ModelMixin 15 | from diffusers.models.normalization import AdaLayerNormSingle 16 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 17 | from torch import nn 18 | 19 | from .attention import BasicTransformerBlock 20 | 21 | 22 | @dataclass 23 | class Transformer2DModelOutput(BaseOutput): 24 | """ 25 | The output of [`Transformer2DModel`]. 26 | 27 | Args: 28 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 29 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 30 | distributions for the unnoised latent pixels. 31 | """ 32 | 33 | sample: torch.FloatTensor 34 | ref_feature: torch.FloatTensor 35 | 36 | 37 | class Transformer2DModel(ModelMixin, ConfigMixin): 38 | """ 39 | A 2D Transformer model for image-like data. 40 | 41 | Parameters: 42 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 43 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 44 | in_channels (`int`, *optional*): 45 | The number of channels in the input and output (specify if the input is **continuous**). 46 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 47 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 48 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 49 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 50 | This is fixed during training since it is used to learn a number of position embeddings. 51 | num_vector_embeds (`int`, *optional*): 52 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 53 | Includes the class for the masked latent pixel. 54 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 55 | num_embeds_ada_norm ( `int`, *optional*): 56 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 57 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 58 | added to the hidden states. 59 | 60 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 61 | attention_bias (`bool`, *optional*): 62 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 63 | """ 64 | 65 | _supports_gradient_checkpointing = True 66 | 67 | @register_to_config 68 | def __init__( 69 | self, 70 | num_attention_heads: int = 16, 71 | attention_head_dim: int = 88, 72 | in_channels: Optional[int] = None, 73 | out_channels: Optional[int] = None, 74 | num_layers: int = 1, 75 | dropout: float = 0.0, 76 | norm_num_groups: int = 32, 77 | cross_attention_dim: Optional[int] = None, 78 | attention_bias: bool = False, 79 | sample_size: Optional[int] = None, 80 | num_vector_embeds: Optional[int] = None, 81 | patch_size: Optional[int] = None, 82 | activation_fn: str = "geglu", 83 | num_embeds_ada_norm: Optional[int] = None, 84 | use_linear_projection: bool = False, 85 | only_cross_attention: bool = False, 86 | double_self_attention: bool = False, 87 | upcast_attention: bool = False, 88 | norm_type: str = "layer_norm", 89 | norm_elementwise_affine: bool = True, 90 | norm_eps: float = 1e-5, 91 | attention_type: str = "default", 92 | caption_channels: int = None, 93 | ): 94 | super().__init__() 95 | self.use_linear_projection = use_linear_projection 96 | self.num_attention_heads = num_attention_heads 97 | self.attention_head_dim = attention_head_dim 98 | inner_dim = num_attention_heads * attention_head_dim 99 | 100 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 101 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 102 | 103 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 104 | # Define whether input is continuous or discrete depending on configuration 105 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 106 | self.is_input_vectorized = num_vector_embeds is not None 107 | self.is_input_patches = in_channels is not None and patch_size is not None 108 | 109 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 110 | deprecation_message = ( 111 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 112 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 113 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 114 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 115 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 116 | ) 117 | deprecate( 118 | "norm_type!=num_embeds_ada_norm", 119 | "1.0.0", 120 | deprecation_message, 121 | standard_warn=False, 122 | ) 123 | norm_type = "ada_norm" 124 | 125 | if self.is_input_continuous and self.is_input_vectorized: 126 | raise ValueError( 127 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 128 | " sure that either `in_channels` or `num_vector_embeds` is None." 129 | ) 130 | elif self.is_input_vectorized and self.is_input_patches: 131 | raise ValueError( 132 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 133 | " sure that either `num_vector_embeds` or `num_patches` is None." 134 | ) 135 | elif ( 136 | not self.is_input_continuous 137 | and not self.is_input_vectorized 138 | and not self.is_input_patches 139 | ): 140 | raise ValueError( 141 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 142 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 143 | ) 144 | 145 | # 2. Define input layers 146 | self.in_channels = in_channels 147 | 148 | self.norm = torch.nn.GroupNorm( 149 | num_groups=norm_num_groups, 150 | num_channels=in_channels, 151 | eps=1e-6, 152 | affine=True, 153 | ) 154 | if use_linear_projection: 155 | self.proj_in = linear_cls(in_channels, inner_dim) 156 | else: 157 | self.proj_in = conv_cls( 158 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 159 | ) 160 | 161 | # 3. Define transformers blocks 162 | self.transformer_blocks = nn.ModuleList( 163 | [ 164 | BasicTransformerBlock( 165 | inner_dim, 166 | num_attention_heads, 167 | attention_head_dim, 168 | dropout=dropout, 169 | cross_attention_dim=cross_attention_dim, 170 | activation_fn=activation_fn, 171 | num_embeds_ada_norm=num_embeds_ada_norm, 172 | attention_bias=attention_bias, 173 | only_cross_attention=only_cross_attention, 174 | double_self_attention=double_self_attention, 175 | upcast_attention=upcast_attention, 176 | norm_type=norm_type, 177 | norm_elementwise_affine=norm_elementwise_affine, 178 | norm_eps=norm_eps, 179 | attention_type=attention_type, 180 | ) 181 | for d in range(num_layers) 182 | ] 183 | ) 184 | 185 | # 4. Define output layers 186 | self.out_channels = in_channels if out_channels is None else out_channels 187 | # TODO: should use out_channels for continuous projections 188 | if use_linear_projection: 189 | self.proj_out = linear_cls(inner_dim, in_channels) 190 | else: 191 | self.proj_out = conv_cls( 192 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 193 | ) 194 | 195 | # 5. PixArt-Alpha blocks. 196 | self.adaln_single = None 197 | self.use_additional_conditions = False 198 | if norm_type == "ada_norm_single": 199 | self.use_additional_conditions = self.config.sample_size == 128 200 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 201 | # additional conditions until we find better name 202 | self.adaln_single = AdaLayerNormSingle( 203 | inner_dim, use_additional_conditions=self.use_additional_conditions 204 | ) 205 | 206 | self.caption_projection = None 207 | if caption_channels is not None: 208 | self.caption_projection = CaptionProjection( 209 | in_features=caption_channels, hidden_size=inner_dim 210 | ) 211 | 212 | self.gradient_checkpointing = False 213 | 214 | def _set_gradient_checkpointing(self, module, value=False): 215 | if hasattr(module, "gradient_checkpointing"): 216 | module.gradient_checkpointing = value 217 | 218 | def forward( 219 | self, 220 | hidden_states: torch.Tensor, 221 | encoder_hidden_states: Optional[torch.Tensor] = None, 222 | timestep: Optional[torch.LongTensor] = None, 223 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 224 | class_labels: Optional[torch.LongTensor] = None, 225 | cross_attention_kwargs: Dict[str, Any] = None, 226 | attention_mask: Optional[torch.Tensor] = None, 227 | encoder_attention_mask: Optional[torch.Tensor] = None, 228 | return_dict: bool = True, 229 | ): 230 | """ 231 | The [`Transformer2DModel`] forward method. 232 | 233 | Args: 234 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 235 | Input `hidden_states`. 236 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 237 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 238 | self-attention. 239 | timestep ( `torch.LongTensor`, *optional*): 240 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 241 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 242 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 243 | `AdaLayerZeroNorm`. 244 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 245 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 246 | `self.processor` in 247 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 248 | attention_mask ( `torch.Tensor`, *optional*): 249 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 250 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 251 | negative values to the attention scores corresponding to "discard" tokens. 252 | encoder_attention_mask ( `torch.Tensor`, *optional*): 253 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 254 | 255 | * Mask `(batch, sequence_length)` True = keep, False = discard. 256 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 257 | 258 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 259 | above. This bias will be added to the cross-attention scores. 260 | return_dict (`bool`, *optional*, defaults to `True`): 261 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 262 | tuple. 263 | 264 | Returns: 265 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 266 | `tuple` where the first element is the sample tensor. 267 | """ 268 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 269 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 270 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 271 | # expects mask of shape: 272 | # [batch, key_tokens] 273 | # adds singleton query_tokens dimension: 274 | # [batch, 1, key_tokens] 275 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 276 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 277 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 278 | if attention_mask is not None and attention_mask.ndim == 2: 279 | # assume that mask is expressed as: 280 | # (1 = keep, 0 = discard) 281 | # convert mask into a bias that can be added to attention scores: 282 | # (keep = +0, discard = -10000.0) 283 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 284 | attention_mask = attention_mask.unsqueeze(1) 285 | 286 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 287 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 288 | encoder_attention_mask = ( 289 | 1 - encoder_attention_mask.to(hidden_states.dtype) 290 | ) * -10000.0 291 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 292 | 293 | # Retrieve lora scale. 294 | lora_scale = ( 295 | cross_attention_kwargs.get("scale", 1.0) 296 | if cross_attention_kwargs is not None 297 | else 1.0 298 | ) 299 | 300 | # 1. Input 301 | batch, _, height, width = hidden_states.shape 302 | residual = hidden_states 303 | 304 | hidden_states = self.norm(hidden_states) 305 | if not self.use_linear_projection: 306 | hidden_states = ( 307 | self.proj_in(hidden_states, scale=lora_scale) 308 | if not USE_PEFT_BACKEND 309 | else self.proj_in(hidden_states) 310 | ) 311 | inner_dim = hidden_states.shape[1] 312 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 313 | batch, height * width, inner_dim 314 | ) 315 | else: 316 | inner_dim = hidden_states.shape[1] 317 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 318 | batch, height * width, inner_dim 319 | ) 320 | hidden_states = ( 321 | self.proj_in(hidden_states, scale=lora_scale) 322 | if not USE_PEFT_BACKEND 323 | else self.proj_in(hidden_states) 324 | ) 325 | 326 | # 2. Blocks 327 | if self.caption_projection is not None: 328 | batch_size = hidden_states.shape[0] 329 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 330 | encoder_hidden_states = encoder_hidden_states.view( 331 | batch_size, -1, hidden_states.shape[-1] 332 | ) 333 | 334 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim) 335 | for block in self.transformer_blocks: 336 | if self.training and self.gradient_checkpointing: 337 | 338 | def create_custom_forward(module, return_dict=None): 339 | def custom_forward(*inputs): 340 | if return_dict is not None: 341 | return module(*inputs, return_dict=return_dict) 342 | else: 343 | return module(*inputs) 344 | 345 | return custom_forward 346 | 347 | ckpt_kwargs: Dict[str, Any] = ( 348 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 349 | ) 350 | hidden_states = torch.utils.checkpoint.checkpoint( 351 | create_custom_forward(block), 352 | hidden_states, 353 | attention_mask, 354 | encoder_hidden_states, 355 | encoder_attention_mask, 356 | timestep, 357 | cross_attention_kwargs, 358 | class_labels, 359 | **ckpt_kwargs, 360 | ) 361 | else: 362 | hidden_states = block( 363 | hidden_states, 364 | attention_mask=attention_mask, 365 | encoder_hidden_states=encoder_hidden_states, 366 | encoder_attention_mask=encoder_attention_mask, 367 | timestep=timestep, 368 | cross_attention_kwargs=cross_attention_kwargs, 369 | class_labels=class_labels, 370 | ) 371 | 372 | # 3. Output 373 | if self.is_input_continuous: 374 | if not self.use_linear_projection: 375 | hidden_states = ( 376 | hidden_states.reshape(batch, height, width, inner_dim) 377 | .permute(0, 3, 1, 2) 378 | .contiguous() 379 | ) 380 | hidden_states = ( 381 | self.proj_out(hidden_states, scale=lora_scale) 382 | if not USE_PEFT_BACKEND 383 | else self.proj_out(hidden_states) 384 | ) 385 | else: 386 | hidden_states = ( 387 | self.proj_out(hidden_states, scale=lora_scale) 388 | if not USE_PEFT_BACKEND 389 | else self.proj_out(hidden_states) 390 | ) 391 | hidden_states = ( 392 | hidden_states.reshape(batch, height, width, inner_dim) 393 | .permute(0, 3, 1, 2) 394 | .contiguous() 395 | ) 396 | 397 | output = hidden_states + residual 398 | if not return_dict: 399 | return (output, ref_feature) 400 | 401 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature) 402 | -------------------------------------------------------------------------------- /src/modules/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm( 59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 60 | ) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d( 65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 66 | ) 67 | 68 | # Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | TemporalBasicTransformerBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | upcast_attention=upcast_attention, 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d( 94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 95 | ) 96 | 97 | self.gradient_checkpointing = False 98 | 99 | def _set_gradient_checkpointing(self, module, value=False): 100 | if hasattr(module, "gradient_checkpointing"): 101 | module.gradient_checkpointing = value 102 | 103 | def forward( 104 | self, 105 | hidden_states, 106 | encoder_hidden_states=None, 107 | timestep=None, 108 | return_dict: bool = True, 109 | ): 110 | # Input 111 | assert ( 112 | hidden_states.dim() == 5 113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 114 | video_length = hidden_states.shape[2] 115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 117 | encoder_hidden_states = repeat( 118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 119 | ) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 129 | batch, height * weight, inner_dim 130 | ) 131 | else: 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | hidden_states = self.proj_in(hidden_states) 137 | 138 | # Blocks 139 | for i, block in enumerate(self.transformer_blocks): 140 | hidden_states = block( 141 | hidden_states, 142 | encoder_hidden_states=encoder_hidden_states, 143 | timestep=timestep, 144 | video_length=video_length, 145 | ) 146 | 147 | # Output 148 | if not self.use_linear_projection: 149 | hidden_states = ( 150 | hidden_states.reshape(batch, height, weight, inner_dim) 151 | .permute(0, 3, 1, 2) 152 | .contiguous() 153 | ) 154 | hidden_states = self.proj_out(hidden_states) 155 | else: 156 | hidden_states = self.proj_out(hidden_states) 157 | hidden_states = ( 158 | hidden_states.reshape(batch, height, weight, inner_dim) 159 | .permute(0, 3, 1, 2) 160 | .contiguous() 161 | ) 162 | 163 | output = hidden_states + residual 164 | 165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 166 | if not return_dict: 167 | return (output,) 168 | 169 | return Transformer3DModelOutput(sample=output) 170 | -------------------------------------------------------------------------------- /src/modules/v_kps_guider.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from diffusers.models.modeling_utils import ModelMixin 6 | from .motion_module import zero_module 7 | from .resnet import InflatedConv3d 8 | 9 | 10 | class VKpsGuider(ModelMixin): 11 | def __init__( 12 | self, 13 | conditioning_embedding_channels: int, 14 | conditioning_channels: int = 3, 15 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 16 | ): 17 | super().__init__() 18 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 19 | 20 | self.blocks = nn.ModuleList([]) 21 | 22 | for i in range(len(block_out_channels) - 1): 23 | channel_in = block_out_channels[i] 24 | channel_out = block_out_channels[i + 1] 25 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) 26 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 27 | 28 | self.conv_out = zero_module(InflatedConv3d( 29 | block_out_channels[-1], 30 | conditioning_embedding_channels, 31 | kernel_size=3, 32 | padding=1, 33 | )) 34 | 35 | def forward(self, conditioning): 36 | embedding = self.conv_in(conditioning) 37 | embedding = F.silu(embedding) 38 | 39 | for block in self.blocks: 40 | embedding = block(embedding) 41 | embedding = F.silu(embedding) 42 | 43 | embedding = self.conv_out(embedding) 44 | 45 | return embedding 46 | -------------------------------------------------------------------------------- /src/pipelines/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiankuan93/ComfyUI-V-Express/ee8f245406e6fa07353dbf36f8ada06f81d48a9e/src/pipelines/.DS_Store -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .v_express_pipeline import VExpressPipeline 2 | -------------------------------------------------------------------------------- /src/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = True, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | next_itr = [] 40 | for e in range(j, j + context_size * context_step, context_step): 41 | if e >= num_frames: 42 | e = num_frames - 2 - e % num_frames 43 | next_itr.append(e) 44 | 45 | yield next_itr 46 | 47 | 48 | def get_context_scheduler(name: str) -> Callable: 49 | if name == "uniform": 50 | return uniform 51 | else: 52 | raise ValueError(f"Unknown context_overlap policy {name}") 53 | 54 | 55 | def get_total_steps( 56 | scheduler, 57 | timesteps: List[int], 58 | num_steps: Optional[int] = None, 59 | num_frames: int = ..., 60 | context_size: Optional[int] = None, 61 | context_stride: int = 3, 62 | context_overlap: int = 4, 63 | closed_loop: bool = True, 64 | ): 65 | return sum( 66 | len( 67 | list( 68 | scheduler( 69 | i, 70 | num_steps, 71 | num_frames, 72 | context_size, 73 | context_stride, 74 | context_overlap, 75 | ) 76 | ) 77 | ) 78 | for i in range(len(timesteps)) 79 | ) 80 | -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pathlib 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as func 9 | import tqdm 10 | from imageio_ffmpeg import get_ffmpeg_exe 11 | 12 | tensor_interpolation = None 13 | 14 | 15 | def get_tensor_interpolation_method(): 16 | return tensor_interpolation 17 | 18 | 19 | def set_tensor_interpolation_method(is_slerp): 20 | global tensor_interpolation 21 | tensor_interpolation = slerp if is_slerp else linear 22 | 23 | 24 | def linear(v1, v2, t): 25 | return (1.0 - t) * v1 + t * v2 26 | 27 | 28 | def slerp(v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995) -> torch.Tensor: 29 | u0 = v0 / v0.norm() 30 | u1 = v1 / v1.norm() 31 | dot = (u0 * u1).sum() 32 | if dot.abs() > DOT_THRESHOLD: 33 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 34 | return (1.0 - t) * v0 + t * v1 35 | omega = dot.acos() 36 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 37 | 38 | 39 | def draw_kps_image(height, width, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)]): 40 | stick_width = 4 41 | limb_seq = np.array([[0, 2], [1, 2]]) 42 | kps = np.array(kps) 43 | 44 | canvas = np.zeros((height, width, 3), dtype=np.uint8) 45 | 46 | for i in range(len(limb_seq)): 47 | index = limb_seq[i] 48 | color = color_list[index[0]] 49 | 50 | x = kps[index][:, 0] 51 | y = kps[index][:, 1] 52 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 53 | angle = int(math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))) 54 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), angle, 0, 360, 1) 55 | cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) 56 | 57 | for idx_kp, kp in enumerate(kps): 58 | color = color_list[idx_kp] 59 | x, y = kp 60 | cv2.circle(canvas, (int(x), int(y)), 4, color, -1) 61 | 62 | return canvas 63 | 64 | 65 | def median_filter_3d(video_tensor, kernel_size, device): 66 | _, video_length, height, width = video_tensor.shape 67 | 68 | pad_size = kernel_size // 2 69 | video_tensor = func.pad(video_tensor, (pad_size, pad_size, pad_size, pad_size, pad_size, pad_size), mode='reflect') 70 | 71 | filtered_video_tensor = [] 72 | for i in tqdm.tqdm(range(video_length), desc='Median Filtering'): 73 | video_segment = video_tensor[:, i:i + kernel_size, ...].to(device) 74 | video_segment = video_segment.unfold(dimension=2, size=kernel_size, step=1) 75 | video_segment = video_segment.unfold(dimension=3, size=kernel_size, step=1) 76 | video_segment = video_segment.permute(0, 2, 3, 1, 4, 5).reshape(3, height, width, -1) 77 | filtered_video_frame = torch.median(video_segment, dim=-1)[0] 78 | filtered_video_tensor.append(filtered_video_frame.cpu()) 79 | filtered_video_tensor = torch.stack(filtered_video_tensor, dim=1) 80 | return filtered_video_tensor 81 | 82 | 83 | def save_video(video_tensor, audio_path, output_path, device, fps=30.0): 84 | pathlib.Path(output_path).parent.mkdir(exist_ok=True, parents=True) 85 | 86 | video_tensor = video_tensor[0, ...] 87 | _, num_frames, height, width = video_tensor.shape 88 | 89 | video_tensor = median_filter_3d(video_tensor, kernel_size=3, device=device) 90 | video_tensor = video_tensor.permute(1, 2, 3, 0) 91 | video_frames = (video_tensor * 255).numpy().astype(np.uint8) 92 | 93 | output_name = pathlib.Path(output_path).stem 94 | temp_output_path = output_path.replace(output_name, output_name + '-temp') 95 | video_writer = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) 96 | 97 | for i in tqdm.tqdm(range(num_frames), 'Writing frames into file'): 98 | frame_image = video_frames[i, ...] 99 | frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) 100 | video_writer.write(frame_image) 101 | video_writer.release() 102 | 103 | cmd = (f'{get_ffmpeg_exe()} -i "{temp_output_path}" -i "{audio_path}" ' 104 | f'-map 0:v -map 1:a -c:v h264 -shortest -y "{output_path}" -loglevel quiet') 105 | os.system(cmd) 106 | os.remove(temp_output_path) 107 | 108 | 109 | def compute_dist(x1, y1, x2, y2): 110 | return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 111 | 112 | 113 | def compute_ratio(kps): 114 | l_eye_x, l_eye_y = kps[0][0], kps[0][1] 115 | r_eye_x, r_eye_y = kps[1][0], kps[1][1] 116 | nose_x, nose_y = kps[2][0], kps[2][1] 117 | d_left = compute_dist(l_eye_x, l_eye_y, nose_x, nose_y) 118 | d_right = compute_dist(r_eye_x, r_eye_y, nose_x, nose_y) 119 | ratio = d_left / (d_right + 1e-6) 120 | return ratio 121 | 122 | 123 | def point_to_line_dist(point, line_points): 124 | point = np.array(point) 125 | line_points = np.array(line_points) 126 | line_vec = line_points[1] - line_points[0] 127 | point_vec = point - line_points[0] 128 | line_norm = line_vec / np.sqrt(np.sum(line_vec ** 2)) 129 | point_vec_scaled = point_vec * 1.0 / np.sqrt(np.sum(line_vec ** 2)) 130 | t = np.dot(line_norm, point_vec_scaled) 131 | if t < 0.0: 132 | t = 0.0 133 | elif t > 1.0: 134 | t = 1.0 135 | nearest = line_points[0] + t * line_vec 136 | dist = np.sqrt(np.sum((point - nearest) ** 2)) 137 | return dist 138 | 139 | 140 | def get_face_size(kps): 141 | # 0: left eye, 1: right eye, 2: nose 142 | A = kps[0, :] 143 | B = kps[1, :] 144 | C = kps[2, :] 145 | 146 | AB_dist = math.sqrt((A[0] - B[0]) ** 2 + (A[1] - B[1]) ** 2) 147 | C_AB_dist = point_to_line_dist(C, [A, B]) 148 | return AB_dist, C_AB_dist 149 | 150 | 151 | def get_rescale_params(kps_ref, kps_target): 152 | kps_ref = np.array(kps_ref) 153 | kps_target = np.array(kps_target) 154 | 155 | ref_AB_dist, ref_C_AB_dist = get_face_size(kps_ref) 156 | target_AB_dist, target_C_AB_dist = get_face_size(kps_target) 157 | 158 | scale_width = ref_AB_dist / target_AB_dist 159 | scale_height = ref_C_AB_dist / target_C_AB_dist 160 | 161 | return scale_width, scale_height 162 | 163 | 164 | def retarget_kps(ref_kps, tgt_kps_list, only_offset=True): 165 | ref_kps = np.array(ref_kps) 166 | tgt_kps_list = np.array(tgt_kps_list) 167 | 168 | ref_ratio = compute_ratio(ref_kps) 169 | 170 | ratio_delta = 10000 171 | selected_tgt_kps_idx = None 172 | for idx, tgt_kps in enumerate(tgt_kps_list): 173 | tgt_ratio = compute_ratio(tgt_kps) 174 | if math.fabs(tgt_ratio - ref_ratio) < ratio_delta: 175 | selected_tgt_kps_idx = idx 176 | ratio_delta = tgt_ratio 177 | 178 | scale_width, scale_height = get_rescale_params( 179 | kps_ref=ref_kps, 180 | kps_target=tgt_kps_list[selected_tgt_kps_idx], 181 | ) 182 | 183 | rescaled_tgt_kps_list = np.array(tgt_kps_list) 184 | rescaled_tgt_kps_list[:, :, 0] *= scale_width 185 | rescaled_tgt_kps_list[:, :, 1] *= scale_height 186 | 187 | if only_offset: 188 | nose_offset = rescaled_tgt_kps_list[:, 2, :] - rescaled_tgt_kps_list[0, 2, :] 189 | nose_offset = nose_offset[:, np.newaxis, :] 190 | ref_kps_repeat = np.tile(ref_kps, (tgt_kps_list.shape[0], 1, 1)) 191 | 192 | ref_kps_repeat[:, :, :] -= (nose_offset / 2.0) 193 | rescaled_tgt_kps_list = ref_kps_repeat 194 | else: 195 | nose_offset_x = rescaled_tgt_kps_list[0, 2, 0] - ref_kps[2][0] 196 | nose_offset_y = rescaled_tgt_kps_list[0, 2, 1] - ref_kps[2][1] 197 | 198 | rescaled_tgt_kps_list[:, :, 0] -= nose_offset_x 199 | rescaled_tgt_kps_list[:, :, 1] -= nose_offset_y 200 | 201 | return rescaled_tgt_kps_list 202 | -------------------------------------------------------------------------------- /src/pipelines/v_express_pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py 2 | import inspect 3 | import math 4 | from typing import Callable, List, Optional, Union 5 | 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.schedulers import ( 10 | DDIMScheduler, 11 | DPMSolverMultistepScheduler, 12 | EulerAncestralDiscreteScheduler, 13 | EulerDiscreteScheduler, 14 | LMSDiscreteScheduler, 15 | PNDMScheduler, 16 | ) 17 | from diffusers.utils import is_accelerate_available 18 | from diffusers.utils.torch_utils import randn_tensor 19 | from einops import rearrange 20 | from tqdm import tqdm 21 | from transformers import CLIPImageProcessor 22 | 23 | from modules import ReferenceAttentionControl 24 | from .context import get_context_scheduler 25 | 26 | 27 | def retrieve_timesteps( 28 | scheduler, 29 | num_inference_steps: Optional[int] = None, 30 | device: Optional[Union[str, torch.device]] = None, 31 | timesteps: Optional[List[int]] = None, 32 | **kwargs, 33 | ): 34 | """ 35 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 36 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 37 | 38 | Args: 39 | scheduler (`SchedulerMixin`): 40 | The scheduler to get timesteps from. 41 | num_inference_steps (`int`): 42 | The number of diffusion steps used when generating samples with a pre-trained model. If used, 43 | `timesteps` must be `None`. 44 | device (`str` or `torch.device`, *optional*): 45 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 46 | timesteps (`List[int]`, *optional*): 47 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 48 | timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` 49 | must be `None`. 50 | 51 | Returns: 52 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 53 | second element is the number of inference steps. 54 | """ 55 | if timesteps is not None: 56 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 57 | if not accepts_timesteps: 58 | raise ValueError( 59 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 60 | f" timestep schedules. Please check whether you are using the correct scheduler." 61 | ) 62 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 63 | timesteps = scheduler.timesteps 64 | num_inference_steps = len(timesteps) 65 | else: 66 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 67 | timesteps = scheduler.timesteps 68 | return timesteps, num_inference_steps 69 | 70 | 71 | class VExpressPipeline(DiffusionPipeline): 72 | _optional_components = [] 73 | 74 | def __init__( 75 | self, 76 | vae, 77 | reference_net, 78 | denoising_unet, 79 | v_kps_guider, 80 | audio_processor, 81 | audio_encoder, 82 | audio_projection, 83 | scheduler: Union[ 84 | DDIMScheduler, 85 | PNDMScheduler, 86 | LMSDiscreteScheduler, 87 | EulerDiscreteScheduler, 88 | EulerAncestralDiscreteScheduler, 89 | DPMSolverMultistepScheduler, 90 | ], 91 | image_proj_model=None, 92 | tokenizer=None, 93 | text_encoder=None, 94 | ): 95 | super().__init__() 96 | 97 | self.register_modules( 98 | vae=vae, 99 | reference_net=reference_net, 100 | denoising_unet=denoising_unet, 101 | v_kps_guider=v_kps_guider, 102 | audio_processor=audio_processor, 103 | audio_encoder=audio_encoder, 104 | audio_projection=audio_projection, 105 | scheduler=scheduler, 106 | image_proj_model=image_proj_model, 107 | tokenizer=tokenizer, 108 | text_encoder=text_encoder, 109 | ) 110 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 111 | self.clip_image_processor = CLIPImageProcessor() 112 | self.reference_image_processor = VaeImageProcessor( 113 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 114 | ) 115 | self.condition_image_processor = VaeImageProcessor( 116 | vae_scale_factor=self.vae_scale_factor, 117 | do_convert_rgb=True, 118 | do_normalize=False, 119 | ) 120 | 121 | def enable_vae_slicing(self): 122 | self.vae.enable_slicing() 123 | 124 | def disable_vae_slicing(self): 125 | self.vae.disable_slicing() 126 | 127 | def enable_sequential_cpu_offload(self, gpu_id=0): 128 | if is_accelerate_available(): 129 | from accelerate import cpu_offload 130 | else: 131 | raise ImportError("Please install accelerate via `pip install accelerate`") 132 | 133 | device = torch.device(f"cuda:{gpu_id}") 134 | 135 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 136 | if cpu_offloaded_model is not None: 137 | cpu_offload(cpu_offloaded_model, device) 138 | 139 | @property 140 | def _execution_device(self): 141 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 142 | return self.device 143 | for module in self.unet.modules(): 144 | if ( 145 | hasattr(module, "_hf_hook") 146 | and hasattr(module._hf_hook, "execution_device") 147 | and module._hf_hook.execution_device is not None 148 | ): 149 | return torch.device(module._hf_hook.execution_device) 150 | return self.device 151 | 152 | @torch.no_grad() 153 | def decode_latents(self, latents): 154 | video_length = latents.shape[2] 155 | latents = 1 / 0.18215 * latents 156 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 157 | video = [] 158 | for frame_idx in tqdm(range(latents.shape[0]), desc='Decoding latents into frames'): 159 | image = self.vae.decode(latents[frame_idx: frame_idx + 1].to(self.vae.device)).sample 160 | image = (image / 2 + 0.5).clamp(0, 1) 161 | image = image.cpu().float() 162 | video.append(image) 163 | video = torch.cat(video) 164 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 165 | 166 | return video 167 | 168 | def prepare_extra_step_kwargs(self, generator, eta): 169 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 170 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 171 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 172 | # and should be between [0, 1] 173 | 174 | accepts_eta = "eta" in set( 175 | inspect.signature(self.scheduler.step).parameters.keys() 176 | ) 177 | extra_step_kwargs = {} 178 | if accepts_eta: 179 | extra_step_kwargs["eta"] = eta 180 | 181 | # check if the scheduler accepts generator 182 | accepts_generator = "generator" in set( 183 | inspect.signature(self.scheduler.step).parameters.keys() 184 | ) 185 | if accepts_generator: 186 | extra_step_kwargs["generator"] = generator 187 | return extra_step_kwargs 188 | 189 | def prepare_latents( 190 | self, 191 | batch_size, 192 | num_channels_latents, 193 | width, 194 | height, 195 | video_length, 196 | dtype, 197 | device, 198 | generator, 199 | latents=None 200 | ): 201 | shape = ( 202 | batch_size, 203 | num_channels_latents, 204 | video_length, 205 | height // self.vae_scale_factor, 206 | width // self.vae_scale_factor, 207 | ) 208 | if isinstance(generator, list) and len(generator) != batch_size: 209 | raise ValueError( 210 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 211 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 212 | ) 213 | 214 | if latents is None: 215 | latents = randn_tensor( 216 | shape, generator=generator, device=device, dtype=dtype 217 | ) 218 | 219 | else: 220 | latents = latents.to(device) 221 | 222 | # scale the initial noise by the standard deviation required by the scheduler 223 | latents = latents * self.scheduler.init_noise_sigma 224 | return latents 225 | 226 | def _encode_prompt( 227 | self, 228 | prompt, 229 | device, 230 | num_videos_per_prompt, 231 | do_classifier_free_guidance, 232 | negative_prompt, 233 | ): 234 | batch_size = len(prompt) if isinstance(prompt, list) else 1 235 | 236 | text_inputs = self.tokenizer( 237 | prompt, 238 | padding="max_length", 239 | max_length=self.tokenizer.model_max_length, 240 | truncation=True, 241 | return_tensors="pt", 242 | ) 243 | text_input_ids = text_inputs.input_ids 244 | untruncated_ids = self.tokenizer( 245 | prompt, padding="longest", return_tensors="pt" 246 | ).input_ids 247 | 248 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 249 | text_input_ids, untruncated_ids 250 | ): 251 | removed_text = self.tokenizer.batch_decode( 252 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] 253 | ) 254 | 255 | if ( 256 | hasattr(self.text_encoder.config, "use_attention_mask") 257 | and self.text_encoder.config.use_attention_mask 258 | ): 259 | attention_mask = text_inputs.attention_mask.to(device) 260 | else: 261 | attention_mask = None 262 | 263 | text_embeddings = self.text_encoder( 264 | text_input_ids.to(device), 265 | attention_mask=attention_mask, 266 | ) 267 | text_embeddings = text_embeddings[0] 268 | 269 | # duplicate text embeddings for each generation per prompt, using mps friendly method 270 | bs_embed, seq_len, _ = text_embeddings.shape 271 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 272 | text_embeddings = text_embeddings.view( 273 | bs_embed * num_videos_per_prompt, seq_len, -1 274 | ) 275 | 276 | # get unconditional embeddings for classifier free guidance 277 | if do_classifier_free_guidance: 278 | uncond_tokens: List[str] 279 | if negative_prompt is None: 280 | uncond_tokens = [""] * batch_size 281 | elif type(prompt) is not type(negative_prompt): 282 | raise TypeError( 283 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 284 | f" {type(prompt)}." 285 | ) 286 | elif isinstance(negative_prompt, str): 287 | uncond_tokens = [negative_prompt] 288 | elif batch_size != len(negative_prompt): 289 | raise ValueError( 290 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 291 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 292 | " the batch size of `prompt`." 293 | ) 294 | else: 295 | uncond_tokens = negative_prompt 296 | 297 | max_length = text_input_ids.shape[-1] 298 | uncond_input = self.tokenizer( 299 | uncond_tokens, 300 | padding="max_length", 301 | max_length=max_length, 302 | truncation=True, 303 | return_tensors="pt", 304 | ) 305 | 306 | if ( 307 | hasattr(self.text_encoder.config, "use_attention_mask") 308 | and self.text_encoder.config.use_attention_mask 309 | ): 310 | attention_mask = uncond_input.attention_mask.to(device) 311 | else: 312 | attention_mask = None 313 | 314 | uncond_embeddings = self.text_encoder( 315 | uncond_input.input_ids.to(device), 316 | attention_mask=attention_mask, 317 | ) 318 | uncond_embeddings = uncond_embeddings[0] 319 | 320 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 321 | seq_len = uncond_embeddings.shape[1] 322 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 323 | uncond_embeddings = uncond_embeddings.view( 324 | batch_size * num_videos_per_prompt, seq_len, -1 325 | ) 326 | 327 | # For classifier free guidance, we need to do two forward passes. 328 | # Here we concatenate the unconditional and text embeddings into a single batch 329 | # to avoid doing two forward passes 330 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 331 | 332 | return text_embeddings 333 | 334 | def get_timesteps(self, num_inference_steps, strength, device): 335 | # get the original timestep using init_timestep 336 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 337 | 338 | t_start = max(num_inference_steps - init_timestep, 0) 339 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] 340 | 341 | return timesteps, num_inference_steps - t_start 342 | 343 | def prepare_reference_latent(self, reference_image, height, width): 344 | reference_image_tensor = self.reference_image_processor.preprocess(reference_image, height=height, width=width) 345 | reference_image_tensor = reference_image_tensor.to(dtype=self.dtype, device=self.device) 346 | reference_image_latents = self.vae.encode(reference_image_tensor).latent_dist.mean 347 | reference_image_latents = reference_image_latents * 0.18215 348 | return reference_image_latents 349 | 350 | def prepare_kps_feature(self, kps_images, height, width, do_classifier_free_guidance): 351 | kps_image_tensors = [] 352 | for idx, kps_image in enumerate(kps_images): 353 | kps_image_tensor = self.condition_image_processor.preprocess(kps_image, height=height, width=width) 354 | kps_image_tensor = kps_image_tensor.unsqueeze(2) # [bs, c, 1, h, w] 355 | kps_image_tensors.append(kps_image_tensor) 356 | kps_images_tensor = torch.cat(kps_image_tensors, dim=2) # [bs, c, t, h, w] 357 | 358 | bs = 16 359 | num_forward = math.ceil(kps_images_tensor.shape[2] / bs) 360 | kps_feature = [] 361 | for i in range(num_forward): 362 | tensor = kps_images_tensor[:, :, i * bs:(i + 1) * bs, ...].to(device=self.device, dtype=self.dtype) 363 | feature = self.v_kps_guider(tensor).cpu() 364 | kps_feature.append(feature) 365 | torch.cuda.empty_cache() 366 | kps_feature = torch.cat(kps_feature, dim=2) 367 | 368 | if do_classifier_free_guidance: 369 | uc_kps_feature = torch.zeros_like(kps_feature) 370 | kps_feature = torch.cat([uc_kps_feature, kps_feature], dim=0) 371 | 372 | return kps_feature 373 | 374 | def prepare_audio_embeddings(self, audio_waveform, video_length, num_pad_audio_frames, do_classifier_free_guidance): 375 | audio_waveform = self.audio_processor(audio_waveform, return_tensors="pt", sampling_rate=16000)['input_values'] 376 | audio_waveform = audio_waveform.to(self.device, self.dtype) 377 | audio_embeddings = self.audio_encoder(audio_waveform).last_hidden_state # [1, num_embeds, d] 378 | 379 | audio_embeddings = torch.nn.functional.interpolate( 380 | audio_embeddings.permute(0, 2, 1), 381 | size=2 * video_length, 382 | mode='linear', 383 | )[0, :, :].permute(1, 0) # [2*vid_len, dim] 384 | 385 | audio_embeddings = torch.cat([ 386 | torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :], 387 | audio_embeddings, 388 | torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :], 389 | ], dim=0) # [2*num_pad+2*vid_len+2*num_pad, dim] 390 | 391 | frame_audio_embeddings = [] 392 | for frame_idx in range(video_length): 393 | start_sample = frame_idx 394 | end_sample = frame_idx + 2 * num_pad_audio_frames 395 | 396 | frame_audio_embedding = audio_embeddings[2 * start_sample:2 * (end_sample + 1), :] # [2*num_pad+1, dim] 397 | frame_audio_embeddings.append(frame_audio_embedding) 398 | audio_embeddings = torch.stack(frame_audio_embeddings, dim=0) # [vid_len, 2*num_pad+1, dim] 399 | 400 | audio_embeddings = self.audio_projection(audio_embeddings).unsqueeze(0) 401 | if do_classifier_free_guidance: 402 | uc_audio_embeddings = torch.zeros_like(audio_embeddings) 403 | audio_embeddings = torch.cat([uc_audio_embeddings, audio_embeddings], dim=0) 404 | return audio_embeddings 405 | 406 | @torch.no_grad() 407 | def __call__( 408 | self, 409 | reference_image, 410 | kps_images, 411 | audio_waveform, 412 | width, 413 | height, 414 | video_length, 415 | num_inference_steps, 416 | guidance_scale, 417 | strength=1., 418 | num_images_per_prompt=1, 419 | eta: float = 0.0, 420 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 421 | output_type: Optional[str] = "tensor", 422 | return_dict: bool = True, 423 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 424 | callback_steps: Optional[int] = 1, 425 | context_schedule="uniform", 426 | context_frames=24, 427 | context_overlap=4, 428 | reference_attention_weight=1., 429 | audio_attention_weight=1., 430 | num_pad_audio_frames=2, 431 | do_multi_devices_inference=False, 432 | save_gpu_memory=False, 433 | **kwargs, 434 | ): 435 | # Default height and width to unet 436 | height = height or self.unet.config.sample_size * self.vae_scale_factor 437 | width = width or self.unet.config.sample_size * self.vae_scale_factor 438 | 439 | device = self._execution_device 440 | 441 | do_classifier_free_guidance = guidance_scale > 1.0 442 | batch_size = 1 443 | 444 | # Prepare timesteps 445 | timesteps = None 446 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) 447 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 448 | 449 | reference_control_writer = ReferenceAttentionControl( 450 | self.reference_net, 451 | do_classifier_free_guidance=do_classifier_free_guidance, 452 | mode="write", 453 | batch_size=batch_size, 454 | fusion_blocks="full", 455 | ) 456 | reference_control_reader = ReferenceAttentionControl( 457 | self.denoising_unet, 458 | do_classifier_free_guidance=do_classifier_free_guidance, 459 | mode="read", 460 | batch_size=batch_size, 461 | fusion_blocks="full", 462 | reference_attention_weight=reference_attention_weight, 463 | audio_attention_weight=audio_attention_weight, 464 | ) 465 | 466 | num_channels_latents = self.denoising_unet.in_channels 467 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 468 | 469 | reference_image_latents = self.prepare_reference_latent(reference_image, height, width) 470 | kps_feature = self.prepare_kps_feature(kps_images, height, width, do_classifier_free_guidance) 471 | # if save_gpu_memory: 472 | # del self.v_kps_guider 473 | torch.cuda.empty_cache() 474 | audio_embeddings = self.prepare_audio_embeddings( 475 | audio_waveform, 476 | video_length, 477 | num_pad_audio_frames, 478 | do_classifier_free_guidance, 479 | ) 480 | # if save_gpu_memory: 481 | # del self.audio_processor, self.audio_encoder, self.audio_projection 482 | torch.cuda.empty_cache() 483 | 484 | context_scheduler = get_context_scheduler(context_schedule) 485 | context_queue = list( 486 | context_scheduler( 487 | step=0, 488 | num_frames=video_length, 489 | context_size=context_frames, 490 | context_stride=1, 491 | context_overlap=context_overlap, 492 | closed_loop=False, 493 | ) 494 | ) 495 | 496 | num_frame_context = torch.zeros(video_length, device=device, dtype=torch.long) 497 | for context in context_queue: 498 | num_frame_context[context] += 1 499 | 500 | encoder_hidden_states = torch.zeros((1, 1, 768), dtype=self.dtype, device=self.device) 501 | self.reference_net( 502 | reference_image_latents, 503 | timestep=0, 504 | encoder_hidden_states=encoder_hidden_states, 505 | return_dict=False, 506 | ) 507 | reference_control_reader.update(reference_control_writer, do_classifier_free_guidance) 508 | # if save_gpu_memory: 509 | # del self.reference_net 510 | torch.cuda.empty_cache() 511 | 512 | latents = self.prepare_latents( 513 | batch_size * num_images_per_prompt, 514 | num_channels_latents, 515 | width, 516 | height, 517 | video_length, 518 | self.dtype, 519 | torch.device('cpu'), 520 | generator, 521 | ) 522 | 523 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 524 | with self.progress_bar(total=num_inference_steps) as progress_bar: 525 | for i, t in enumerate(timesteps): 526 | context_counter = torch.zeros(video_length, device=device, dtype=torch.long) 527 | noise_preds = [None] * video_length 528 | for context_idx, context in enumerate(context_queue): 529 | latent_kps_feature = kps_feature[:, :, context].to(device, self.dtype) 530 | 531 | latent_audio_embeddings = audio_embeddings[:, context, ...] 532 | _, _, num_tokens, dim = latent_audio_embeddings.shape 533 | latent_audio_embeddings = latent_audio_embeddings.reshape(-1, num_tokens, dim) 534 | 535 | input_latents = latents[:, :, context, ...].to(device) 536 | input_latents = input_latents.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1) 537 | input_latents = self.scheduler.scale_model_input(input_latents, t) 538 | noise_pred = self.denoising_unet( 539 | input_latents, 540 | t, 541 | encoder_hidden_states=latent_audio_embeddings.reshape(-1, num_tokens, dim), 542 | kps_features=latent_kps_feature, 543 | return_dict=False, 544 | )[0] 545 | if do_classifier_free_guidance: 546 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 547 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 548 | 549 | context_counter[context] += 1 550 | noise_pred /= num_frame_context[context][None, None, :, None, None] 551 | step_frame_ids = [] 552 | step_noise_preds = [] 553 | for latent_idx, frame_idx in enumerate(context): 554 | if noise_preds[frame_idx] is None: 555 | noise_preds[frame_idx] = noise_pred[:, :, latent_idx, ...] 556 | else: 557 | noise_preds[frame_idx] += noise_pred[:, :, latent_idx, ...] 558 | if context_counter[frame_idx] == num_frame_context[frame_idx]: 559 | step_frame_ids.append(frame_idx) 560 | step_noise_preds.append(noise_preds[frame_idx]) 561 | noise_preds[frame_idx] = None 562 | step_noise_preds = torch.stack(step_noise_preds, dim=2) 563 | output_latents = self.scheduler.step( 564 | step_noise_preds, 565 | t, 566 | latents[:, :, step_frame_ids, ...].to(device), 567 | **extra_step_kwargs, 568 | ).prev_sample 569 | latents[:, :, step_frame_ids, ...] = output_latents.cpu() 570 | 571 | progress_bar.set_description( 572 | f'Denoising Step Index: {i + 1} / {len(timesteps)}, ' 573 | f'Context Index: {context_idx + 1} / {len(context_queue)}' 574 | ) 575 | 576 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 577 | progress_bar.update() 578 | if callback is not None and i % callback_steps == 0: 579 | step_idx = i // getattr(self.scheduler, "order", 1) 580 | callback(step_idx, t, latents) 581 | 582 | reference_control_reader.clear() 583 | reference_control_writer.clear() 584 | 585 | video_tensor = self.decode_latents(latents) 586 | return video_tensor 587 | -------------------------------------------------------------------------------- /src/scripts/extract_kps_sequence_and_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import cv2 5 | import torch 6 | from insightface.app import FaceAnalysis 7 | from imageio_ffmpeg import get_ffmpeg_exe 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--video_path', type=str, default='') 11 | parser.add_argument('--kps_sequence_save_path', type=str, default='') 12 | parser.add_argument('--audio_save_path', type=str, default='') 13 | parser.add_argument('--device', type=str, default='cuda') 14 | parser.add_argument('--gpu_id', type=int, default=0) 15 | parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/') 16 | parser.add_argument('--height', type=int, default=512) 17 | parser.add_argument('--width', type=int, default=512) 18 | args = parser.parse_args() 19 | 20 | app = FaceAnalysis( 21 | providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'], 22 | provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [], 23 | root=args.insightface_model_path, 24 | ) 25 | app.prepare(ctx_id=0, det_size=(args.height, args.width)) 26 | 27 | os.system(f'{get_ffmpeg_exe()} -i "{args.video_path}" -y -vn "{args.audio_save_path}"') 28 | 29 | kps_sequence = [] 30 | video_capture = cv2.VideoCapture(args.video_path) 31 | frame_idx = 0 32 | while video_capture.isOpened(): 33 | ret, frame = video_capture.read() 34 | if not ret: 35 | break 36 | frame = cv2.resize(frame, (args.width, args.height)) 37 | faces = app.get(frame) 38 | assert len(faces) == 1, f'There are {len(faces)} faces in the {frame_idx}-th frame. Only one face is supported.' 39 | 40 | kps = faces[0].kps[:3] 41 | kps_sequence.append(kps) 42 | frame_idx += 1 43 | torch.save(kps_sequence, args.kps_sequence_save_path) 44 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | from typing import Iterable 4 | import shutil 5 | import subprocess 6 | import re 7 | 8 | 9 | def ffmpeg_suitability(path): 10 | try: 11 | version = subprocess.run([path, "-version"], check=True, 12 | capture_output=True).stdout.decode("utf-8") 13 | except: 14 | return 0 15 | score = 0 16 | #rough layout of the importance of various features 17 | simple_criterion = [("libvpx", 20),("264",10), ("265",3), 18 | ("svtav1",5),("libopus", 1)] 19 | for criterion in simple_criterion: 20 | if version.find(criterion[0]) >= 0: 21 | score += criterion[1] 22 | #obtain rough compile year from copyright information 23 | copyright_index = version.find('2000-2') 24 | if copyright_index >= 0: 25 | copyright_year = version[copyright_index+6:copyright_index+9] 26 | if copyright_year.isnumeric(): 27 | score += int(copyright_year) 28 | return score 29 | 30 | 31 | def get_ffmpeg(): 32 | ffmpeg_paths = [] 33 | try: 34 | from imageio_ffmpeg import get_ffmpeg_exe 35 | imageio_ffmpeg_path = get_ffmpeg_exe() 36 | ffmpeg_paths.append(imageio_ffmpeg_path) 37 | except: 38 | print("Failed to import imageio_ffmpeg") 39 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: 40 | ffmpeg_path = imageio_ffmpeg_path 41 | else: 42 | system_ffmpeg = shutil.which("ffmpeg") 43 | if system_ffmpeg is not None: 44 | ffmpeg_paths.append(system_ffmpeg) 45 | if os.path.isfile("ffmpeg"): 46 | ffmpeg_paths.append(os.path.abspath("ffmpeg")) 47 | if os.path.isfile("ffmpeg.exe"): 48 | ffmpeg_paths.append(os.path.abspath("ffmpeg.exe")) 49 | if len(ffmpeg_paths) == 0: 50 | print("No valid ffmpeg found.") 51 | ffmpeg_path = None 52 | elif len(ffmpeg_paths) == 1: 53 | #Evaluation of suitability isn't required, can take sole option 54 | #to reduce startup time 55 | ffmpeg_path = ffmpeg_paths[0] 56 | else: 57 | ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability) 58 | return ffmpeg_path 59 | -------------------------------------------------------------------------------- /web/js/VEpreviewVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from "../../../scripts/api.js"; 3 | 4 | // 定义一个函数来预览视频 5 | function previewVideo(node, file, type) { 6 | // 清除 node 元素中的所有子元素 7 | try { 8 | var el = document.getElementById("VEpreviewVideo"); 9 | el.remove(); 10 | } catch (error) { 11 | console.log(error); 12 | } 13 | var element = document.createElement("div"); 14 | element.id = "VEpreviewVideo"; 15 | const previewNode = node; 16 | 17 | // 创建一个新的 video 元素 18 | let videoEl = document.createElement("video"); 19 | 20 | // 设置 video 元素的属性 21 | videoEl.controls = true; 22 | videoEl.style.width = "100%"; 23 | 24 | let params = { 25 | filename: file, 26 | type: type, 27 | }; 28 | // 更新 video 元素的 src 属性 29 | videoEl.src = api.apiURL("/view?" + new URLSearchParams(params)); 30 | 31 | // 重新加载并播放视频 32 | videoEl.load(); 33 | videoEl.play(); 34 | 35 | // 清除 div 元素中的所有子元素 36 | while (element.firstChild) { 37 | element.removeChild(element.firstChild); 38 | } 39 | 40 | // 将 video 元素添加到 div 元素中 41 | element.appendChild(videoEl); 42 | 43 | node.previewWidget = node.addDOMWidget("videopreview", "preview", element, { 44 | serialize: false, 45 | hideOnZoom: false, 46 | getValue() { 47 | return element.value; 48 | }, 49 | setValue(v) { 50 | element.value = v; 51 | }, 52 | }); 53 | 54 | var previewWidget = node.previewWidget; 55 | 56 | previewWidget.computeSize = function (width) { 57 | if (this.aspectRatio && !this.parentEl.hidden) { 58 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10; 59 | if (!(height > 0)) { 60 | height = 0; 61 | } 62 | this.computedHeight = height + 10; 63 | return [width, height]; 64 | } 65 | return [width, -4]; //no loaded src, widget should not display 66 | }; 67 | } 68 | 69 | app.registerExtension({ 70 | name: "ComfyUI-V-Express.VideoPreviewer", 71 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 72 | if (nodeData?.name == "VEPreview_Video") { 73 | nodeType.prototype.onExecuted = function (data) { 74 | previewVideo(this, data.video[0], data.video[1]); 75 | }; 76 | } 77 | }, 78 | }); 79 | -------------------------------------------------------------------------------- /web/js/VEuploadAudio.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from "../../../scripts/api.js"; 3 | import { ComfyWidgets } from "../../../scripts/widgets.js"; 4 | 5 | function previewAudio(node, file) { 6 | while (node.widgets.length > 2) { 7 | node.widgets.pop(); 8 | } 9 | try { 10 | var el = document.getElementById("VEuploadAudio"); 11 | el.remove(); 12 | } catch (error) { 13 | console.log(error); 14 | } 15 | var element = document.createElement("div"); 16 | element.id = "VEuploadAudio"; 17 | const previewNode = node; 18 | var previewWidget = node.addDOMWidget("audiopreview", "preview", element, { 19 | serialize: false, 20 | hideOnZoom: false, 21 | getValue() { 22 | return element.value; 23 | }, 24 | setValue(v) { 25 | element.value = v; 26 | }, 27 | }); 28 | previewWidget.computeSize = function (width) { 29 | if (this.aspectRatio && !this.parentEl.hidden) { 30 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10; 31 | if (!(height > 0)) { 32 | height = 0; 33 | } 34 | this.computedHeight = height + 10; 35 | return [width, height]; 36 | } 37 | return [width, -4]; //no loaded src, widget should not display 38 | }; 39 | // element.style['pointer-events'] = "none" 40 | previewWidget.value = { hidden: false, paused: false, params: {} }; 41 | previewWidget.parentEl = document.createElement("div"); 42 | previewWidget.parentEl.className = "audio_preview"; 43 | previewWidget.parentEl.style["width"] = "100%"; 44 | element.appendChild(previewWidget.parentEl); 45 | previewWidget.audioEl = document.createElement("audio"); 46 | previewWidget.audioEl.controls = true; 47 | previewWidget.audioEl.loop = false; 48 | previewWidget.audioEl.muted = false; 49 | previewWidget.audioEl.style["width"] = "100%"; 50 | previewWidget.audioEl.addEventListener("loadedmetadata", () => { 51 | previewWidget.aspectRatio = 52 | previewWidget.audioEl.audioWidth / previewWidget.audioEl.audioHeight; 53 | }); 54 | previewWidget.audioEl.addEventListener("error", () => { 55 | //TODO: consider a way to properly notify the user why a preview isn't shown. 56 | previewWidget.parentEl.hidden = true; 57 | }); 58 | 59 | let params = { 60 | filename: file, 61 | type: "input", 62 | }; 63 | 64 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 65 | previewWidget.audioEl.autoplay = 66 | !previewWidget.value.paused && !previewWidget.value.hidden; 67 | let target_width = 256; 68 | if (element.style?.width) { 69 | //overscale to allow scrolling. Endpoint won't return higher than native 70 | target_width = element.style.width.slice(0, -2) * 2; 71 | } 72 | if ( 73 | !params.force_size || 74 | params.force_size.includes("?") || 75 | params.force_size == "Disabled" 76 | ) { 77 | params.force_size = target_width + "x?"; 78 | } else { 79 | let size = params.force_size.split("x"); 80 | let ar = parseInt(size[0]) / parseInt(size[1]); 81 | params.force_size = target_width + "x" + target_width / ar; 82 | } 83 | 84 | previewWidget.audioEl.src = api.apiURL( 85 | "/view?" + new URLSearchParams(params) 86 | ); 87 | 88 | previewWidget.audioEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.audioEl); 90 | } 91 | 92 | function audioUpload(node, inputName, inputData, app) { 93 | const audioWidget = node.widgets.find((w) => w.name === "audio_path"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = audioWidget.value; 99 | Object.defineProperty(audioWidget, "audio_path", { 100 | set: function (value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get: function () { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if (real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | }, 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!audioWidget.options.values.includes(path)) { 145 | audioWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | audioWidget.value = path; 150 | previewAudio(node, path); 151 | } 152 | } else { 153 | alert(resp.status + " - " + resp.statusText); 154 | } 155 | } catch (error) { 156 | alert(error); 157 | } 158 | } 159 | 160 | const fileInput = document.createElement("input"); 161 | Object.assign(fileInput, { 162 | type: "file", 163 | accept: "audio/mp3", 164 | style: "display: none", 165 | onchange: async () => { 166 | if (fileInput.files.length) { 167 | await uploadFile(fileInput.files[0], true); 168 | } 169 | }, 170 | }); 171 | document.body.append(fileInput); 172 | 173 | // Create the button widget for selecting the files 174 | uploadWidget = node.addWidget( 175 | "button", 176 | "choose audio file to upload", 177 | "Audio", 178 | () => { 179 | fileInput.click(); 180 | } 181 | ); 182 | 183 | uploadWidget.serialize = false; 184 | 185 | previewAudio(node, audioWidget.value); 186 | const cb = node.callback; 187 | audioWidget.callback = function () { 188 | previewAudio(node, audioWidget.value); 189 | if (cb) { 190 | return cb.apply(this, arguments); 191 | } 192 | }; 193 | 194 | return { widget: uploadWidget }; 195 | } 196 | ComfyWidgets.VEAUDIOUPLOAD = audioUpload; 197 | 198 | app.registerExtension({ 199 | name: "ComfyUI-V-Express.UploadAudio", 200 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 201 | if (nodeData?.name == "Load_Audio_Path") { 202 | nodeData.input.required.upload = ["VEAUDIOUPLOAD"]; 203 | } 204 | }, 205 | }); 206 | -------------------------------------------------------------------------------- /web/js/VEuploadVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from "../../../scripts/api.js"; 3 | import { ComfyWidgets } from "../../../scripts/widgets.js"; 4 | 5 | function fitHeight(node) { 6 | node.setSize([ 7 | node.size[0], 8 | node.computeSize([node.size[0], node.size[1]])[1], 9 | ]); 10 | node?.graph?.setDirtyCanvas(true); 11 | } 12 | 13 | function previewVideo(node, file) { 14 | while (node.widgets.length > 2) { 15 | node.widgets.pop(); 16 | } 17 | try { 18 | var el = document.getElementById("VEuploadVideo"); 19 | el.remove(); 20 | } catch (error) { 21 | console.log(error); 22 | } 23 | var element = document.createElement("div"); 24 | element.id = "VEuploadVideo"; 25 | const previewNode = node; 26 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 27 | serialize: false, 28 | hideOnZoom: false, 29 | getValue() { 30 | return element.value; 31 | }, 32 | setValue(v) { 33 | element.value = v; 34 | }, 35 | }); 36 | previewWidget.computeSize = function (width) { 37 | if (this.aspectRatio && !this.parentEl.hidden) { 38 | let height = (previewNode.size[0] - 20) / this.aspectRatio + 10; 39 | if (!(height > 0)) { 40 | height = 0; 41 | } 42 | this.computedHeight = height + 10; 43 | return [width, height]; 44 | } 45 | return [width, -4]; //no loaded src, widget should not display 46 | }; 47 | // element.style['pointer-events'] = "none" 48 | previewWidget.value = { hidden: false, paused: false, params: {} }; 49 | previewWidget.parentEl = document.createElement("div"); 50 | previewWidget.parentEl.className = "video_preview"; 51 | previewWidget.parentEl.style["width"] = "100%"; 52 | element.appendChild(previewWidget.parentEl); 53 | previewWidget.videoEl = document.createElement("video"); 54 | previewWidget.videoEl.controls = true; 55 | previewWidget.videoEl.loop = false; 56 | previewWidget.videoEl.muted = false; 57 | previewWidget.videoEl.style["width"] = "100%"; 58 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 59 | previewWidget.aspectRatio = 60 | previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 61 | fitHeight(this); 62 | }); 63 | previewWidget.videoEl.addEventListener("error", () => { 64 | //TODO: consider a way to properly notify the user why a preview isn't shown. 65 | previewWidget.parentEl.hidden = true; 66 | fitHeight(this); 67 | }); 68 | 69 | let params = { 70 | filename: file, 71 | type: "input", 72 | }; 73 | 74 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 75 | previewWidget.videoEl.autoplay = 76 | !previewWidget.value.paused && !previewWidget.value.hidden; 77 | let target_width = 256; 78 | if (element.style?.width) { 79 | //overscale to allow scrolling. Endpoint won't return higher than native 80 | target_width = element.style.width.slice(0, -2) * 2; 81 | } 82 | if ( 83 | !params.force_size || 84 | params.force_size.includes("?") || 85 | params.force_size == "Disabled" 86 | ) { 87 | params.force_size = target_width + "x?"; 88 | } else { 89 | let size = params.force_size.split("x"); 90 | let ar = parseInt(size[0]) / parseInt(size[1]); 91 | params.force_size = target_width + "x" + target_width / ar; 92 | } 93 | 94 | previewWidget.videoEl.src = api.apiURL( 95 | "/view?" + new URLSearchParams(params) 96 | ); 97 | 98 | previewWidget.videoEl.hidden = false; 99 | previewWidget.parentEl.appendChild(previewWidget.videoEl); 100 | } 101 | 102 | function videoUpload(node, inputName, inputData, app) { 103 | const videoWidget = node.widgets.find((w) => w.name === "video_path"); 104 | let uploadWidget; 105 | /* 106 | A method that returns the required style for the html 107 | */ 108 | var default_value = videoWidget.value; 109 | Object.defineProperty(videoWidget, "video_path", { 110 | set: function (value) { 111 | this._real_value = value; 112 | }, 113 | 114 | get: function () { 115 | let value = ""; 116 | if (this._real_value) { 117 | value = this._real_value; 118 | } else { 119 | return default_value; 120 | } 121 | 122 | if (value.filename) { 123 | let real_value = value; 124 | value = ""; 125 | if (real_value.subfolder) { 126 | value = real_value.subfolder + "/"; 127 | } 128 | 129 | value += real_value.filename; 130 | 131 | if (real_value.type && real_value.type !== "input") 132 | value += ` [${real_value.type}]`; 133 | } 134 | return value; 135 | }, 136 | }); 137 | async function uploadFile(file, updateNode, pasted = false) { 138 | try { 139 | // Wrap file in formdata so it includes filename 140 | const body = new FormData(); 141 | body.append("image", file); 142 | if (pasted) body.append("subfolder", "pasted"); 143 | const resp = await api.fetchApi("/upload/image", { 144 | method: "POST", 145 | body, 146 | }); 147 | 148 | if (resp.status === 200) { 149 | const data = await resp.json(); 150 | // Add the file to the dropdown list and update the widget value 151 | let path = data.name; 152 | if (data.subfolder) path = data.subfolder + "/" + path; 153 | 154 | if (!videoWidget.options.values.includes(path)) { 155 | videoWidget.options.values.push(path); 156 | } 157 | 158 | if (updateNode) { 159 | videoWidget.value = path; 160 | previewVideo(node, path); 161 | } 162 | } else { 163 | alert(resp.status + " - " + resp.statusText); 164 | } 165 | } catch (error) { 166 | alert(error); 167 | } 168 | } 169 | 170 | const fileInput = document.createElement("input"); 171 | Object.assign(fileInput, { 172 | type: "file", 173 | accept: "video/webm,video/mp4,video/mkv,video/avi", 174 | style: "display: none", 175 | onchange: async () => { 176 | if (fileInput.files.length) { 177 | await uploadFile(fileInput.files[0], true); 178 | } 179 | }, 180 | }); 181 | document.body.append(fileInput); 182 | 183 | // Create the button widget for selecting the files 184 | uploadWidget = node.addWidget( 185 | "button", 186 | "choose video file to upload", 187 | "Video", 188 | () => { 189 | fileInput.click(); 190 | } 191 | ); 192 | 193 | uploadWidget.serialize = false; 194 | 195 | previewVideo(node, videoWidget.value); 196 | const cb = node.callback; 197 | videoWidget.callback = function () { 198 | previewVideo(node, videoWidget.value); 199 | if (cb) { 200 | return cb.apply(this, arguments); 201 | } 202 | }; 203 | 204 | return { widget: uploadWidget }; 205 | } 206 | 207 | ComfyWidgets.VEVIDEOUPLOAD = videoUpload; 208 | 209 | app.registerExtension({ 210 | name: "ComfyUI-V-Express.UploadVideo", 211 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 212 | if ( 213 | nodeData?.name == "Load_Audio_Path_From_Video" || 214 | nodeData?.name == "Load_Kps_Path_From_Video" 215 | ) { 216 | nodeData.input.required.upload = ["VEVIDEOUPLOAD"]; 217 | } 218 | }, 219 | }); 220 | -------------------------------------------------------------------------------- /workflow/V-Express.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 20, 3 | "last_link_id": 33, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "V_Express_Loader", 8 | "pos": [ 9 | 443, 10 | 14 11 | ], 12 | "size": { 13 | "0": 320.79998779296875, 14 | "1": 26 15 | }, 16 | "flags": {}, 17 | "order": 4, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "vexpress_model_path", 22 | "type": "STRING_INPUT", 23 | "link": 28 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "v_express_pipeline", 29 | "type": "V_EXPRESS_PIPELINE", 30 | "links": [ 31 | 11 32 | ], 33 | "shape": 3, 34 | "slot_index": 0 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "V_Express_Loader" 39 | } 40 | }, 41 | { 42 | "id": 6, 43 | "type": "Load_Audio_Path", 44 | "pos": [ 45 | -30, 46 | 138 47 | ], 48 | "size": { 49 | "0": 315, 50 | "1": 82 51 | }, 52 | "flags": {}, 53 | "order": 0, 54 | "mode": 0, 55 | "outputs": [ 56 | { 57 | "name": "AUDIO_PATH", 58 | "type": "AUDIO_PATH", 59 | "links": [ 60 | 13 61 | ], 62 | "shape": 3, 63 | "slot_index": 0 64 | } 65 | ], 66 | "properties": { 67 | "Node name for S&R": "Load_Audio_Path" 68 | }, 69 | "widgets_values": [ 70 | "aud.mp3", 71 | "Audio", 72 | { 73 | "hidden": false, 74 | "paused": false, 75 | "params": {} 76 | } 77 | ] 78 | }, 79 | { 80 | "id": 8, 81 | "type": "Load_Image_Path", 82 | "pos": [ 83 | -32, 84 | 368 85 | ], 86 | "size": { 87 | "0": 315, 88 | "1": 294 89 | }, 90 | "flags": {}, 91 | "order": 1, 92 | "mode": 0, 93 | "outputs": [ 94 | { 95 | "name": "IMAGE_PATH", 96 | "type": "IMAGE_PATH", 97 | "links": [ 98 | 15 99 | ], 100 | "shape": 3, 101 | "slot_index": 0 102 | } 103 | ], 104 | "properties": { 105 | "Node name for S&R": "Load_Image_Path" 106 | }, 107 | "widgets_values": [ 108 | "ref.jpg", 109 | "image" 110 | ] 111 | }, 112 | { 113 | "id": 13, 114 | "type": "Load_Kps_Path_From_Video", 115 | "pos": [ 116 | 496, 117 | 361 118 | ], 119 | "size": { 120 | "0": 367.79998779296875, 121 | "1": 455.79998779296875 122 | }, 123 | "flags": {}, 124 | "order": 5, 125 | "mode": 0, 126 | "inputs": [ 127 | { 128 | "name": "vexpress_model_path", 129 | "type": "STRING_INPUT", 130 | "link": 29 131 | }, 132 | { 133 | "name": "image_size", 134 | "type": "INT_INPUT", 135 | "link": 31 136 | } 137 | ], 138 | "outputs": [ 139 | { 140 | "name": "VKPS_PATH", 141 | "type": "VKPS_PATH", 142 | "links": [ 143 | 20 144 | ], 145 | "shape": 3, 146 | "slot_index": 0 147 | } 148 | ], 149 | "properties": { 150 | "Node name for S&R": "Load_Kps_Path_From_Video" 151 | }, 152 | "widgets_values": [ 153 | "00000gt.mp4", 154 | "Video", 155 | { 156 | "hidden": false, 157 | "paused": false, 158 | "params": {} 159 | } 160 | ] 161 | }, 162 | { 163 | "id": 18, 164 | "type": "VEStringConstant", 165 | "pos": [ 166 | -18, 167 | -50 168 | ], 169 | "size": { 170 | "0": 315, 171 | "1": 58 172 | }, 173 | "flags": {}, 174 | "order": 2, 175 | "mode": 0, 176 | "outputs": [ 177 | { 178 | "name": "STRING_INPUT", 179 | "type": "STRING_INPUT", 180 | "links": [ 181 | 28, 182 | 29, 183 | 30 184 | ], 185 | "shape": 3, 186 | "slot_index": 0 187 | } 188 | ], 189 | "properties": { 190 | "Node name for S&R": "VEStringConstant" 191 | }, 192 | "widgets_values": [ 193 | "./model_ckpts" 194 | ] 195 | }, 196 | { 197 | "id": 19, 198 | "type": "VEINTConstant", 199 | "pos": [ 200 | -43, 201 | 773 202 | ], 203 | "size": { 204 | "0": 315, 205 | "1": 58 206 | }, 207 | "flags": {}, 208 | "order": 3, 209 | "mode": 0, 210 | "outputs": [ 211 | { 212 | "name": "image_size", 213 | "type": "INT_INPUT", 214 | "links": [ 215 | 31, 216 | 32 217 | ], 218 | "shape": 3, 219 | "slot_index": 0 220 | } 221 | ], 222 | "properties": { 223 | "Node name for S&R": "VEINTConstant" 224 | }, 225 | "widgets_values": [ 226 | 512 227 | ] 228 | }, 229 | { 230 | "id": 10, 231 | "type": "V_Express_Sampler", 232 | "pos": [ 233 | 1046, 234 | 13 235 | ], 236 | "size": { 237 | "0": 393, 238 | "1": 422 239 | }, 240 | "flags": {}, 241 | "order": 6, 242 | "mode": 0, 243 | "inputs": [ 244 | { 245 | "name": "v_express_pipeline", 246 | "type": "V_EXPRESS_PIPELINE", 247 | "link": 11 248 | }, 249 | { 250 | "name": "vexpress_model_path", 251 | "type": "STRING_INPUT", 252 | "link": 30 253 | }, 254 | { 255 | "name": "audio_path", 256 | "type": "AUDIO_PATH", 257 | "link": 13 258 | }, 259 | { 260 | "name": "kps_path", 261 | "type": "VKPS_PATH", 262 | "link": 20 263 | }, 264 | { 265 | "name": "ref_image_path", 266 | "type": "IMAGE_PATH", 267 | "link": 15 268 | }, 269 | { 270 | "name": "image_size", 271 | "type": "INT_INPUT", 272 | "link": 32 273 | } 274 | ], 275 | "outputs": [ 276 | { 277 | "name": "output_path", 278 | "type": "STRING_INPUT", 279 | "links": [ 280 | 33 281 | ], 282 | "shape": 3, 283 | "slot_index": 0 284 | } 285 | ], 286 | "properties": { 287 | "Node name for S&R": "V_Express_Sampler" 288 | }, 289 | "widgets_values": [ 290 | "E:\\ComfyUI_windows_portable\\ComfyUI\\output\\test3.mp4", 291 | "fix_face", 292 | 30, 293 | 42, 294 | "fixed", 295 | 20, 296 | 3.5, 297 | 12, 298 | 1, 299 | 4, 300 | 0.95, 301 | 3 302 | ] 303 | }, 304 | { 305 | "id": 20, 306 | "type": "VEPreview_Video", 307 | "pos": [ 308 | 1527, 309 | 33 310 | ], 311 | "size": { 312 | "0": 210, 313 | "1": 26 314 | }, 315 | "flags": {}, 316 | "order": 7, 317 | "mode": 0, 318 | "inputs": [ 319 | { 320 | "name": "video", 321 | "type": "STRING_INPUT", 322 | "link": 33 323 | } 324 | ], 325 | "properties": { 326 | "Node name for S&R": "VEPreview_Video" 327 | }, 328 | "widgets_values": [ 329 | null 330 | ] 331 | } 332 | ], 333 | "links": [ 334 | [ 335 | 11, 336 | 4, 337 | 0, 338 | 10, 339 | 0, 340 | "V_EXPRESS_PIPELINE" 341 | ], 342 | [ 343 | 13, 344 | 6, 345 | 0, 346 | 10, 347 | 2, 348 | "AUDIO_PATH" 349 | ], 350 | [ 351 | 15, 352 | 8, 353 | 0, 354 | 10, 355 | 4, 356 | "IMAGE_PATH" 357 | ], 358 | [ 359 | 20, 360 | 13, 361 | 0, 362 | 10, 363 | 3, 364 | "VKPS_PATH" 365 | ], 366 | [ 367 | 28, 368 | 18, 369 | 0, 370 | 4, 371 | 0, 372 | "STRING_INPUT" 373 | ], 374 | [ 375 | 29, 376 | 18, 377 | 0, 378 | 13, 379 | 0, 380 | "STRING_INPUT" 381 | ], 382 | [ 383 | 30, 384 | 18, 385 | 0, 386 | 10, 387 | 1, 388 | "STRING_INPUT" 389 | ], 390 | [ 391 | 31, 392 | 19, 393 | 0, 394 | 13, 395 | 1, 396 | "INT_INPUT" 397 | ], 398 | [ 399 | 32, 400 | 19, 401 | 0, 402 | 10, 403 | 5, 404 | "INT_INPUT" 405 | ], 406 | [ 407 | 33, 408 | 10, 409 | 0, 410 | 20, 411 | 0, 412 | "STRING_INPUT" 413 | ] 414 | ], 415 | "groups": [], 416 | "config": {}, 417 | "extra": {}, 418 | "version": 0.4 419 | } --------------------------------------------------------------------------------