├── hallo ├── __init__.py ├── utils │ └── __init__.py ├── animate │ ├── __init__.py │ ├── face_animate_static.py │ └── face_animate.py ├── datasets │ ├── __init__.py │ ├── mask_image.py │ ├── audio_processor.py │ ├── image_processor.py │ └── talk_video.py └── models │ ├── __init__.py │ ├── image_proj.py │ ├── face_locator.py │ ├── audio_proj.py │ ├── wav2vec.py │ ├── transformer_3d.py │ ├── resnet.py │ ├── transformer_2d.py │ └── mutual_self_attention.py ├── configs ├── inference │ ├── .gitkeep │ └── default.yaml └── unet │ └── unet.yaml ├── start_docker.sh ├── start.bat ├── start.sh ├── start_collab.sh ├── docker-compose.yml ├── .pre-commit-config.yaml ├── install.sh ├── install.bat ├── requirements.txt ├── DOCKERFILE ├── setup.py ├── .gitignore ├── app.py ├── container.py ├── README.md └── scripts └── inference.py /hallo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hallo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/inference/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hallo/animate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hallo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hallo/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /start_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python container.py -------------------------------------------------------------------------------- /start.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | call venv/scripts/activate 4 | python app.py -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | python app.py -------------------------------------------------------------------------------- /start_collab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | python app.py --share -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | services: 3 | app: 4 | build: 5 | context: ./ 6 | dockerfile: DOCKERFILE 7 | ports: 8 | - "8020:7860" 9 | deploy: 10 | resources: 11 | reservations: 12 | devices: 13 | - driver: nvidia 14 | count: 1 15 | capabilities: [gpu] -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: isort 5 | name: isort 6 | language: system 7 | types: [python] 8 | pass_filenames: false 9 | entry: isort 10 | args: ["."] 11 | - id: pylint 12 | name: pylint 13 | language: system 14 | types: [python] 15 | pass_filenames: false 16 | entry: pylint 17 | args: ["**/*.py"] 18 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Clone models" 4 | git lfs install 5 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models 6 | wget -O pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true 7 | 8 | echo "Install dependencies" 9 | python3 -m venv venv 10 | source venv/bin/activate 11 | pip install -r requirements.txt 12 | pip install -e . 13 | 14 | echo "Install GPU libraries" 15 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121 16 | pip install onnxruntime-gpu 17 | 18 | echo "Installation complete" 19 | -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | echo clone models 4 | git lfs install 5 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models 6 | curl -L -o pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true 7 | 8 | echo Install Depends 9 | python -m venv venv 10 | call venv/scripts/activate 11 | pip install -r requirements.txt 12 | pip install -e . 13 | 14 | pip install bitsandbytes-windows --force-reinstall 15 | 16 | echo Install GPU libs 17 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121 18 | 19 | echo install complete 20 | pause -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | audio-separator==0.17.2 3 | av==12.1.0 4 | bitsandbytes==0.43.1 5 | decord==0.6.0 6 | diffusers==0.27.2 7 | einops==0.8.0 8 | insightface==0.7.3 9 | librosa==0.10.2.post1 10 | mediapipe[vision]==0.10.14 11 | mlflow==2.13.1 12 | moviepy==1.0.3 13 | numpy==1.26.4 14 | omegaconf==2.3.0 15 | onnx2torch==1.5.14 16 | onnx==1.16.1 17 | onnxruntime==1.18.0 18 | opencv-contrib-python==4.9.0.80 19 | opencv-python-headless==4.9.0.80 20 | opencv-python==4.9.0.80 21 | pillow==10.3.0 22 | setuptools==70.0.0 23 | torch==2.2.2 24 | torchvision==0.17.2 25 | tqdm==4.66.4 26 | transformers==4.39.2 27 | xformers==0.0.25.post1 28 | isort==5.13.2 29 | pylint==3.2.2 30 | pre-commit==3.7.1 31 | gradio==4.36.1 -------------------------------------------------------------------------------- /DOCKERFILE: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # Set the working directory in the container to /app 7 | WORKDIR /app 8 | 9 | # Install dependencies & python 10 | RUN apt-get update && apt-get install -y --no-install-recommends git wget bzip2 ffmpeg gcc g++ software-properties-common 11 | 12 | RUN add-apt-repository ppa:deadsnakes/ppa && apt update && apt install -y python3.11 13 | 14 | RUN apt install -y python3-pip 15 | 16 | # Clone the GitHub repository and checkout to the correct branch/tag 17 | RUN git clone https://github.com/moda20/hallo-webui.git . 18 | 19 | # Run install.sh to perform any setup or installations required before setting up the virtual environment 20 | 21 | RUN chmod +x install.sh && ./install.sh 22 | 23 | RUN pip install --upgrade pip 24 | 25 | # Expose port 7860 for the gradio app 26 | EXPOSE 7860 27 | 28 | RUN ln -s /usr/bin/python3 /usr/bin/python 29 | 30 | 31 | RUN chmod +x start_docker.sh 32 | 33 | # Run start.sh when the container starts 34 | CMD ["./start_docker.sh"] -------------------------------------------------------------------------------- /configs/unet/unet.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | use_audio_module: true 7 | motion_module_resolutions: 8 | - 1 9 | - 2 10 | - 4 11 | - 8 12 | motion_module_mid_block: true 13 | motion_module_decoder_only: false 14 | motion_module_type: Vanilla 15 | motion_module_kwargs: 16 | num_attention_heads: 8 17 | num_transformer_block: 1 18 | attention_block_types: 19 | - Temporal_Self 20 | - Temporal_Self 21 | temporal_position_encoding: true 22 | temporal_position_encoding_max_len: 32 23 | temporal_attention_dim_div: 1 24 | audio_attention_dim: 768 25 | stack_enable_blocks_name: 26 | - "up" 27 | - "down" 28 | - "mid" 29 | stack_enable_blocks_depth: [0,1,2,3] 30 | 31 | enable_zero_snr: true 32 | 33 | noise_scheduler_kwargs: 34 | beta_start: 0.00085 35 | beta_end: 0.012 36 | beta_schedule: "linear" 37 | clip_sample: false 38 | steps_offset: 1 39 | ### Zero-SNR params 40 | prediction_type: "v_prediction" 41 | rescale_betas_zero_snr: True 42 | timestep_spacing: "trailing" 43 | 44 | sampler: DDIM 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | ---- 4 | This is the main setup file for the hallo face animation project. It defines the package 5 | metadata, required dependencies, and provides the entry point for installing the package. 6 | 7 | """ 8 | 9 | # -*- coding: utf-8 -*- 10 | from setuptools import setup 11 | 12 | packages = \ 13 | ['hallo', 'hallo.datasets', 'hallo.models', 'hallo.animate', 'hallo.utils'] 14 | 15 | package_data = \ 16 | {'': ['*']} 17 | 18 | install_requires = \ 19 | ['accelerate==0.28.0', 20 | 'audio-separator>=0.17.2,<0.18.0', 21 | 'av==12.1.0', 22 | 'bitsandbytes==0.43.1', 23 | 'decord==0.6.0', 24 | 'diffusers==0.27.2', 25 | 'einops>=0.8.0,<0.9.0', 26 | 'insightface>=0.7.3,<0.8.0', 27 | 'mediapipe[vision]>=0.10.14,<0.11.0', 28 | 'mlflow==2.13.1', 29 | 'moviepy>=1.0.3,<2.0.0', 30 | 'omegaconf>=2.3.0,<3.0.0', 31 | 'opencv-python>=4.9.0.80,<5.0.0.0', 32 | 'pillow>=10.3.0,<11.0.0', 33 | 'torch==2.2.2', 34 | 'torchvision==0.17.2', 35 | 'transformers==4.39.2', 36 | 'xformers==0.0.25.post1'] 37 | 38 | setup_kwargs = { 39 | 'name': 'anna', 40 | 'version': '0.1.0', 41 | 'description': '', 42 | 'long_description': '# Anna face animation', 43 | 'author': 'Your Name', 44 | 'author_email': 'you@example.com', 45 | 'maintainer': 'None', 46 | 'maintainer_email': 'None', 47 | 'url': 'None', 48 | 'packages': packages, 49 | 'package_data': package_data, 50 | 'install_requires': install_requires, 51 | 'python_requires': '>=3.10,<4.0', 52 | } 53 | 54 | 55 | setup(**setup_kwargs) 56 | -------------------------------------------------------------------------------- /configs/inference/default.yaml: -------------------------------------------------------------------------------- 1 | source_image: ./default.png 2 | driving_audio: default.wav 3 | 4 | weight_dtype: fp16 5 | 6 | data: 7 | n_motion_frames: 2 8 | n_sample_frames: 16 9 | source_image: 10 | width: 512 11 | height: 512 12 | driving_audio: 13 | sample_rate: 16000 14 | export_video: 15 | fps: 25 16 | 17 | inference_steps: 40 18 | cfg_scale: 3.5 19 | 20 | audio_ckpt_dir: ./pretrained_models/hallo 21 | 22 | base_model_path: ./pretrained_models/stable-diffusion-v1-5 23 | 24 | motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt 25 | 26 | face_analysis: 27 | model_path: ./pretrained_models/face_analysis 28 | 29 | wav2vec: 30 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h 31 | features: all 32 | 33 | audio_separator: 34 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx 35 | 36 | vae: 37 | model_path: ./pretrained_models/sd-vae-ft-mse 38 | 39 | save_path: ./.cache 40 | 41 | face_expand_ratio: 1.1 42 | pose_weight: 1.1 43 | face_weight: 1.1 44 | lip_weight: 1.1 45 | 46 | unet_additional_kwargs: 47 | use_inflated_groupnorm: true 48 | unet_use_cross_frame_attention: false 49 | unet_use_temporal_attention: false 50 | use_motion_module: true 51 | use_audio_module: true 52 | motion_module_resolutions: 53 | - 1 54 | - 2 55 | - 4 56 | - 8 57 | motion_module_mid_block: true 58 | motion_module_decoder_only: false 59 | motion_module_type: Vanilla 60 | motion_module_kwargs: 61 | num_attention_heads: 8 62 | num_transformer_block: 1 63 | attention_block_types: 64 | - Temporal_Self 65 | - Temporal_Self 66 | temporal_position_encoding: true 67 | temporal_position_encoding_max_len: 32 68 | temporal_attention_dim_div: 1 69 | audio_attention_dim: 768 70 | stack_enable_blocks_name: 71 | - "up" 72 | - "down" 73 | - "mid" 74 | stack_enable_blocks_depth: [0,1,2,3] 75 | 76 | 77 | enable_zero_snr: true 78 | 79 | noise_scheduler_kwargs: 80 | beta_start: 0.00085 81 | beta_end: 0.012 82 | beta_schedule: "linear" 83 | clip_sample: false 84 | steps_offset: 1 85 | ### Zero-SNR params 86 | prediction_type: "v_prediction" 87 | rescale_betas_zero_snr: True 88 | timestep_spacing: "trailing" 89 | 90 | sampler: DDIM 91 | -------------------------------------------------------------------------------- /hallo/models/image_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | image_proj_model.py 3 | 4 | This module defines the ImageProjModel class, which is responsible for 5 | projecting image embeddings into a different dimensional space. The model 6 | leverages a linear transformation followed by a layer normalization to 7 | reshape and normalize the input image embeddings for further processing in 8 | cross-attention mechanisms or other downstream tasks. 9 | 10 | Classes: 11 | ImageProjModel 12 | 13 | Dependencies: 14 | torch 15 | diffusers.ModelMixin 16 | 17 | """ 18 | 19 | import torch 20 | from diffusers import ModelMixin 21 | 22 | 23 | class ImageProjModel(ModelMixin): 24 | """ 25 | ImageProjModel is a class that projects image embeddings into a different 26 | dimensional space. It inherits from ModelMixin, providing additional functionalities 27 | specific to image projection. 28 | 29 | Attributes: 30 | cross_attention_dim (int): The dimension of the cross attention. 31 | clip_embeddings_dim (int): The dimension of the CLIP embeddings. 32 | clip_extra_context_tokens (int): The number of extra context tokens in CLIP. 33 | 34 | Methods: 35 | forward(image_embeds): Forward pass of the ImageProjModel, which takes in image 36 | embeddings and returns the projected tokens. 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | cross_attention_dim=1024, 43 | clip_embeddings_dim=1024, 44 | clip_extra_context_tokens=4, 45 | ): 46 | super().__init__() 47 | 48 | self.generator = None 49 | self.cross_attention_dim = cross_attention_dim 50 | self.clip_extra_context_tokens = clip_extra_context_tokens 51 | self.proj = torch.nn.Linear( 52 | clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim 53 | ) 54 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 55 | 56 | def forward(self, image_embeds): 57 | """ 58 | Forward pass of the ImageProjModel, which takes in image embeddings and returns the 59 | projected tokens after reshaping and normalization. 60 | 61 | Args: 62 | image_embeds (torch.Tensor): The input image embeddings, with shape 63 | batch_size x num_image_tokens x clip_embeddings_dim. 64 | 65 | Returns: 66 | clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping 67 | and normalization, with shape batch_size x (clip_extra_context_tokens * 68 | cross_attention_dim). 69 | 70 | """ 71 | embeds = image_embeds 72 | clip_extra_context_tokens = self.proj(embeds).reshape( 73 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 74 | ) 75 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 76 | return clip_extra_context_tokens 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # running cache 2 | mlruns/ 3 | 4 | # Test directories 5 | test_data/ 6 | pretrained_models/ 7 | 8 | # Poetry project 9 | poetry.lock 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # IDE 166 | .idea/ 167 | .vscode/ 168 | data 169 | pretrained_models 170 | test_data 171 | /output 172 | -------------------------------------------------------------------------------- /hallo/models/face_locator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the FaceLocator class, which is a neural network model designed to 3 | locate and extract facial features from input images or tensors. It uses a series of 4 | convolutional layers to progressively downsample and refine the facial feature map. 5 | 6 | The FaceLocator class is part of a larger system that may involve facial recognition or 7 | similar tasks where precise location and extraction of facial features are required. 8 | 9 | Attributes: 10 | conditioning_embedding_channels (int): The number of channels in the output embedding. 11 | conditioning_channels (int): The number of input channels for the conditioning tensor. 12 | block_out_channels (Tuple[int]): A tuple of integers representing the output channels 13 | for each block in the model. 14 | 15 | The model uses the following components: 16 | - InflatedConv3d: A convolutional layer that inflates the input to increase the depth. 17 | - zero_module: A utility function that may set certain parameters to zero for regularization 18 | or other purposes. 19 | 20 | The forward method of the FaceLocator class takes a conditioning tensor as input and 21 | produces an embedding tensor as output, which can be used for further processing or analysis. 22 | """ 23 | 24 | from typing import Tuple 25 | 26 | import torch.nn.functional as F 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from torch import nn 29 | 30 | from .motion_module import zero_module 31 | from .resnet import InflatedConv3d 32 | 33 | 34 | class FaceLocator(ModelMixin): 35 | """ 36 | The FaceLocator class is a neural network model designed to process and extract facial 37 | features from an input tensor. It consists of a series of convolutional layers that 38 | progressively downsample the input while increasing the depth of the feature map. 39 | 40 | The model is built using InflatedConv3d layers, which are designed to inflate the 41 | feature channels, allowing for more complex feature extraction. The final output is a 42 | conditioning embedding that can be used for various tasks such as facial recognition or 43 | feature-based image manipulation. 44 | 45 | Parameters: 46 | conditioning_embedding_channels (int): The number of channels in the output embedding. 47 | conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3. 48 | block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels 49 | for each block in the model. The default is (16, 32, 64, 128), which defines the 50 | progression of the network's depth. 51 | 52 | Attributes: 53 | conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process. 54 | blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model. 55 | conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding. 56 | 57 | The forward method applies the convolutional layers to the input conditioning tensor and 58 | returns the resulting embedding tensor. 59 | """ 60 | def __init__( 61 | self, 62 | conditioning_embedding_channels: int, 63 | conditioning_channels: int = 3, 64 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 65 | ): 66 | super().__init__() 67 | self.conv_in = InflatedConv3d( 68 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 69 | ) 70 | 71 | self.blocks = nn.ModuleList([]) 72 | 73 | for i in range(len(block_out_channels) - 1): 74 | channel_in = block_out_channels[i] 75 | channel_out = block_out_channels[i + 1] 76 | self.blocks.append( 77 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 78 | ) 79 | self.blocks.append( 80 | InflatedConv3d( 81 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 82 | ) 83 | ) 84 | 85 | self.conv_out = zero_module( 86 | InflatedConv3d( 87 | block_out_channels[-1], 88 | conditioning_embedding_channels, 89 | kernel_size=3, 90 | padding=1, 91 | ) 92 | ) 93 | 94 | def forward(self, conditioning): 95 | """ 96 | Forward pass of the FaceLocator model. 97 | 98 | Args: 99 | conditioning (Tensor): The input conditioning tensor. 100 | 101 | Returns: 102 | Tensor: The output embedding tensor. 103 | """ 104 | embedding = self.conv_in(conditioning) 105 | embedding = F.silu(embedding) 106 | 107 | for block in self.blocks: 108 | embedding = block(embedding) 109 | embedding = F.silu(embedding) 110 | 111 | embedding = self.conv_out(embedding) 112 | 113 | return embedding 114 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gradio as gr 3 | import subprocess 4 | from datetime import datetime 5 | import os 6 | import platform 7 | 8 | def generate_video(ref_img, ref_audio,settings_face_expand_ratio=1.2, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16): 9 | # Ensure file paths are correct 10 | if not os.path.isfile(ref_img) or not os.path.isfile(ref_audio): 11 | return "Error: File not found", None 12 | 13 | # Path to the output video file 14 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 15 | # Check if output exists and if not create it 16 | if not os.path.exists("output"): 17 | os.makedirs("output") 18 | 19 | output_video = f"output/{timestamp}.mp4" 20 | 21 | # Determine the command based on the operating system 22 | if platform.system() == "Windows": 23 | command = [ 24 | "venv\\Scripts\\python.exe", 25 | "scripts\\inference.py", 26 | "--source_image", ref_img, 27 | "--driving_audio", ref_audio, 28 | "--output", output_video, 29 | "--face_expand_ratio", str(settings_face_expand_ratio), 30 | "--setting_steps", str(setting_steps), 31 | "--setting_cfg", str(setting_cfg), 32 | "--settings_seed", str(settings_seed), 33 | "--settings_fps", str(settings_fps), 34 | "--settings_motion_pose_scale", str(settings_motion_pose_scale), 35 | "--settings_motion_face_scale", str(settings_motion_face_scale), 36 | "--settings_motion_lip_scale", str(settings_motion_lip_scale), 37 | "--settings_n_motion_frames", str(settings_n_motion_frames), 38 | "--settings_n_sample_frames", str(settings_n_sample_frames) 39 | ] 40 | else: 41 | command = [ 42 | "python3", 43 | "scripts/inference.py", 44 | "--source_image", ref_img, 45 | "--driving_audio", ref_audio, 46 | "--output", output_video, 47 | "--setting_steps", str(setting_steps), 48 | "--setting_cfg", str(setting_cfg), 49 | "--settings_seed", str(settings_seed), 50 | "--settings_fps", str(settings_fps), 51 | "--settings_motion_pose_scale", str(settings_motion_pose_scale), 52 | "--settings_motion_face_scale", str(settings_motion_face_scale), 53 | "--settings_motion_lip_scale", str(settings_motion_lip_scale), 54 | "--settings_n_motion_frames", str(settings_n_motion_frames), 55 | "--settings_n_sample_frames", str(settings_n_sample_frames) 56 | ] 57 | 58 | try: 59 | # Execute the command 60 | result = subprocess.run(command, check=True) 61 | 62 | if result.returncode == 0: 63 | return "Video generated successfully", output_video 64 | else: 65 | return "Error generating video", None 66 | 67 | except subprocess.CalledProcessError as e: 68 | return f"Error: {str(e)}", None 69 | 70 | 71 | with gr.Blocks() as demo: 72 | with gr.Row(): 73 | with gr.Column(): 74 | ref_img = gr.Image(label="Reference Image", type="filepath") 75 | ref_audio = gr.Audio(label="Audio", type="filepath") 76 | with gr.Accordion("Settings", open=True): 77 | settings_face_expand_ratio = gr.Slider(label="Face Expand Ratio", value=1.2, minimum=0, maximum=10, step=0.01) 78 | setting_steps = gr.Slider(label="Steps", value=40, minimum=1, maximum=200, step=1) 79 | setting_cfg = gr.Slider(label="CFG Scale", value=3.5, minimum=0, maximum=10, step=0.01) 80 | settings_seed = gr.Textbox(label="Seed", value=42) 81 | settings_fps = gr.Slider(label="FPS", value=25, minimum=1, maximum=200, step=1) 82 | with gr.Accordion("Motion Scale", open=True): 83 | settings_motion_pose_scale = gr.Slider(label="Motion Pose Scale", value=1.0, minimum=0, maximum=5, step=0.01) 84 | settings_motion_face_scale = gr.Slider(label="Motion Face Scale", value=1.0, minimum=0, maximum=5, step=0.01) 85 | settings_motion_lip_scale = gr.Slider(label="Motion Lip Scale", value=1.0, minimum=0, maximum=5, step=0.01) 86 | with gr.Accordion("Extra Settings", open=True): 87 | settings_n_motion_frames = gr.Slider(label="N Motion Frames", value=2, minimum=1, maximum=100, step=1) 88 | settings_n_sample_frames = gr.Slider(label="N Sample Frames", value=16, minimum=1, maximum=100, step=1) 89 | with gr.Column(): 90 | result_status = gr.Label(value="Status") 91 | result_video = gr.Video(label="Result Video", interactive=False) 92 | result_btn = gr.Button(value="Generate Video") 93 | 94 | result_btn.click(fn=generate_video, inputs=[ref_img, ref_audio,settings_face_expand_ratio, setting_steps, setting_cfg, settings_seed, settings_fps, settings_motion_pose_scale, settings_motion_face_scale, settings_motion_lip_scale, settings_n_motion_frames, settings_n_sample_frames], outputs=[result_status, result_video]) 95 | 96 | if __name__ == "__main__": 97 | share_url = False if "--share" not in sys.argv else True 98 | 99 | demo.queue() 100 | demo.launch(inbrowser=True, share=share_url) 101 | 102 | -------------------------------------------------------------------------------- /container.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gradio as gr 3 | import subprocess 4 | from datetime import datetime 5 | import os 6 | import platform 7 | 8 | def generate_video(ref_img, ref_audio,settings_face_expand_ratio=1.2, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16): 9 | # Ensure file paths are correct 10 | if not os.path.isfile(ref_img) or not os.path.isfile(ref_audio): 11 | return "Error: File not found", None 12 | 13 | # Path to the output video file 14 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 15 | # Check if output exists and if not create it 16 | if not os.path.exists("output"): 17 | os.makedirs("output") 18 | 19 | output_video = f"output/{timestamp}.mp4" 20 | 21 | # Determine the command based on the operating system 22 | if platform.system() == "Windows": 23 | command = [ 24 | "venv\\Scripts\\python.exe", 25 | "scripts\\inference.py", 26 | "--source_image", ref_img, 27 | "--driving_audio", ref_audio, 28 | "--output", output_video, 29 | "--face_expand_ratio", str(settings_face_expand_ratio), 30 | "--setting_steps", str(setting_steps), 31 | "--setting_cfg", str(setting_cfg), 32 | "--settings_seed", str(settings_seed), 33 | "--settings_fps", str(settings_fps), 34 | "--settings_motion_pose_scale", str(settings_motion_pose_scale), 35 | "--settings_motion_face_scale", str(settings_motion_face_scale), 36 | "--settings_motion_lip_scale", str(settings_motion_lip_scale), 37 | "--settings_n_motion_frames", str(settings_n_motion_frames), 38 | "--settings_n_sample_frames", str(settings_n_sample_frames) 39 | ] 40 | else: 41 | command = [ 42 | "python3", 43 | "scripts/inference.py", 44 | "--source_image", ref_img, 45 | "--driving_audio", ref_audio, 46 | "--output", output_video, 47 | "--setting_steps", str(setting_steps), 48 | "--setting_cfg", str(setting_cfg), 49 | "--settings_seed", str(settings_seed), 50 | "--settings_fps", str(settings_fps), 51 | "--settings_motion_pose_scale", str(settings_motion_pose_scale), 52 | "--settings_motion_face_scale", str(settings_motion_face_scale), 53 | "--settings_motion_lip_scale", str(settings_motion_lip_scale), 54 | "--settings_n_motion_frames", str(settings_n_motion_frames), 55 | "--settings_n_sample_frames", str(settings_n_sample_frames) 56 | ] 57 | 58 | try: 59 | # Execute the command 60 | result = subprocess.run(command, check=True) 61 | 62 | if result.returncode == 0: 63 | return "Video generated successfully", output_video 64 | else: 65 | return "Error generating video", None 66 | 67 | except subprocess.CalledProcessError as e: 68 | return f"Error: {str(e)}", None 69 | 70 | 71 | with gr.Blocks() as demo: 72 | with gr.Row(): 73 | with gr.Column(): 74 | ref_img = gr.Image(label="Reference Image", type="filepath") 75 | ref_audio = gr.Audio(label="Audio", type="filepath") 76 | with gr.Accordion("Settings", open=True): 77 | settings_face_expand_ratio = gr.Slider(label="Face Expand Ratio", value=1.2, minimum=0, maximum=10, step=0.01) 78 | setting_steps = gr.Slider(label="Steps", value=40, minimum=1, maximum=200, step=1) 79 | setting_cfg = gr.Slider(label="CFG Scale", value=3.5, minimum=0, maximum=10, step=0.01) 80 | settings_seed = gr.Textbox(label="Seed", value=42) 81 | settings_fps = gr.Slider(label="FPS", value=25, minimum=1, maximum=200, step=1) 82 | with gr.Accordion("Motion Scale", open=True): 83 | settings_motion_pose_scale = gr.Slider(label="Motion Pose Scale", value=1.0, minimum=0, maximum=5, step=0.01) 84 | settings_motion_face_scale = gr.Slider(label="Motion Face Scale", value=1.0, minimum=0, maximum=5, step=0.01) 85 | settings_motion_lip_scale = gr.Slider(label="Motion Lip Scale", value=1.0, minimum=0, maximum=5, step=0.01) 86 | with gr.Accordion("Extra Settings", open=True): 87 | settings_n_motion_frames = gr.Slider(label="N Motion Frames", value=2, minimum=1, maximum=100, step=1) 88 | settings_n_sample_frames = gr.Slider(label="N Sample Frames", value=16, minimum=1, maximum=100, step=1) 89 | with gr.Column(): 90 | result_status = gr.Label(value="Status") 91 | result_video = gr.Video(label="Result Video", interactive=False) 92 | result_btn = gr.Button(value="Generate Video") 93 | 94 | result_btn.click(fn=generate_video, inputs=[ref_img, ref_audio,settings_face_expand_ratio, setting_steps, setting_cfg, settings_seed, settings_fps, settings_motion_pose_scale, settings_motion_face_scale, settings_motion_lip_scale, settings_n_motion_frames, settings_n_sample_frames], outputs=[result_status, result_video]) 95 | 96 | if __name__ == "__main__": 97 | share_url = False if "--share" not in sys.argv else True 98 | 99 | demo.queue() 100 | demo.launch(inbrowser=True, share=share_url, server_port=7860, server_name="0.0.0.0") -------------------------------------------------------------------------------- /hallo/models/audio_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the implementation of an Audio Projection Model, which is designed for 3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens 4 | that can be used for various downstream applications, such as audio analysis or synthesis. 5 | 6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which 7 | provides a foundation for building custom models. This implementation includes multiple linear 8 | layers with ReLU activation functions and a LayerNorm for normalization. 9 | 10 | Key Features: 11 | - Audio embedding input with flexible sequence length and block structure. 12 | - Multiple linear layers for feature transformation. 13 | - ReLU activation for non-linear transformation. 14 | - LayerNorm for stabilizing and speeding up training. 15 | - Rearrangement of input embeddings to match the model's expected input shape. 16 | - Customizable number of blocks, channels, and context tokens for adaptability. 17 | 18 | The module is structured to be easily integrated into larger systems or used as a standalone 19 | component for audio feature extraction and processing. 20 | 21 | Classes: 22 | - AudioProjModel: A class representing the audio projection model with configurable parameters. 23 | 24 | Functions: 25 | - (none) 26 | 27 | Dependencies: 28 | - torch: For tensor operations and neural network components. 29 | - diffusers: For the ModelMixin base class. 30 | - einops: For tensor rearrangement operations. 31 | 32 | """ 33 | 34 | import torch 35 | from diffusers import ModelMixin 36 | from einops import rearrange 37 | from torch import nn 38 | 39 | 40 | class AudioProjModel(ModelMixin): 41 | """Audio Projection Model 42 | 43 | This class defines an audio projection model that takes audio embeddings as input 44 | and produces context tokens as output. The model is based on the ModelMixin class 45 | and consists of multiple linear layers and activation functions. It can be used 46 | for various audio processing tasks. 47 | 48 | Attributes: 49 | seq_len (int): The length of the audio sequence. 50 | blocks (int): The number of blocks in the audio projection model. 51 | channels (int): The number of channels in the audio projection model. 52 | intermediate_dim (int): The intermediate dimension of the model. 53 | context_tokens (int): The number of context tokens in the output. 54 | output_dim (int): The output dimension of the context tokens. 55 | 56 | Methods: 57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): 58 | Initializes the AudioProjModel with the given parameters. 59 | forward(self, audio_embeds): 60 | Defines the forward pass for the AudioProjModel. 61 | Parameters: 62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 63 | Returns: 64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | seq_len=5, 71 | blocks=12, # add a new parameter blocks 72 | channels=768, # add a new parameter channels 73 | intermediate_dim=512, 74 | output_dim=768, 75 | context_tokens=32, 76 | ): 77 | super().__init__() 78 | 79 | self.seq_len = seq_len 80 | self.blocks = blocks 81 | self.channels = channels 82 | self.input_dim = ( 83 | seq_len * blocks * channels 84 | ) # update input_dim to be the product of blocks and channels. 85 | self.intermediate_dim = intermediate_dim 86 | self.context_tokens = context_tokens 87 | self.output_dim = output_dim 88 | 89 | # define multiple linear layers 90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 93 | 94 | self.norm = nn.LayerNorm(output_dim) 95 | 96 | def forward(self, audio_embeds): 97 | """ 98 | Defines the forward pass for the AudioProjModel. 99 | 100 | Parameters: 101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 102 | 103 | Returns: 104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 105 | """ 106 | # merge 107 | video_length = audio_embeds.shape[1] 108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 109 | batch_size, window_size, blocks, channels = audio_embeds.shape 110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 111 | 112 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 113 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 114 | 115 | context_tokens = self.proj3(audio_embeds).reshape( 116 | batch_size, self.context_tokens, self.output_dim 117 | ) 118 | 119 | context_tokens = self.norm(context_tokens) 120 | context_tokens = rearrange( 121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length 122 | ) 123 | 124 | return context_tokens 125 | -------------------------------------------------------------------------------- /hallo/datasets/mask_image.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module contains the code for a dataset class called FaceMaskDataset, which is used to process and 4 | load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and 5 | provides methods for data augmentation, getting items from the dataset, and determining the length of the 6 | dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch, 7 | PIL, and transformers. 8 | """ 9 | 10 | import json 11 | import random 12 | from pathlib import Path 13 | 14 | import torch 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from torchvision import transforms 18 | from transformers import CLIPImageProcessor 19 | 20 | 21 | class FaceMaskDataset(Dataset): 22 | """ 23 | FaceMaskDataset is a custom dataset for face mask images. 24 | 25 | Args: 26 | img_size (int): The size of the input images. 27 | drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1. 28 | data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"]. 29 | sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30. 30 | 31 | Attributes: 32 | img_size (int): The size of the input images. 33 | drop_ratio (float): The ratio of dropped pixels during data augmentation. 34 | data_meta_paths (list): The paths to the metadata files containing image paths and labels. 35 | sample_margin (int): The margin for sampling regions in the image. 36 | processor (CLIPImageProcessor): The image processor for preprocessing images. 37 | transform (transforms.Compose): The image augmentation transform. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | img_size, 43 | drop_ratio=0.1, 44 | data_meta_paths=None, 45 | sample_margin=30, 46 | ): 47 | super().__init__() 48 | 49 | self.img_size = img_size 50 | self.sample_margin = sample_margin 51 | 52 | vid_meta = [] 53 | for data_meta_path in data_meta_paths: 54 | with open(data_meta_path, "r", encoding="utf-8") as f: 55 | vid_meta.extend(json.load(f)) 56 | self.vid_meta = vid_meta 57 | self.length = len(self.vid_meta) 58 | 59 | self.clip_image_processor = CLIPImageProcessor() 60 | 61 | self.transform = transforms.Compose( 62 | [ 63 | transforms.Resize(self.img_size), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.5], [0.5]), 66 | ] 67 | ) 68 | 69 | self.cond_transform = transforms.Compose( 70 | [ 71 | transforms.Resize(self.img_size), 72 | transforms.ToTensor(), 73 | ] 74 | ) 75 | 76 | self.drop_ratio = drop_ratio 77 | 78 | def augmentation(self, image, transform, state=None): 79 | """ 80 | Apply data augmentation to the input image. 81 | 82 | Args: 83 | image (PIL.Image): The input image. 84 | transform (torchvision.transforms.Compose): The data augmentation transforms. 85 | state (dict, optional): The random state for reproducibility. Defaults to None. 86 | 87 | Returns: 88 | PIL.Image: The augmented image. 89 | """ 90 | if state is not None: 91 | torch.set_rng_state(state) 92 | return transform(image) 93 | 94 | def __getitem__(self, index): 95 | video_meta = self.vid_meta[index] 96 | video_path = video_meta["image_path"] 97 | mask_path = video_meta["mask_path"] 98 | face_emb_path = video_meta["face_emb"] 99 | 100 | video_frames = sorted(Path(video_path).iterdir()) 101 | video_length = len(video_frames) 102 | 103 | margin = min(self.sample_margin, video_length) 104 | 105 | ref_img_idx = random.randint(0, video_length - 1) 106 | if ref_img_idx + margin < video_length: 107 | tgt_img_idx = random.randint( 108 | ref_img_idx + margin, video_length - 1) 109 | elif ref_img_idx - margin > 0: 110 | tgt_img_idx = random.randint(0, ref_img_idx - margin) 111 | else: 112 | tgt_img_idx = random.randint(0, video_length - 1) 113 | 114 | ref_img_pil = Image.open(video_frames[ref_img_idx]) 115 | tgt_img_pil = Image.open(video_frames[tgt_img_idx]) 116 | 117 | tgt_mask_pil = Image.open(mask_path) 118 | 119 | assert ref_img_pil is not None, "Fail to load reference image." 120 | assert tgt_img_pil is not None, "Fail to load target image." 121 | assert tgt_mask_pil is not None, "Fail to load target mask." 122 | 123 | state = torch.get_rng_state() 124 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state) 125 | tgt_mask_img = self.augmentation( 126 | tgt_mask_pil, self.cond_transform, state) 127 | tgt_mask_img = tgt_mask_img.repeat(3, 1, 1) 128 | ref_img_vae = self.augmentation( 129 | ref_img_pil, self.transform, state) 130 | face_emb = torch.load(face_emb_path) 131 | 132 | 133 | sample = { 134 | "video_dir": video_path, 135 | "img": tgt_img, 136 | "tgt_mask": tgt_mask_img, 137 | "ref_img": ref_img_vae, 138 | "face_emb": face_emb, 139 | } 140 | 141 | return sample 142 | 143 | def __len__(self): 144 | return len(self.vid_meta) 145 | 146 | 147 | if __name__ == "__main__": 148 | data = FaceMaskDataset(img_size=(512, 512)) 149 | train_dataloader = torch.utils.data.DataLoader( 150 | data, batch_size=4, shuffle=True, num_workers=1 151 | ) 152 | for step, batch in enumerate(train_dataloader): 153 | print(batch["tgt_mask"].shape) 154 | break 155 | -------------------------------------------------------------------------------- /hallo/datasets/audio_processor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | ''' 3 | This module contains the AudioProcessor class and related functions for processing audio data. 4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, 5 | and audio separation. The class is initialized with configuration parameters and can process 6 | audio files using the provided models. 7 | ''' 8 | import math 9 | import os 10 | 11 | import librosa 12 | import numpy as np 13 | import torch 14 | from audio_separator.separator import Separator 15 | from einops import rearrange 16 | from transformers import Wav2Vec2FeatureExtractor 17 | 18 | from hallo.models.wav2vec import Wav2VecModel 19 | from hallo.utils.util import resample_audio 20 | 21 | 22 | class AudioProcessor: 23 | """ 24 | AudioProcessor is a class that handles the processing of audio files. 25 | It takes care of preprocessing the audio files, extracting features 26 | using wav2vec models, and separating audio signals if needed. 27 | 28 | :param sample_rate: Sampling rate of the audio file 29 | :param fps: Frames per second for the extracted features 30 | :param wav2vec_model_path: Path to the wav2vec model 31 | :param only_last_features: Whether to only use the last features 32 | :param audio_separator_model_path: Path to the audio separator model 33 | :param audio_separator_model_name: Name of the audio separator model 34 | :param cache_dir: Directory to cache the intermediate results 35 | :param device: Device to run the processing on 36 | """ 37 | def __init__( 38 | self, 39 | sample_rate, 40 | fps, 41 | wav2vec_model_path, 42 | only_last_features, 43 | audio_separator_model_path:str=None, 44 | audio_separator_model_name:str=None, 45 | cache_dir:str='', 46 | device="cuda:0", 47 | ) -> None: 48 | self.sample_rate = sample_rate 49 | self.fps = fps 50 | self.device = device 51 | 52 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) 53 | self.audio_encoder.feature_extractor._freeze_parameters() 54 | self.only_last_features = only_last_features 55 | 56 | if audio_separator_model_name is not None: 57 | try: 58 | os.makedirs(cache_dir, exist_ok=True) 59 | except OSError as _: 60 | print("Fail to create the output cache dir.") 61 | self.audio_separator = Separator( 62 | output_dir=cache_dir, 63 | output_single_stem="vocals", 64 | model_file_dir=audio_separator_model_path, 65 | ) 66 | self.audio_separator.load_model(audio_separator_model_name) 67 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." 68 | else: 69 | self.audio_separator=None 70 | print("Use audio directly without vocals seperator.") 71 | 72 | 73 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) 74 | 75 | 76 | def preprocess(self, wav_file: str): 77 | """ 78 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. 79 | The separated vocal track is then converted into wav2vec2 for further processing or analysis. 80 | 81 | Args: 82 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 83 | 84 | Raises: 85 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues 86 | such as file not found, unsupported file format, or errors during the audio processing steps. 87 | 88 | Returns: 89 | torch.tensor: Returns an audio embedding as a torch.tensor 90 | """ 91 | if self.audio_separator is not None: 92 | # 1. separate vocals 93 | # TODO: process in memory 94 | outputs = self.audio_separator.separate(wav_file) 95 | if len(outputs) <= 0: 96 | raise RuntimeError("Audio separate failed.") 97 | 98 | vocal_audio_file = outputs[0] 99 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file) 100 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) 101 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) 102 | else: 103 | vocal_audio_file=wav_file 104 | 105 | # 2. extract wav2vec features 106 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) 107 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) 108 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 109 | 110 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) 111 | audio_feature = audio_feature.unsqueeze(0) 112 | 113 | with torch.no_grad(): 114 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) 115 | assert len(embeddings) > 0, "Fail to extract audio embedding" 116 | if self.only_last_features: 117 | audio_emb = embeddings.last_hidden_state.squeeze() 118 | else: 119 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) 120 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 121 | 122 | audio_emb = audio_emb.cpu().detach() 123 | 124 | return audio_emb 125 | 126 | def get_embedding(self, wav_file: str): 127 | """preprocess wav audio file convert to embeddings 128 | 129 | Args: 130 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 131 | 132 | Returns: 133 | torch.tensor: Returns an audio embedding as a torch.tensor 134 | """ 135 | speech_array, sampling_rate = librosa.load( 136 | wav_file, sr=self.sample_rate) 137 | assert sampling_rate == 16000, "The audio sample rate must be 16000" 138 | audio_feature = np.squeeze(self.wav2vec_feature_extractor( 139 | speech_array, sampling_rate=sampling_rate).input_values) 140 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 141 | 142 | audio_feature = torch.from_numpy( 143 | audio_feature).float().to(device=self.device) 144 | audio_feature = audio_feature.unsqueeze(0) 145 | 146 | with torch.no_grad(): 147 | embeddings = self.audio_encoder( 148 | audio_feature, seq_len=seq_len, output_hidden_states=True) 149 | assert len(embeddings) > 0, "Fail to extract audio embedding" 150 | 151 | if self.only_last_features: 152 | audio_emb = embeddings.last_hidden_state.squeeze() 153 | else: 154 | audio_emb = torch.stack( 155 | embeddings.hidden_states[1:], dim=1).squeeze(0) 156 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 157 | 158 | audio_emb = audio_emb.cpu().detach() 159 | 160 | return audio_emb 161 | 162 | def close(self): 163 | """ 164 | TODO: to be implemented 165 | """ 166 | return self 167 | 168 | def __enter__(self): 169 | return self 170 | 171 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 172 | self.close() 173 | -------------------------------------------------------------------------------- /hallo/datasets/image_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is responsible for processing images, particularly for face-related tasks. 3 | It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like 4 | face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates 5 | the functionality for these operations. 6 | """ 7 | import os 8 | from typing import List 9 | 10 | import cv2 11 | import numpy as np 12 | import torch 13 | from insightface.app import FaceAnalysis 14 | from PIL import Image 15 | from torchvision import transforms 16 | 17 | from ..utils.util import get_mask 18 | 19 | MEAN = 0.5 20 | STD = 0.5 21 | 22 | class ImageProcessor: 23 | """ 24 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks. 25 | It takes in an image and performs various operations such as augmentation, face detection, 26 | face embedding extraction, and rendering a face mask. The processed images are then used for 27 | further analysis or recognition purposes. 28 | 29 | Attributes: 30 | img_size (int): The size of the image to be processed. 31 | face_analysis_model_path (str): The path to the face analysis model. 32 | 33 | Methods: 34 | preprocess(source_image_path, cache_dir): 35 | Preprocesses the input image by performing augmentation, face detection, 36 | face embedding extraction, and rendering a face mask. 37 | 38 | close(): 39 | Closes the ImageProcessor and releases any resources being used. 40 | 41 | _augmentation(images, transform, state=None): 42 | Applies image augmentation to the input images using the given transform and state. 43 | 44 | __enter__(): 45 | Enters a runtime context and returns the ImageProcessor object. 46 | 47 | __exit__(_exc_type, _exc_val, _exc_tb): 48 | Exits a runtime context and handles any exceptions that occurred during the processing. 49 | """ 50 | def __init__(self, img_size, face_analysis_model_path) -> None: 51 | self.img_size = img_size 52 | 53 | self.pixel_transform = transforms.Compose( 54 | [ 55 | transforms.Resize(self.img_size), 56 | transforms.ToTensor(), 57 | transforms.Normalize([MEAN], [STD]), 58 | ] 59 | ) 60 | 61 | self.cond_transform = transforms.Compose( 62 | [ 63 | transforms.Resize(self.img_size), 64 | transforms.ToTensor(), 65 | ] 66 | ) 67 | 68 | self.attn_transform_64 = transforms.Compose( 69 | [ 70 | transforms.Resize( 71 | (self.img_size[0] // 8, self.img_size[0] // 8)), 72 | transforms.ToTensor(), 73 | ] 74 | ) 75 | self.attn_transform_32 = transforms.Compose( 76 | [ 77 | transforms.Resize( 78 | (self.img_size[0] // 16, self.img_size[0] // 16)), 79 | transforms.ToTensor(), 80 | ] 81 | ) 82 | self.attn_transform_16 = transforms.Compose( 83 | [ 84 | transforms.Resize( 85 | (self.img_size[0] // 32, self.img_size[0] // 32)), 86 | transforms.ToTensor(), 87 | ] 88 | ) 89 | self.attn_transform_8 = transforms.Compose( 90 | [ 91 | transforms.Resize( 92 | (self.img_size[0] // 64, self.img_size[0] // 64)), 93 | transforms.ToTensor(), 94 | ] 95 | ) 96 | 97 | self.face_analysis = FaceAnalysis( 98 | name="", 99 | root=face_analysis_model_path, 100 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 101 | ) 102 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) 103 | 104 | def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float): 105 | """ 106 | Apply preprocessing to the source image to prepare for face analysis. 107 | 108 | Parameters: 109 | source_image_path (str): The path to the source image. 110 | cache_dir (str): The directory to cache intermediate results. 111 | 112 | Returns: 113 | None 114 | """ 115 | source_image = Image.open(source_image_path) 116 | ref_image_pil = source_image.convert("RGB") 117 | # 1. image augmentation 118 | pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform) 119 | 120 | 121 | # 2.1 detect face 122 | faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) 123 | # use max size face 124 | face = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1] 125 | 126 | # 2.2 face embedding 127 | face_emb = face["embedding"] 128 | 129 | # 2.3 render face mask 130 | get_mask(source_image_path, cache_dir, face_region_ratio) 131 | file_name = os.path.basename(source_image_path).split(".")[0] 132 | face_mask_pil = Image.open( 133 | os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB") 134 | 135 | face_mask = self._augmentation(face_mask_pil, self.cond_transform) 136 | 137 | # 2.4 detect and expand lip, face mask 138 | sep_background_mask = Image.open( 139 | os.path.join(cache_dir, f"{file_name}_sep_background.png")) 140 | sep_face_mask = Image.open( 141 | os.path.join(cache_dir, f"{file_name}_sep_face.png")) 142 | sep_lip_mask = Image.open( 143 | os.path.join(cache_dir, f"{file_name}_sep_lip.png")) 144 | 145 | pixel_values_face_mask = [ 146 | self._augmentation(sep_face_mask, self.attn_transform_64), 147 | self._augmentation(sep_face_mask, self.attn_transform_32), 148 | self._augmentation(sep_face_mask, self.attn_transform_16), 149 | self._augmentation(sep_face_mask, self.attn_transform_8), 150 | ] 151 | pixel_values_lip_mask = [ 152 | self._augmentation(sep_lip_mask, self.attn_transform_64), 153 | self._augmentation(sep_lip_mask, self.attn_transform_32), 154 | self._augmentation(sep_lip_mask, self.attn_transform_16), 155 | self._augmentation(sep_lip_mask, self.attn_transform_8), 156 | ] 157 | pixel_values_full_mask = [ 158 | self._augmentation(sep_background_mask, self.attn_transform_64), 159 | self._augmentation(sep_background_mask, self.attn_transform_32), 160 | self._augmentation(sep_background_mask, self.attn_transform_16), 161 | self._augmentation(sep_background_mask, self.attn_transform_8), 162 | ] 163 | 164 | pixel_values_full_mask = [mask.view(1, -1) 165 | for mask in pixel_values_full_mask] 166 | pixel_values_face_mask = [mask.view(1, -1) 167 | for mask in pixel_values_face_mask] 168 | pixel_values_lip_mask = [mask.view(1, -1) 169 | for mask in pixel_values_lip_mask] 170 | 171 | return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask 172 | 173 | def close(self): 174 | """ 175 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. 176 | 177 | Args: 178 | self: The ImageProcessor instance. 179 | 180 | Returns: 181 | None. 182 | """ 183 | for _, model in self.face_analysis.models.items(): 184 | if hasattr(model, "Dispose"): 185 | model.Dispose() 186 | 187 | def _augmentation(self, images, transform, state=None): 188 | if state is not None: 189 | torch.set_rng_state(state) 190 | if isinstance(images, List): 191 | transformed_images = [transform(img) for img in images] 192 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 193 | else: 194 | ret_tensor = transform(images) # (c, h, w) 195 | return ret_tensor 196 | 197 | def __enter__(self): 198 | return self 199 | 200 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 201 | self.close() 202 | -------------------------------------------------------------------------------- /hallo/models/wav2vec.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0901 2 | # src/models/wav2vec.py 3 | 4 | """ 5 | This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding. 6 | It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities 7 | such as feature extraction and encoding. 8 | 9 | Classes: 10 | Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding. 11 | 12 | Functions: 13 | linear_interpolation: Interpolates the features based on the sequence length. 14 | """ 15 | 16 | import torch.nn.functional as F 17 | from transformers import Wav2Vec2Model 18 | from transformers.modeling_outputs import BaseModelOutput 19 | 20 | 21 | class Wav2VecModel(Wav2Vec2Model): 22 | """ 23 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. 24 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. 25 | ... 26 | 27 | Attributes: 28 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object. 29 | 30 | Methods: 31 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None 32 | , output_attentions=None, output_hidden_states=None, return_dict=None): 33 | Forward pass of the Wav2VecModel. 34 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. 35 | 36 | feature_extract(input_values, seq_len): 37 | Extracts features from the input_values using the base model. 38 | 39 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): 40 | Encodes the extracted features using the base model and returns the encoded features. 41 | """ 42 | def forward( 43 | self, 44 | input_values, 45 | seq_len, 46 | attention_mask=None, 47 | mask_time_indices=None, 48 | output_attentions=None, 49 | output_hidden_states=None, 50 | return_dict=None, 51 | ): 52 | """ 53 | Forward pass of the Wav2Vec model. 54 | 55 | Args: 56 | self: The instance of the model. 57 | input_values: The input values (waveform) to the model. 58 | seq_len: The sequence length of the input values. 59 | attention_mask: Attention mask to be used for the model. 60 | mask_time_indices: Mask indices to be used for the model. 61 | output_attentions: If set to True, returns attentions. 62 | output_hidden_states: If set to True, returns hidden states. 63 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple. 64 | 65 | Returns: 66 | The output of the Wav2Vec model. 67 | """ 68 | self.config.output_attentions = True 69 | 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | extract_features = self.feature_extractor(input_values) 76 | extract_features = extract_features.transpose(1, 2) 77 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 78 | 79 | if attention_mask is not None: 80 | # compute reduced attention_mask corresponding to feature vectors 81 | attention_mask = self._get_feature_vector_attention_mask( 82 | extract_features.shape[1], attention_mask, add_adapter=False 83 | ) 84 | 85 | hidden_states, extract_features = self.feature_projection(extract_features) 86 | hidden_states = self._mask_hidden_states( 87 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 88 | ) 89 | 90 | encoder_outputs = self.encoder( 91 | hidden_states, 92 | attention_mask=attention_mask, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | ) 97 | 98 | hidden_states = encoder_outputs[0] 99 | 100 | if self.adapter is not None: 101 | hidden_states = self.adapter(hidden_states) 102 | 103 | if not return_dict: 104 | return (hidden_states, ) + encoder_outputs[1:] 105 | return BaseModelOutput( 106 | last_hidden_state=hidden_states, 107 | hidden_states=encoder_outputs.hidden_states, 108 | attentions=encoder_outputs.attentions, 109 | ) 110 | 111 | 112 | def feature_extract( 113 | self, 114 | input_values, 115 | seq_len, 116 | ): 117 | """ 118 | Extracts features from the input values and returns the extracted features. 119 | 120 | Parameters: 121 | input_values (torch.Tensor): The input values to be processed. 122 | seq_len (torch.Tensor): The sequence lengths of the input values. 123 | 124 | Returns: 125 | extracted_features (torch.Tensor): The extracted features from the input values. 126 | """ 127 | extract_features = self.feature_extractor(input_values) 128 | extract_features = extract_features.transpose(1, 2) 129 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 130 | 131 | return extract_features 132 | 133 | def encode( 134 | self, 135 | extract_features, 136 | attention_mask=None, 137 | mask_time_indices=None, 138 | output_attentions=None, 139 | output_hidden_states=None, 140 | return_dict=None, 141 | ): 142 | """ 143 | Encodes the input features into the output space. 144 | 145 | Args: 146 | extract_features (torch.Tensor): The extracted features from the audio signal. 147 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding. 148 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. 149 | output_attentions (bool, optional): If set to True, returns the attention weights. 150 | output_hidden_states (bool, optional): If set to True, returns all hidden states. 151 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. 152 | 153 | Returns: 154 | The encoded output features. 155 | """ 156 | self.config.output_attentions = True 157 | 158 | output_hidden_states = ( 159 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 160 | ) 161 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 162 | 163 | if attention_mask is not None: 164 | # compute reduced attention_mask corresponding to feature vectors 165 | attention_mask = self._get_feature_vector_attention_mask( 166 | extract_features.shape[1], attention_mask, add_adapter=False 167 | ) 168 | 169 | hidden_states, extract_features = self.feature_projection(extract_features) 170 | hidden_states = self._mask_hidden_states( 171 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 172 | ) 173 | 174 | encoder_outputs = self.encoder( 175 | hidden_states, 176 | attention_mask=attention_mask, 177 | output_attentions=output_attentions, 178 | output_hidden_states=output_hidden_states, 179 | return_dict=return_dict, 180 | ) 181 | 182 | hidden_states = encoder_outputs[0] 183 | 184 | if self.adapter is not None: 185 | hidden_states = self.adapter(hidden_states) 186 | 187 | if not return_dict: 188 | return (hidden_states, ) + encoder_outputs[1:] 189 | return BaseModelOutput( 190 | last_hidden_state=hidden_states, 191 | hidden_states=encoder_outputs.hidden_states, 192 | attentions=encoder_outputs.attentions, 193 | ) 194 | 195 | 196 | def linear_interpolation(features, seq_len): 197 | """ 198 | Transpose the features to interpolate linearly. 199 | 200 | Args: 201 | features (torch.Tensor): The extracted features to be interpolated. 202 | seq_len (torch.Tensor): The sequence lengths of the features. 203 | 204 | Returns: 205 | torch.Tensor: The interpolated features. 206 | """ 207 | features = features.transpose(1, 2) 208 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 209 | return output_features.transpose(1, 2) 210 | -------------------------------------------------------------------------------- /hallo/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module implements the Transformer3DModel, a PyTorch model designed for processing 4 | 3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer 5 | model with support for gradient checkpointing and various types of attention mechanisms. 6 | The model can be configured with different parameters such as the number of attention heads, 7 | attention head dimension, and the number of layers. It also supports the use of audio modules 8 | for enhanced feature extraction from video data. 9 | """ 10 | 11 | from dataclasses import dataclass 12 | from typing import Optional 13 | 14 | import torch 15 | from diffusers.configuration_utils import ConfigMixin, register_to_config 16 | from diffusers.models import ModelMixin 17 | from diffusers.utils import BaseOutput 18 | from einops import rearrange, repeat 19 | from torch import nn 20 | 21 | from .attention import (AudioTemporalBasicTransformerBlock, 22 | TemporalBasicTransformerBlock) 23 | 24 | 25 | @dataclass 26 | class Transformer3DModelOutput(BaseOutput): 27 | """ 28 | The output of the [`Transformer3DModel`]. 29 | 30 | Attributes: 31 | sample (`torch.FloatTensor`): 32 | The output tensor from the transformer model, which is the result of processing the input 33 | hidden states through the transformer blocks and any subsequent layers. 34 | """ 35 | sample: torch.FloatTensor 36 | 37 | 38 | class Transformer3DModel(ModelMixin, ConfigMixin): 39 | """ 40 | Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model. 41 | It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks. 42 | The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method. 43 | """ 44 | _supports_gradient_checkpointing = True 45 | 46 | @register_to_config 47 | def __init__( 48 | self, 49 | num_attention_heads: int = 16, 50 | attention_head_dim: int = 88, 51 | in_channels: Optional[int] = None, 52 | num_layers: int = 1, 53 | dropout: float = 0.0, 54 | norm_num_groups: int = 32, 55 | cross_attention_dim: Optional[int] = None, 56 | attention_bias: bool = False, 57 | activation_fn: str = "geglu", 58 | num_embeds_ada_norm: Optional[int] = None, 59 | use_linear_projection: bool = False, 60 | only_cross_attention: bool = False, 61 | upcast_attention: bool = False, 62 | unet_use_cross_frame_attention=None, 63 | unet_use_temporal_attention=None, 64 | use_audio_module=False, 65 | depth=0, 66 | unet_block_name=None, 67 | stack_enable_blocks_name = None, 68 | stack_enable_blocks_depth = None, 69 | ): 70 | super().__init__() 71 | self.use_linear_projection = use_linear_projection 72 | self.num_attention_heads = num_attention_heads 73 | self.attention_head_dim = attention_head_dim 74 | inner_dim = num_attention_heads * attention_head_dim 75 | self.use_audio_module = use_audio_module 76 | # Define input layers 77 | self.in_channels = in_channels 78 | 79 | self.norm = torch.nn.GroupNorm( 80 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 81 | ) 82 | if use_linear_projection: 83 | self.proj_in = nn.Linear(in_channels, inner_dim) 84 | else: 85 | self.proj_in = nn.Conv2d( 86 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 87 | ) 88 | 89 | if use_audio_module: 90 | self.transformer_blocks = nn.ModuleList( 91 | [ 92 | AudioTemporalBasicTransformerBlock( 93 | inner_dim, 94 | num_attention_heads, 95 | attention_head_dim, 96 | dropout=dropout, 97 | cross_attention_dim=cross_attention_dim, 98 | activation_fn=activation_fn, 99 | num_embeds_ada_norm=num_embeds_ada_norm, 100 | attention_bias=attention_bias, 101 | only_cross_attention=only_cross_attention, 102 | upcast_attention=upcast_attention, 103 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 104 | unet_use_temporal_attention=unet_use_temporal_attention, 105 | depth=depth, 106 | unet_block_name=unet_block_name, 107 | stack_enable_blocks_name=stack_enable_blocks_name, 108 | stack_enable_blocks_depth=stack_enable_blocks_depth, 109 | ) 110 | for d in range(num_layers) 111 | ] 112 | ) 113 | else: 114 | # Define transformers blocks 115 | self.transformer_blocks = nn.ModuleList( 116 | [ 117 | TemporalBasicTransformerBlock( 118 | inner_dim, 119 | num_attention_heads, 120 | attention_head_dim, 121 | dropout=dropout, 122 | cross_attention_dim=cross_attention_dim, 123 | activation_fn=activation_fn, 124 | num_embeds_ada_norm=num_embeds_ada_norm, 125 | attention_bias=attention_bias, 126 | only_cross_attention=only_cross_attention, 127 | upcast_attention=upcast_attention, 128 | ) 129 | for d in range(num_layers) 130 | ] 131 | ) 132 | 133 | # 4. Define output layers 134 | if use_linear_projection: 135 | self.proj_out = nn.Linear(in_channels, inner_dim) 136 | else: 137 | self.proj_out = nn.Conv2d( 138 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 139 | ) 140 | 141 | self.gradient_checkpointing = False 142 | 143 | def _set_gradient_checkpointing(self, module, value=False): 144 | if hasattr(module, "gradient_checkpointing"): 145 | module.gradient_checkpointing = value 146 | 147 | def forward( 148 | self, 149 | hidden_states, 150 | encoder_hidden_states=None, 151 | attention_mask=None, 152 | full_mask=None, 153 | face_mask=None, 154 | lip_mask=None, 155 | motion_scale=None, 156 | timestep=None, 157 | return_dict: bool = True, 158 | ): 159 | """ 160 | Forward pass for the Transformer3DModel. 161 | 162 | Args: 163 | hidden_states (torch.Tensor): The input hidden states. 164 | encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states. 165 | attention_mask (torch.Tensor, optional): The attention mask. 166 | full_mask (torch.Tensor, optional): The full mask. 167 | face_mask (torch.Tensor, optional): The face mask. 168 | lip_mask (torch.Tensor, optional): The lip mask. 169 | timestep (int, optional): The current timestep. 170 | return_dict (bool, optional): Whether to return a dictionary or a tuple. 171 | 172 | Returns: 173 | output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel. 174 | """ 175 | # Input 176 | assert ( 177 | hidden_states.dim() == 5 178 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 179 | video_length = hidden_states.shape[2] 180 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 181 | 182 | # TODO 183 | if self.use_audio_module: 184 | encoder_hidden_states = rearrange( 185 | encoder_hidden_states, 186 | "bs f margin dim -> (bs f) margin dim", 187 | ) 188 | else: 189 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 190 | encoder_hidden_states = repeat( 191 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 192 | ) 193 | 194 | batch, _, height, weight = hidden_states.shape 195 | residual = hidden_states 196 | 197 | hidden_states = self.norm(hidden_states) 198 | if not self.use_linear_projection: 199 | hidden_states = self.proj_in(hidden_states) 200 | inner_dim = hidden_states.shape[1] 201 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 202 | batch, height * weight, inner_dim 203 | ) 204 | else: 205 | inner_dim = hidden_states.shape[1] 206 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 207 | batch, height * weight, inner_dim 208 | ) 209 | hidden_states = self.proj_in(hidden_states) 210 | 211 | # Blocks 212 | motion_frames = [] 213 | for _, block in enumerate(self.transformer_blocks): 214 | if isinstance(block, TemporalBasicTransformerBlock): 215 | hidden_states, motion_frame_fea = block( 216 | hidden_states, 217 | encoder_hidden_states=encoder_hidden_states, 218 | timestep=timestep, 219 | video_length=video_length, 220 | ) 221 | motion_frames.append(motion_frame_fea) 222 | else: 223 | hidden_states = block( 224 | hidden_states, # shape [2, 4096, 320] 225 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640] 226 | attention_mask=attention_mask, 227 | full_mask=full_mask, 228 | face_mask=face_mask, 229 | lip_mask=lip_mask, 230 | timestep=timestep, 231 | video_length=video_length, 232 | motion_scale=motion_scale, 233 | ) 234 | 235 | # Output 236 | if not self.use_linear_projection: 237 | hidden_states = ( 238 | hidden_states.reshape(batch, height, weight, inner_dim) 239 | .permute(0, 3, 1, 2) 240 | .contiguous() 241 | ) 242 | hidden_states = self.proj_out(hidden_states) 243 | else: 244 | hidden_states = self.proj_out(hidden_states) 245 | hidden_states = ( 246 | hidden_states.reshape(batch, height, weight, inner_dim) 247 | .permute(0, 3, 1, 2) 248 | .contiguous() 249 | ) 250 | 251 | output = hidden_states + residual 252 | 253 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 254 | if not return_dict: 255 | return (output, motion_frames) 256 | 257 | return Transformer3DModelOutput(sample=output) 258 | -------------------------------------------------------------------------------- /hallo/datasets/talk_video.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | talking_video_dataset.py 4 | 5 | This module defines the TalkingVideoDataset class, a custom PyTorch dataset 6 | for handling talking video data. The dataset uses video files, masks, and 7 | embeddings to prepare data for tasks such as video generation and 8 | speech-driven video animation. 9 | 10 | Classes: 11 | TalkingVideoDataset 12 | 13 | Dependencies: 14 | json 15 | random 16 | torch 17 | decord.VideoReader, decord.cpu 18 | PIL.Image 19 | torch.utils.data.Dataset 20 | torchvision.transforms 21 | 22 | Example: 23 | from talking_video_dataset import TalkingVideoDataset 24 | from torch.utils.data import DataLoader 25 | 26 | # Example configuration for the Wav2Vec model 27 | class Wav2VecConfig: 28 | def __init__(self, audio_type, model_scale, features): 29 | self.audio_type = audio_type 30 | self.model_scale = model_scale 31 | self.features = features 32 | 33 | wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature") 34 | 35 | # Initialize dataset 36 | dataset = TalkingVideoDataset( 37 | img_size=(512, 512), 38 | sample_rate=16000, 39 | audio_margin=2, 40 | n_motion_frames=0, 41 | n_sample_frames=16, 42 | data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"], 43 | wav2vec_cfg=wav2vec_cfg, 44 | ) 45 | 46 | # Initialize dataloader 47 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) 48 | 49 | # Fetch one batch of data 50 | batch = next(iter(dataloader)) 51 | print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512) 52 | 53 | The TalkingVideoDataset class provides methods for loading video frames, masks, 54 | audio embeddings, and other relevant data, applying transformations, and preparing 55 | the data for training and evaluation in a deep learning pipeline. 56 | 57 | Attributes: 58 | img_size (tuple): The dimensions to resize the video frames to. 59 | sample_rate (int): The audio sample rate. 60 | audio_margin (int): The margin for audio sampling. 61 | n_motion_frames (int): The number of motion frames. 62 | n_sample_frames (int): The number of sample frames. 63 | data_meta_paths (list): List of paths to the JSON metadata files. 64 | wav2vec_cfg (object): Configuration for the Wav2Vec model. 65 | 66 | Methods: 67 | augmentation(images, transform, state=None): Apply transformation to input images. 68 | __getitem__(index): Get a sample from the dataset at the specified index. 69 | __len__(): Return the length of the dataset. 70 | """ 71 | 72 | import json 73 | import random 74 | from typing import List 75 | 76 | import torch 77 | from decord import VideoReader, cpu 78 | from PIL import Image 79 | from torch.utils.data import Dataset 80 | from torchvision import transforms 81 | 82 | 83 | class TalkingVideoDataset(Dataset): 84 | """ 85 | A dataset class for processing talking video data. 86 | 87 | Args: 88 | img_size (tuple, optional): The size of the output images. Defaults to (512, 512). 89 | sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000. 90 | audio_margin (int, optional): The margin for the audio data. Defaults to 2. 91 | n_motion_frames (int, optional): The number of motion frames. Defaults to 0. 92 | n_sample_frames (int, optional): The number of sample frames. Defaults to 16. 93 | data_meta_paths (list, optional): The paths to the data metadata. Defaults to None. 94 | wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None. 95 | 96 | Attributes: 97 | img_size (tuple): The size of the output images. 98 | sample_rate (int): The sample rate of the audio data. 99 | audio_margin (int): The margin for the audio data. 100 | n_motion_frames (int): The number of motion frames. 101 | n_sample_frames (int): The number of sample frames. 102 | data_meta_paths (list): The paths to the data metadata. 103 | wav2vec_cfg (dict): The configuration for the wav2vec model. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | img_size=(512, 512), 109 | sample_rate=16000, 110 | audio_margin=2, 111 | n_motion_frames=0, 112 | n_sample_frames=16, 113 | data_meta_paths=None, 114 | wav2vec_cfg=None, 115 | ): 116 | super().__init__() 117 | self.sample_rate = sample_rate 118 | self.img_size = img_size 119 | self.audio_margin = audio_margin 120 | self.n_motion_frames = n_motion_frames 121 | self.n_sample_frames = n_sample_frames 122 | self.audio_type = wav2vec_cfg.audio_type 123 | self.audio_model = wav2vec_cfg.model_scale 124 | self.audio_features = wav2vec_cfg.features 125 | 126 | vid_meta = [] 127 | for data_meta_path in data_meta_paths: 128 | with open(data_meta_path, "r", encoding="utf-8") as f: 129 | vid_meta.extend(json.load(f)) 130 | self.vid_meta = vid_meta 131 | self.length = len(self.vid_meta) 132 | self.pixel_transform = transforms.Compose( 133 | [ 134 | transforms.Resize(self.img_size), 135 | transforms.ToTensor(), 136 | transforms.Normalize([0.5], [0.5]), 137 | ] 138 | ) 139 | 140 | self.cond_transform = transforms.Compose( 141 | [ 142 | transforms.Resize(self.img_size), 143 | transforms.ToTensor(), 144 | ] 145 | ) 146 | self.attn_transform_64 = transforms.Compose( 147 | [ 148 | transforms.Resize((64,64)), 149 | transforms.ToTensor(), 150 | ] 151 | ) 152 | self.attn_transform_32 = transforms.Compose( 153 | [ 154 | transforms.Resize((32, 32)), 155 | transforms.ToTensor(), 156 | ] 157 | ) 158 | self.attn_transform_16 = transforms.Compose( 159 | [ 160 | transforms.Resize((16, 16)), 161 | transforms.ToTensor(), 162 | ] 163 | ) 164 | self.attn_transform_8 = transforms.Compose( 165 | [ 166 | transforms.Resize((8, 8)), 167 | transforms.ToTensor(), 168 | ] 169 | ) 170 | 171 | def augmentation(self, images, transform, state=None): 172 | """ 173 | Apply the given transformation to the input images. 174 | 175 | Args: 176 | images (List[PIL.Image] or PIL.Image): The input images to be transformed. 177 | transform (torchvision.transforms.Compose): The transformation to be applied to the images. 178 | state (torch.ByteTensor, optional): The state of the random number generator. 179 | If provided, it will set the RNG state to this value before applying the transformation. Defaults to None. 180 | 181 | Returns: 182 | torch.Tensor: The transformed images as a tensor. 183 | If the input was a list of images, the tensor will have shape (f, c, h, w), 184 | where f is the number of images, c is the number of channels, h is the height, and w is the width. 185 | If the input was a single image, the tensor will have shape (c, h, w), 186 | where c is the number of channels, h is the height, and w is the width. 187 | """ 188 | if state is not None: 189 | torch.set_rng_state(state) 190 | if isinstance(images, List): 191 | transformed_images = [transform(img) for img in images] 192 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 193 | else: 194 | ret_tensor = transform(images) # (c, h, w) 195 | return ret_tensor 196 | 197 | def __getitem__(self, index): 198 | video_meta = self.vid_meta[index] 199 | video_path = video_meta["video_path"] 200 | mask_path = video_meta["mask_path"] 201 | lip_mask_union_path = video_meta.get("sep_mask_lip", None) 202 | face_mask_union_path = video_meta.get("sep_mask_face", None) 203 | full_mask_union_path = video_meta.get("sep_mask_border", None) 204 | face_emb_path = video_meta["face_emb_path"] 205 | audio_emb_path = video_meta[ 206 | f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}" 207 | ] 208 | tgt_mask_pil = Image.open(mask_path) 209 | video_frames = VideoReader(video_path, ctx=cpu(0)) 210 | assert tgt_mask_pil is not None, "Fail to load target mask." 211 | assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames." 212 | video_length = len(video_frames) 213 | 214 | assert ( 215 | video_length 216 | > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin 217 | ) 218 | start_idx = random.randint( 219 | self.n_motion_frames, 220 | video_length - self.n_sample_frames - self.audio_margin - 1, 221 | ) 222 | 223 | videos = video_frames[start_idx : start_idx + self.n_sample_frames] 224 | 225 | frame_list = [ 226 | Image.fromarray(video).convert("RGB") for video in videos.asnumpy() 227 | ] 228 | 229 | face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames 230 | lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames 231 | full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames 232 | assert face_masks_list[0] is not None, "Fail to load face mask." 233 | assert lip_masks_list[0] is not None, "Fail to load lip mask." 234 | assert full_masks_list[0] is not None, "Fail to load full mask." 235 | 236 | 237 | face_emb = torch.load(face_emb_path) 238 | audio_emb = torch.load(audio_emb_path) 239 | indices = ( 240 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 241 | ) # Generates [-2, -1, 0, 1, 2] 242 | center_indices = torch.arange( 243 | start_idx, 244 | start_idx + self.n_sample_frames, 245 | ).unsqueeze(1) + indices.unsqueeze(0) 246 | audio_tensor = audio_emb[center_indices] 247 | 248 | ref_img_idx = random.randint( 249 | self.n_motion_frames, 250 | video_length - self.n_sample_frames - self.audio_margin - 1, 251 | ) 252 | ref_img = video_frames[ref_img_idx].asnumpy() 253 | ref_img = Image.fromarray(ref_img) 254 | 255 | if self.n_motion_frames > 0: 256 | motions = video_frames[start_idx - self.n_motion_frames : start_idx] 257 | motion_list = [ 258 | Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy() 259 | ] 260 | 261 | # transform 262 | state = torch.get_rng_state() 263 | pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state) 264 | 265 | pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state) 266 | pixel_values_mask = pixel_values_mask.repeat(3, 1, 1) 267 | 268 | pixel_values_face_mask = [ 269 | self.augmentation(face_masks_list, self.attn_transform_64, state), 270 | self.augmentation(face_masks_list, self.attn_transform_32, state), 271 | self.augmentation(face_masks_list, self.attn_transform_16, state), 272 | self.augmentation(face_masks_list, self.attn_transform_8, state), 273 | ] 274 | pixel_values_lip_mask = [ 275 | self.augmentation(lip_masks_list, self.attn_transform_64, state), 276 | self.augmentation(lip_masks_list, self.attn_transform_32, state), 277 | self.augmentation(lip_masks_list, self.attn_transform_16, state), 278 | self.augmentation(lip_masks_list, self.attn_transform_8, state), 279 | ] 280 | pixel_values_full_mask = [ 281 | self.augmentation(full_masks_list, self.attn_transform_64, state), 282 | self.augmentation(full_masks_list, self.attn_transform_32, state), 283 | self.augmentation(full_masks_list, self.attn_transform_16, state), 284 | self.augmentation(full_masks_list, self.attn_transform_8, state), 285 | ] 286 | 287 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) 288 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 289 | if self.n_motion_frames > 0: 290 | pixel_values_motion = self.augmentation( 291 | motion_list, self.pixel_transform, state 292 | ) 293 | pixel_values_ref_img = torch.cat( 294 | [pixel_values_ref_img, pixel_values_motion], dim=0 295 | ) 296 | 297 | sample = { 298 | "video_dir": video_path, 299 | "pixel_values_vid": pixel_values_vid, 300 | "pixel_values_mask": pixel_values_mask, 301 | "pixel_values_face_mask": pixel_values_face_mask, 302 | "pixel_values_lip_mask": pixel_values_lip_mask, 303 | "pixel_values_full_mask": pixel_values_full_mask, 304 | "audio_tensor": audio_tensor, 305 | "pixel_values_ref_img": pixel_values_ref_img, 306 | "face_emb": face_emb, 307 | } 308 | 309 | return sample 310 | 311 | def __len__(self): 312 | return len(self.vid_meta) 313 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About this fork 2 | 3 | This fork was created to provide a convenient web interface for using Hallo. The original code has been slightly modified to allow for more control over the generation process. 4 | 5 | ## About colab 6 | ⚠️ To run the web interface, you need at least 12 GB of video memory (VRAM) and more than 12 GB of RAM. ⚠️ 7 | 8 | Unfortunately, I was unable to create a free tier Colab notebook as there is not enough RAM available. 9 | 10 | ## Portable version 11 | 12 | But you can try if you have `pro` [colab](https://colab.research.google.com/drive/1JGkftvdEksrhJbeAUnnRAZNAfZyGjP44?usp=sharing) 13 | 14 | If you have windows and you don't want to bother with installing libs, you can download the [portable version](https://huggingface.co/daswer123/portable_webuis/resolve/main/hallo-portable-2.zip?download=true) , unpack and launch `run.bat` 15 | 16 | ## Screenshot 17 | 18 | ![image](https://github.com/daswer123/hallo-webui/assets/22278673/ebd9c9cd-9d37-4772-8d7c-edaf68dbe15b) 19 | 20 | ## Installation 21 | 22 | ### Docker 23 | 24 | ```bash 25 | docker compose up -d 26 | ``` 27 | this will start the gradio web ui and expose the port 7680 which is mapped to 8020 to teh container's host 28 | The app will be available at http://localhost:8020 29 | 30 | Note : Be sure to use the correct cuda start image for your GPU driver version, if it doesn't build from the start. 31 | 32 | ### Windows 33 | 34 | 1. Clone this repository: 35 | ``` 36 | git clone https://github.com/yourusername/hallo.git 37 | ``` 38 | 39 | 2. Run `install.bat` to set up the environment and download the pretrained models. 40 | 41 | 3. Make sure ffmpeg is installed on your system. It doesn't matter where it's located, as long as the system can find it. 42 | 43 | 4. Launch the web interface by running `start.bat`. 44 | 45 | ### Linux 46 | 47 | 1. Clone this repository: 48 | ``` 49 | git clone https://github.com/yourusername/hallo.git 50 | ``` 51 | 52 | 2. Run `install.sh` to set up the environment and download the pretrained models. 53 | 54 | 3. Ensure ffmpeg is installed on your system. You can install it with: 55 | ``` 56 | sudo apt-get install ffmpeg 57 | ``` 58 | 59 | 4. Launch the web interface by running `start.sh`. 60 | 61 | ### Manual Installation 62 | 63 | If you prefer to install manually, here are the detailed steps: 64 | 65 | 1. Clone the repository and pretrained models: 66 | ``` 67 | git lfs install 68 | git clone https://github.com/yourusername/hallo.git 69 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models 70 | curl -L -o pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true 71 | ``` 72 | 73 | 2. Create a virtual environment and activate it: 74 | ``` 75 | python -m venv venv 76 | venv\Scripts\activate # For Windows 77 | source venv/bin/activate # For Linux 78 | ``` 79 | 80 | 3. Install the required packages: 81 | ``` 82 | pip install -r requirements.txt 83 | pip install -e . 84 | pip install bitsandbytes-windows --force-reinstall # For Windows only 85 | ``` 86 | 87 | 4. Install GPU libraries: 88 | ``` 89 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121 90 | pip install onnxruntime-gpu 91 | ``` 92 | 93 | 5. Launch the web interface: 94 | ``` 95 | python app.py 96 | ``` 97 | 98 | 6. To share , use `--share` flag 99 | ``` 100 | python app.py --share 101 | ``` 102 | 103 |

Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation

104 | 105 |
106 | Mingwang Xu1*  107 | Hui Li1*  108 | Qingkun Su1*  109 | Hanlin Shang1  110 | Liwei Zhang1  111 | Ce Liu3  112 |
113 |
114 | Jingdong Wang2  115 | Yao Yao4  116 | Siyu Zhu1  117 |
118 | 119 |
120 | 1Fudan University  2Baidu Inc  3ETH Zurich  4Nanjing University 121 |
122 | 123 |
124 |
125 | 126 | 127 | 128 | 129 | 130 |
131 | 132 |
133 | 134 | # Showcase 135 | 136 | 137 | https://github.com/fudan-generative-vision/hallo/assets/17402682/294e78ef-c60d-4c32-8e3c-7f8d6934c6bd 138 | 139 | 140 | # Framework 141 | 142 | ![abstract](assets/framework_1.jpg) 143 | ![framework](assets/framework_2.jpg) 144 | 145 | # News 146 | 147 | - **`2024/06/15`**: 🎉🎉🎉 Release the first version on [GitHub](https://github.com/fudan-generative-vision/hallo). 148 | - **`2024/06/15`**: ✨✨✨ Release some images and audios for inference testing on [Huggingface](https://huggingface.co/datasets/fudan-generative-ai/hallo_inference_samples). 149 | 150 | # Installation 151 | 152 | - System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 12.1 153 | - Tested GPUs: A100 154 | 155 | Create conda environment: 156 | 157 | ```bash 158 | conda create -n hallo python=3.10 159 | conda activate hallo 160 | ``` 161 | 162 | Install packages with `pip` 163 | 164 | ```bash 165 | pip install -r requirements.txt 166 | pip install . 167 | ``` 168 | 169 | Besides, ffmpeg is also need: 170 | ```bash 171 | apt-get install ffmpeg 172 | ``` 173 | 174 | # Inference 175 | 176 | The inference entrypoint script is `scripts/inference.py`. Before testing your cases, there are two preparations need to be completed: 177 | 178 | 1. [Download all required pretrained models](#download-pretrained-models). 179 | 2. [Prepare source image and driving audio pairs](#prepare-inference-data). 180 | 3. [Run inference](#run-inference). 181 | 182 | ## Download pretrained models 183 | 184 | You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo). 185 | 186 | Clone the the pretrained models into `${PROJECT_ROOT}/pretrained_models` directory by cmd below: 187 | 188 | ```shell 189 | git lfs install 190 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models 191 | ``` 192 | 193 | Or you can download them separately from their source repo: 194 | 195 | - [hallo](https://huggingface.co/fudan-generative-ai/hallo/tree/main/hallo): Our checkpoints consist of denoising UNet, face locator, image & audio proj. 196 | - [audio_separator](https://huggingface.co/huangjackson/Kim_Vocal_2): Kim\_Vocal\_2 MDX-Net vocal removal model by [KimberleyJensen](https://github.com/KimberleyJensen). (_Thanks to runwayml_) 197 | - [insightface](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo): 2D and 3D Face Analysis placed into `pretrained_models/face_analysis/models/`. (_Thanks to deepinsight_) 198 | - [face landmarker](https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task): Face detection & mesh model from [mediapipe](https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker#models) placed into `pretrained_models/face_analysis/models`. 199 | - [motion module](https://github.com/guoyww/AnimateDiff/blob/main/README.md#202309-animatediff-v2): motion module from [AnimateDiff](https://github.com/guoyww/AnimateDiff). (_Thanks to guoyww_). 200 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse): Weights are intended to be used with the diffusers library. (_Thanks to stablilityai_) 201 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5): Initialized and fine-tuned from Stable-Diffusion-v1-2. (_Thanks to runwayml_) 202 | - [wav2vec](https://huggingface.co/facebook/wav2vec2-base-960h): wav audio to vector model from [Facebook](https://huggingface.co/facebook/wav2vec2-base-960h). 203 | 204 | Finally, these pretrained models should be organized as follows: 205 | 206 | ```text 207 | ./pretrained_models/ 208 | |-- audio_separator/ 209 | | `-- Kim_Vocal_2.onnx 210 | |-- face_analysis/ 211 | | `-- models/ 212 | | |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe 213 | | |-- 1k3d68.onnx 214 | | |-- 2d106det.onnx 215 | | |-- genderage.onnx 216 | | |-- glintr100.onnx 217 | | `-- scrfd_10g_bnkps.onnx 218 | |-- motion_module/ 219 | | `-- mm_sd_v15_v2.ckpt 220 | |-- sd-vae-ft-mse/ 221 | | |-- config.json 222 | | `-- diffusion_pytorch_model.safetensors 223 | |-- stable-diffusion-v1-5/ 224 | | |-- feature_extractor/ 225 | | | `-- preprocessor_config.json 226 | | |-- model_index.json 227 | | |-- unet/ 228 | | | |-- config.json 229 | | | `-- diffusion_pytorch_model.safetensors 230 | | `-- v1-inference.yaml 231 | `-- wav2vec/ 232 | |-- wav2vec2-base-960h/ 233 | | |-- config.json 234 | | |-- feature_extractor_config.json 235 | | |-- model.safetensors 236 | | |-- preprocessor_config.json 237 | | |-- special_tokens_map.json 238 | | |-- tokenizer_config.json 239 | | `-- vocab.json 240 | ``` 241 | 242 | ## Prepare Inference Data 243 | 244 | Hallo has a few simple requirements for input data: 245 | 246 | For the source image: 247 | 248 | 1. It should be cropped into squares. 249 | 2. The face should be the main focus, making up 50%-70% of the image. 250 | 3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles). 251 | 252 | For the driving audio: 253 | 254 | 1. It must be in WAV format. 255 | 2. It must be in English since our training datasets are only in this language. 256 | 3. Ensure the vocals are clear; background music is acceptable. 257 | 258 | We have provided some samples for your reference. 259 | 260 | ## Run inference 261 | 262 | Simply to run the `scripts/inference.py` and pass `source_image` and `driving_audio` as input: 263 | 264 | ```bash 265 | python scripts/inference.py --source_image examples/source_images/1.jpg --driving_audio examples/driving_audios/1.wav 266 | ``` 267 | 268 | Animation results will be saved as `${PROJECT_ROOT}/.cache/output.mp4` by default. You can pass `--output` to specify the output file name. You can find more examples for inference at [examples folder](https://github.com/fudan-generative-vision/hallo/tree/main/examples). 269 | 270 | For more options: 271 | 272 | ```shell 273 | usage: inference.py [-h] [-c CONFIG] [--source_image SOURCE_IMAGE] [--driving_audio DRIVING_AUDIO] [--output OUTPUT] [--pose_weight POSE_WEIGHT] 274 | [--face_weight FACE_WEIGHT] [--lip_weight LIP_WEIGHT] [--face_expand_ratio FACE_EXPAND_RATIO] 275 | 276 | options: 277 | -h, --help show this help message and exit 278 | -c CONFIG, --config CONFIG 279 | --source_image SOURCE_IMAGE 280 | source image 281 | --driving_audio DRIVING_AUDIO 282 | driving audio 283 | --output OUTPUT output video file name 284 | --pose_weight POSE_WEIGHT 285 | weight of pose 286 | --face_weight FACE_WEIGHT 287 | weight of face 288 | --lip_weight LIP_WEIGHT 289 | weight of lip 290 | --face_expand_ratio FACE_EXPAND_RATIO 291 | face region 292 | ``` 293 | 294 | # Roadmap 295 | 296 | | Status | Milestone | ETA | 297 | | :----: | :---------------------------------------------------------------------------------------------------- | :--------: | 298 | | ✅ | **[Inference source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo)** | 2024-06-15 | 299 | | ✅ | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/hallo)** | 2024-06-15 | 300 | | 🚀🚀🚀 | **[Traning: data preparation and training scripts]()** | 2024-06-25 | 301 | | 🚀🚀🚀 | **[Optimize inference performance in Mandarin]()** | TBD | 302 | 303 | # Citation 304 | 305 | If you find our work useful for your research, please consider citing the paper: 306 | 307 | ``` 308 | @misc{xu2024hallo, 309 | title={Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation}, 310 | author={Mingwang Xu and Hui Li and Qingkun Su and Hanlin Shang and Liwei Zhang and Ce Liu and Jingdong Wang and Yao Yao and Siyu zhu}, 311 | year={2024}, 312 | eprint={2406.08801}, 313 | archivePrefix={arXiv}, 314 | primaryClass={cs.CV} 315 | } 316 | ``` 317 | 318 | # Opportunities available 319 | 320 | Multiple research positions are open at the **Generative Vision Lab, Fudan University**! Include: 321 | 322 | - Research assistant 323 | - Postdoctoral researcher 324 | - PhD candidate 325 | - Master students 326 | 327 | Interested individuals are encouraged to contact us at [siyuzhu@fudan.edu.cn](mailto://siyuzhu@fudan.edu.cn) for further information. 328 | 329 | # Social Risks and Mitigations 330 | 331 | The development of portrait image animation technologies driven by audio inputs poses social risks, such as the ethical implications of creating realistic portraits that could be misused for deepfakes. To mitigate these risks, it is crucial to establish ethical guidelines and responsible use practices. Privacy and consent concerns also arise from using individuals' images and voices. Addressing these involves transparent data usage policies, informed consent, and safeguarding privacy rights. By addressing these risks and implementing mitigations, the research aims to ensure the responsible and ethical development of this technology. 332 | -------------------------------------------------------------------------------- /hallo/models/resnet.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1120 2 | # pylint: disable=E1102 3 | # pylint: disable=W0237 4 | 5 | # src/models/resnet.py 6 | 7 | """ 8 | This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm, 9 | Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct 10 | a deep neural network model for image classification or other computer vision tasks. 11 | 12 | Classes: 13 | - InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d. 14 | - InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm. 15 | - Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor. 16 | - Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor. 17 | - ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures. 18 | - Mish: A Mish activation function, which is a smooth, non-monotonic activation function. 19 | 20 | To use this module, simply import the classes and functions you need and follow the instructions provided in 21 | the respective class and function docstrings. 22 | """ 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from einops import rearrange 27 | from torch import nn 28 | 29 | 30 | class InflatedConv3d(nn.Conv2d): 31 | """ 32 | InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method. 33 | 34 | This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer 35 | commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and 36 | InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of 37 | inflating 2D convolutional layers to 3D for use in 3D deep learning tasks. 38 | 39 | Attributes: 40 | Same as torch.nn.Conv2d. 41 | 42 | Methods: 43 | forward(self, x): 44 | Performs 3D convolution on the input tensor x using the InflatedConv3d layer. 45 | 46 | Example: 47 | conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) 48 | output = conv_layer(input_tensor) 49 | """ 50 | def forward(self, x): 51 | """ 52 | Forward pass of the InflatedConv3d layer. 53 | 54 | Args: 55 | x (torch.Tensor): Input tensor to the layer. 56 | 57 | Returns: 58 | torch.Tensor: Output tensor after applying the InflatedConv3d layer. 59 | """ 60 | video_length = x.shape[2] 61 | 62 | x = rearrange(x, "b c f h w -> (b f) c h w") 63 | x = super().forward(x) 64 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 65 | 66 | return x 67 | 68 | 69 | class InflatedGroupNorm(nn.GroupNorm): 70 | """ 71 | InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm. 72 | It is used to apply group normalization to 3D tensors. 73 | 74 | Args: 75 | num_groups (int): The number of groups to divide the channels into. 76 | num_channels (int): The number of channels in the input tensor. 77 | eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5. 78 | affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True. 79 | 80 | Attributes: 81 | weight (torch.Tensor): The learnable weight tensor for scale. 82 | bias (torch.Tensor): The learnable bias tensor for shift. 83 | 84 | Forward method: 85 | x (torch.Tensor): Input tensor to be normalized. 86 | return (torch.Tensor): Normalized tensor. 87 | """ 88 | def forward(self, x): 89 | """ 90 | Performs a forward pass through the CustomClassName. 91 | 92 | :param x: Input tensor of shape (batch_size, channels, video_length, height, width). 93 | :return: Output tensor of shape (batch_size, channels, video_length, height, width). 94 | """ 95 | video_length = x.shape[2] 96 | 97 | x = rearrange(x, "b c f h w -> (b f) c h w") 98 | x = super().forward(x) 99 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 100 | 101 | return x 102 | 103 | 104 | class Upsample3D(nn.Module): 105 | """ 106 | Upsample3D is a PyTorch module that upsamples a 3D tensor. 107 | 108 | Args: 109 | channels (int): The number of channels in the input tensor. 110 | use_conv (bool): Whether to use a convolutional layer for upsampling. 111 | use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling. 112 | out_channels (int): The number of channels in the output tensor. 113 | name (str): The name of the convolutional layer. 114 | """ 115 | def __init__( 116 | self, 117 | channels, 118 | use_conv=False, 119 | use_conv_transpose=False, 120 | out_channels=None, 121 | name="conv", 122 | ): 123 | super().__init__() 124 | self.channels = channels 125 | self.out_channels = out_channels or channels 126 | self.use_conv = use_conv 127 | self.use_conv_transpose = use_conv_transpose 128 | self.name = name 129 | 130 | if use_conv_transpose: 131 | raise NotImplementedError 132 | if use_conv: 133 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 134 | 135 | def forward(self, hidden_states, output_size=None): 136 | """ 137 | Forward pass of the Upsample3D class. 138 | 139 | Args: 140 | hidden_states (torch.Tensor): Input tensor to be upsampled. 141 | output_size (tuple, optional): Desired output size of the upsampled tensor. 142 | 143 | Returns: 144 | torch.Tensor: Upsampled tensor. 145 | 146 | Raises: 147 | AssertionError: If the number of channels in the input tensor does not match the expected channels. 148 | """ 149 | assert hidden_states.shape[1] == self.channels 150 | 151 | if self.use_conv_transpose: 152 | raise NotImplementedError 153 | 154 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 155 | dtype = hidden_states.dtype 156 | if dtype == torch.bfloat16: 157 | hidden_states = hidden_states.to(torch.float32) 158 | 159 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 160 | if hidden_states.shape[0] >= 64: 161 | hidden_states = hidden_states.contiguous() 162 | 163 | # if `output_size` is passed we force the interpolation output 164 | # size and do not make use of `scale_factor=2` 165 | if output_size is None: 166 | hidden_states = F.interpolate( 167 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 168 | ) 169 | else: 170 | hidden_states = F.interpolate( 171 | hidden_states, size=output_size, mode="nearest" 172 | ) 173 | 174 | # If the input is bfloat16, we cast back to bfloat16 175 | if dtype == torch.bfloat16: 176 | hidden_states = hidden_states.to(dtype) 177 | 178 | # if self.use_conv: 179 | # if self.name == "conv": 180 | # hidden_states = self.conv(hidden_states) 181 | # else: 182 | # hidden_states = self.Conv2d_0(hidden_states) 183 | hidden_states = self.conv(hidden_states) 184 | 185 | return hidden_states 186 | 187 | 188 | class Downsample3D(nn.Module): 189 | """ 190 | The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to 191 | reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network. 192 | 193 | Attributes: 194 | channels (int): Number of input channels. 195 | use_conv (bool): Flag to use a convolutional layer for downsampling. 196 | out_channels (int, optional): Number of output channels. Defaults to input channels if None. 197 | padding (int): Padding added to the input. 198 | name (str): Name of the convolutional layer used for downsampling. 199 | 200 | Methods: 201 | forward(self, hidden_states): 202 | Downsamples the input tensor hidden_states and returns the downsampled tensor. 203 | """ 204 | def __init__( 205 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 206 | ): 207 | """ 208 | Downsamples the given input in the 3D space. 209 | 210 | Args: 211 | channels: The number of input channels. 212 | use_conv: Whether to use a convolutional layer for downsampling. 213 | out_channels: The number of output channels. If None, the input channels are used. 214 | padding: The amount of padding to be added to the input. 215 | name: The name of the convolutional layer. 216 | """ 217 | super().__init__() 218 | self.channels = channels 219 | self.out_channels = out_channels or channels 220 | self.use_conv = use_conv 221 | self.padding = padding 222 | stride = 2 223 | self.name = name 224 | 225 | if use_conv: 226 | self.conv = InflatedConv3d( 227 | self.channels, self.out_channels, 3, stride=stride, padding=padding 228 | ) 229 | else: 230 | raise NotImplementedError 231 | 232 | def forward(self, hidden_states): 233 | """ 234 | Forward pass for the Downsample3D class. 235 | 236 | Args: 237 | hidden_states (torch.Tensor): Input tensor to be downsampled. 238 | 239 | Returns: 240 | torch.Tensor: Downsampled tensor. 241 | 242 | Raises: 243 | AssertionError: If the number of channels in the input tensor does not match the expected channels. 244 | """ 245 | assert hidden_states.shape[1] == self.channels 246 | if self.use_conv and self.padding == 0: 247 | raise NotImplementedError 248 | 249 | assert hidden_states.shape[1] == self.channels 250 | hidden_states = self.conv(hidden_states) 251 | 252 | return hidden_states 253 | 254 | 255 | class ResnetBlock3D(nn.Module): 256 | """ 257 | The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet 258 | architectures for both image and video modeling tasks. 259 | 260 | Attributes: 261 | in_channels (int): Number of input channels. 262 | out_channels (int, optional): Number of output channels, defaults to in_channels if None. 263 | conv_shortcut (bool): Flag to use a convolutional shortcut. 264 | dropout (float): Dropout rate. 265 | temb_channels (int): Number of channels in the time embedding tensor. 266 | groups (int): Number of groups for the group normalization layers. 267 | eps (float): Epsilon value for group normalization. 268 | non_linearity (str): Type of nonlinearity to apply after convolutions. 269 | time_embedding_norm (str): Type of normalization for the time embedding. 270 | output_scale_factor (float): Scaling factor for the output tensor. 271 | use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection. 272 | use_inflated_groupnorm (bool): Flag to use inflated group normalization layers. 273 | 274 | Methods: 275 | forward(self, input_tensor, temb): 276 | Passes the input tensor and time embedding through the residual block and 277 | returns the output tensor. 278 | """ 279 | def __init__( 280 | self, 281 | *, 282 | in_channels, 283 | out_channels=None, 284 | conv_shortcut=False, 285 | dropout=0.0, 286 | temb_channels=512, 287 | groups=32, 288 | groups_out=None, 289 | pre_norm=True, 290 | eps=1e-6, 291 | non_linearity="swish", 292 | time_embedding_norm="default", 293 | output_scale_factor=1.0, 294 | use_in_shortcut=None, 295 | use_inflated_groupnorm=None, 296 | ): 297 | super().__init__() 298 | self.pre_norm = pre_norm 299 | self.pre_norm = True 300 | self.in_channels = in_channels 301 | out_channels = in_channels if out_channels is None else out_channels 302 | self.out_channels = out_channels 303 | self.use_conv_shortcut = conv_shortcut 304 | self.time_embedding_norm = time_embedding_norm 305 | self.output_scale_factor = output_scale_factor 306 | 307 | if groups_out is None: 308 | groups_out = groups 309 | 310 | assert use_inflated_groupnorm is not None 311 | if use_inflated_groupnorm: 312 | self.norm1 = InflatedGroupNorm( 313 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 314 | ) 315 | else: 316 | self.norm1 = torch.nn.GroupNorm( 317 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 318 | ) 319 | 320 | self.conv1 = InflatedConv3d( 321 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 322 | ) 323 | 324 | if temb_channels is not None: 325 | if self.time_embedding_norm == "default": 326 | time_emb_proj_out_channels = out_channels 327 | elif self.time_embedding_norm == "scale_shift": 328 | time_emb_proj_out_channels = out_channels * 2 329 | else: 330 | raise ValueError( 331 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 332 | ) 333 | 334 | self.time_emb_proj = torch.nn.Linear( 335 | temb_channels, time_emb_proj_out_channels 336 | ) 337 | else: 338 | self.time_emb_proj = None 339 | 340 | if use_inflated_groupnorm: 341 | self.norm2 = InflatedGroupNorm( 342 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 343 | ) 344 | else: 345 | self.norm2 = torch.nn.GroupNorm( 346 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 347 | ) 348 | self.dropout = torch.nn.Dropout(dropout) 349 | self.conv2 = InflatedConv3d( 350 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 351 | ) 352 | 353 | if non_linearity == "swish": 354 | self.nonlinearity = F.silu() 355 | elif non_linearity == "mish": 356 | self.nonlinearity = Mish() 357 | elif non_linearity == "silu": 358 | self.nonlinearity = nn.SiLU() 359 | 360 | self.use_in_shortcut = ( 361 | self.in_channels != self.out_channels 362 | if use_in_shortcut is None 363 | else use_in_shortcut 364 | ) 365 | 366 | self.conv_shortcut = None 367 | if self.use_in_shortcut: 368 | self.conv_shortcut = InflatedConv3d( 369 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 370 | ) 371 | 372 | def forward(self, input_tensor, temb): 373 | """ 374 | Forward pass for the ResnetBlock3D class. 375 | 376 | Args: 377 | input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer. 378 | temb (torch.Tensor): Token embedding tensor. 379 | 380 | Returns: 381 | torch.Tensor: Output tensor after passing through the ResnetBlock3D layer. 382 | """ 383 | hidden_states = input_tensor 384 | 385 | hidden_states = self.norm1(hidden_states) 386 | hidden_states = self.nonlinearity(hidden_states) 387 | 388 | hidden_states = self.conv1(hidden_states) 389 | 390 | if temb is not None: 391 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 392 | 393 | if temb is not None and self.time_embedding_norm == "default": 394 | hidden_states = hidden_states + temb 395 | 396 | hidden_states = self.norm2(hidden_states) 397 | 398 | if temb is not None and self.time_embedding_norm == "scale_shift": 399 | scale, shift = torch.chunk(temb, 2, dim=1) 400 | hidden_states = hidden_states * (1 + scale) + shift 401 | 402 | hidden_states = self.nonlinearity(hidden_states) 403 | 404 | hidden_states = self.dropout(hidden_states) 405 | hidden_states = self.conv2(hidden_states) 406 | 407 | if self.conv_shortcut is not None: 408 | input_tensor = self.conv_shortcut(input_tensor) 409 | 410 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 411 | 412 | return output_tensor 413 | 414 | 415 | class Mish(torch.nn.Module): 416 | """ 417 | The Mish class implements the Mish activation function, a smooth, non-monotonic function 418 | that can be used in neural networks as an alternative to traditional activation functions like ReLU. 419 | 420 | Methods: 421 | forward(self, hidden_states): 422 | Applies the Mish activation function to the input tensor hidden_states and 423 | returns the resulting tensor. 424 | """ 425 | def forward(self, hidden_states): 426 | """ 427 | Mish activation function. 428 | 429 | Args: 430 | hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to. 431 | 432 | Returns: 433 | hidden_states (torch.Tensor): The output tensor after applying the Mish activation function. 434 | """ 435 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 436 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | # scripts/inference.py 3 | 4 | """ 5 | This script contains the main inference pipeline for processing audio and image inputs to generate a video output. 6 | 7 | The script imports necessary packages and classes, defines a neural network model, 8 | and contains functions for processing audio embeddings and performing inference. 9 | 10 | The main inference process is outlined in the following steps: 11 | 1. Initialize the configuration. 12 | 2. Set up runtime variables. 13 | 3. Prepare the input data for inference (source image, face mask, and face embeddings). 14 | 4. Process the audio embeddings. 15 | 5. Build and freeze the model and scheduler. 16 | 6. Run the inference loop and save the result. 17 | 18 | Usage: 19 | This script can be run from the command line with the following arguments: 20 | - audio_path: Path to the audio file. 21 | - image_path: Path to the source image. 22 | - face_mask_path: Path to the face mask image. 23 | - face_emb_path: Path to the face embeddings file. 24 | - output_path: Path to save the output video. 25 | 26 | Example: 27 | python scripts/inference.py --audio_path audio.wav --image_path image.jpg 28 | --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4 29 | """ 30 | 31 | import argparse 32 | import gc 33 | import os 34 | 35 | import torch 36 | from diffusers import AutoencoderKL, DDIMScheduler 37 | from omegaconf import OmegaConf 38 | from torch import nn 39 | 40 | from hallo.animate.face_animate import FaceAnimatePipeline 41 | from hallo.datasets.audio_processor import AudioProcessor 42 | from hallo.datasets.image_processor import ImageProcessor 43 | from hallo.models.audio_proj import AudioProjModel 44 | from hallo.models.face_locator import FaceLocator 45 | from hallo.models.image_proj import ImageProjModel 46 | from hallo.models.unet_2d_condition import UNet2DConditionModel 47 | from hallo.models.unet_3d import UNet3DConditionModel 48 | from hallo.utils.util import tensor_to_video 49 | 50 | # from diffusers.utils.import_utils import is_xformers_available 51 | 52 | 53 | class Net(nn.Module): 54 | """ 55 | The Net class combines all the necessary modules for the inference process. 56 | 57 | Args: 58 | reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference. 59 | denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio. 60 | face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image. 61 | imageproj (nn.Module): The ImageProjector model used to project the source image onto the face. 62 | audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face. 63 | """ 64 | def __init__( 65 | self, 66 | reference_unet: UNet2DConditionModel, 67 | denoising_unet: UNet3DConditionModel, 68 | face_locator: FaceLocator, 69 | imageproj, 70 | audioproj, 71 | ): 72 | super().__init__() 73 | self.reference_unet = reference_unet 74 | self.denoising_unet = denoising_unet 75 | self.face_locator = face_locator 76 | self.imageproj = imageproj 77 | self.audioproj = audioproj 78 | 79 | def forward(self,): 80 | """ 81 | empty function to override abstract function of nn Module 82 | """ 83 | 84 | def get_modules(self): 85 | """ 86 | Simple method to avoid too-few-public-methods pylint error 87 | """ 88 | return { 89 | "reference_unet": self.reference_unet, 90 | "denoising_unet": self.denoising_unet, 91 | "face_locator": self.face_locator, 92 | "imageproj": self.imageproj, 93 | "audioproj": self.audioproj, 94 | } 95 | 96 | 97 | def process_audio_emb(audio_emb): 98 | """ 99 | Process the audio embedding to concatenate with other tensors. 100 | 101 | Parameters: 102 | audio_emb (torch.Tensor): The audio embedding tensor to process. 103 | 104 | Returns: 105 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. 106 | """ 107 | concatenated_tensors = [] 108 | 109 | for i in range(audio_emb.shape[0]): 110 | vectors_to_concat = [ 111 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] 112 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) 113 | 114 | audio_emb = torch.stack(concatenated_tensors, dim=0) 115 | 116 | return audio_emb 117 | 118 | 119 | 120 | def inference_process(args: argparse.Namespace, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16): 121 | """ 122 | Perform inference processing. 123 | 124 | Args: 125 | args (argparse.Namespace): Command-line arguments. 126 | 127 | This function initializes the configuration for the inference process. It sets up the necessary 128 | modules and variables to prepare for the upcoming inference steps. 129 | """ 130 | # 1. init config 131 | config = OmegaConf.load(args.config) 132 | config = OmegaConf.merge(config, vars(args)) 133 | 134 | 135 | if setting_steps is not None: 136 | config.inference_steps = setting_steps 137 | if setting_cfg is not None: 138 | config.cfg_scale = setting_cfg 139 | if settings_seed is not None: 140 | config.seed = int(settings_seed) 141 | if settings_fps is not None: 142 | config.data.export_video.fps = settings_fps 143 | if settings_motion_pose_scale is not None: 144 | config.pose_weight = settings_motion_pose_scale 145 | if settings_motion_face_scale is not None: 146 | config.face_weight = settings_motion_face_scale 147 | if settings_motion_lip_scale is not None: 148 | config.lip_weight = settings_motion_lip_scale 149 | if settings_n_motion_frames is not None: 150 | config.data.n_motion_frames = settings_n_motion_frames 151 | if settings_n_sample_frames is not None: 152 | config.data.n_sample_frames = settings_n_sample_frames 153 | 154 | 155 | source_image_path = config.source_image 156 | driving_audio_path = config.driving_audio 157 | save_path = config.save_path 158 | if not os.path.exists(save_path): 159 | os.makedirs(save_path) 160 | motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] 161 | if args.checkpoint is not None: 162 | config.audio_ckpt_dir = args.checkpoint 163 | # 2. runtime variables 164 | device = torch.device( 165 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 166 | if config.weight_dtype == "fp16": 167 | weight_dtype = torch.float16 168 | elif config.weight_dtype == "bf16": 169 | weight_dtype = torch.bfloat16 170 | elif config.weight_dtype == "fp32": 171 | weight_dtype = torch.float32 172 | else: 173 | weight_dtype = torch.float32 174 | 175 | # 3. prepare inference data 176 | # 3.1 prepare source image, face mask, face embeddings 177 | img_size = (config.data.source_image.width, 178 | config.data.source_image.height) 179 | clip_length = config.data.n_sample_frames 180 | face_analysis_model_path = config.face_analysis.model_path 181 | with ImageProcessor(img_size, face_analysis_model_path) as image_processor: 182 | source_image_pixels, \ 183 | source_image_face_region, \ 184 | source_image_face_emb, \ 185 | source_image_full_mask, \ 186 | source_image_face_mask, \ 187 | source_image_lip_mask = image_processor.preprocess( 188 | source_image_path, save_path, config.face_expand_ratio) 189 | 190 | # 3.2 prepare audio embeddings 191 | sample_rate = config.data.driving_audio.sample_rate 192 | assert sample_rate == 16000, "audio sample rate must be 16000" 193 | fps = config.data.export_video.fps 194 | wav2vec_model_path = config.wav2vec.model_path 195 | wav2vec_only_last_features = config.wav2vec.features == "last" 196 | audio_separator_model_file = config.audio_separator.model_path 197 | with AudioProcessor( 198 | sample_rate, 199 | fps, 200 | wav2vec_model_path, 201 | wav2vec_only_last_features, 202 | os.path.dirname(audio_separator_model_file), 203 | os.path.basename(audio_separator_model_file), 204 | os.path.join(save_path, "audio_preprocess") 205 | ) as audio_processor: 206 | audio_emb = audio_processor.preprocess(driving_audio_path) 207 | 208 | # Clear memory 209 | # del image_processor 210 | del audio_processor 211 | gc.collect() 212 | torch.cuda.empty_cache() 213 | 214 | 215 | # 4. build modules 216 | sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) 217 | if config.enable_zero_snr: 218 | sched_kwargs.update( 219 | rescale_betas_zero_snr=True, 220 | timestep_spacing="trailing", 221 | prediction_type="v_prediction", 222 | ) 223 | val_noise_scheduler = DDIMScheduler(**sched_kwargs) 224 | sched_kwargs.update({"beta_schedule": "scaled_linear"}) 225 | 226 | vae = AutoencoderKL.from_pretrained(config.vae.model_path) 227 | reference_unet = UNet2DConditionModel.from_pretrained( 228 | config.base_model_path, subfolder="unet") 229 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 230 | config.base_model_path, 231 | config.motion_module_path, 232 | subfolder="unet", 233 | unet_additional_kwargs=OmegaConf.to_container( 234 | config.unet_additional_kwargs), 235 | use_landmark=False, 236 | ) 237 | face_locator = FaceLocator(conditioning_embedding_channels=320) 238 | image_proj = ImageProjModel( 239 | cross_attention_dim=denoising_unet.config.cross_attention_dim, 240 | clip_embeddings_dim=512, 241 | clip_extra_context_tokens=4, 242 | ) 243 | 244 | audio_proj = AudioProjModel( 245 | seq_len=5, 246 | blocks=12, # use 12 layers' hidden states of wav2vec 247 | channels=768, # audio embedding channel 248 | intermediate_dim=512, 249 | output_dim=768, 250 | context_tokens=32, 251 | ).to(device=device, dtype=weight_dtype) 252 | 253 | audio_ckpt_dir = config.audio_ckpt_dir 254 | 255 | 256 | # Freeze 257 | vae.requires_grad_(False) 258 | image_proj.requires_grad_(False) 259 | reference_unet.requires_grad_(False) 260 | denoising_unet.requires_grad_(False) 261 | face_locator.requires_grad_(False) 262 | audio_proj.requires_grad_(False) 263 | 264 | # Not working soryy :( 265 | # if is_xformers_available(): 266 | # reference_unet.enable_xformers_memory_efficient_attention() 267 | # denoising_unet.enable_xformers_memory_efficient_attention() 268 | 269 | reference_unet.enable_gradient_checkpointing() 270 | denoising_unet.enable_gradient_checkpointing() 271 | 272 | net = Net( 273 | reference_unet, 274 | denoising_unet, 275 | face_locator, 276 | image_proj, 277 | audio_proj, 278 | ) 279 | 280 | m,u = net.load_state_dict( 281 | torch.load( 282 | os.path.join(audio_ckpt_dir, "net.pth"), 283 | map_location="cpu", 284 | ), 285 | ) 286 | assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." 287 | print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth")) 288 | 289 | # 5. inference 290 | pipeline = FaceAnimatePipeline( 291 | vae=vae, 292 | reference_unet=net.reference_unet, 293 | denoising_unet=net.denoising_unet, 294 | face_locator=net.face_locator, 295 | scheduler=val_noise_scheduler, 296 | image_proj=net.imageproj, 297 | ) 298 | pipeline.to(device=device, dtype=weight_dtype) 299 | 300 | audio_emb = process_audio_emb(audio_emb) 301 | 302 | source_image_pixels = source_image_pixels.unsqueeze(0) 303 | source_image_face_region = source_image_face_region.unsqueeze(0) 304 | source_image_face_emb = source_image_face_emb.reshape(1, -1) 305 | source_image_face_emb = torch.tensor(source_image_face_emb) 306 | 307 | source_image_full_mask = [ 308 | (mask.repeat(clip_length, 1)) 309 | for mask in source_image_full_mask 310 | ] 311 | source_image_face_mask = [ 312 | (mask.repeat(clip_length, 1)) 313 | for mask in source_image_face_mask 314 | ] 315 | source_image_lip_mask = [ 316 | (mask.repeat(clip_length, 1)) 317 | for mask in source_image_lip_mask 318 | ] 319 | 320 | 321 | times = audio_emb.shape[0] // clip_length 322 | 323 | tensor_result = [] 324 | 325 | generator = torch.manual_seed(42) 326 | 327 | for t in range(times): 328 | 329 | if len(tensor_result) == 0: 330 | # The first iteration 331 | motion_zeros = source_image_pixels.repeat( 332 | config.data.n_motion_frames, 1, 1, 1) 333 | motion_zeros = motion_zeros.to( 334 | dtype=source_image_pixels.dtype, device=source_image_pixels.device) 335 | pixel_values_ref_img = torch.cat( 336 | [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames 337 | else: 338 | motion_frames = tensor_result[-1][0] 339 | motion_frames = motion_frames.permute(1, 0, 2, 3) 340 | motion_frames = motion_frames[0-config.data.n_motion_frames:] 341 | motion_frames = motion_frames * 2.0 - 1.0 342 | motion_frames = motion_frames.to( 343 | dtype=source_image_pixels.dtype, device=source_image_pixels.device) 344 | pixel_values_ref_img = torch.cat( 345 | [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames 346 | 347 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 348 | 349 | audio_tensor = audio_emb[ 350 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) 351 | ] 352 | audio_tensor = audio_tensor.unsqueeze(0) 353 | audio_tensor = audio_tensor.to( 354 | device=net.audioproj.device, dtype=net.audioproj.dtype) 355 | audio_tensor = net.audioproj(audio_tensor) 356 | 357 | # Get all params 358 | print( 359 | f""" 360 | inference {t+1} / {times} 361 | """ 362 | ) 363 | 364 | pipeline_output = pipeline( 365 | ref_image=pixel_values_ref_img, 366 | audio_tensor=audio_tensor, 367 | face_emb=source_image_face_emb, 368 | face_mask=source_image_face_region, 369 | pixel_values_full_mask=source_image_full_mask, 370 | pixel_values_face_mask=source_image_face_mask, 371 | pixel_values_lip_mask=source_image_lip_mask, 372 | width=img_size[0], 373 | height=img_size[1], 374 | video_length=clip_length, 375 | num_inference_steps=config.inference_steps, 376 | guidance_scale=config.cfg_scale, 377 | generator=generator, 378 | motion_scale=motion_scale, 379 | ) 380 | 381 | tensor_result.append(pipeline_output.videos) 382 | 383 | tensor_result = torch.cat(tensor_result, dim=2) 384 | tensor_result = tensor_result.squeeze(0) 385 | 386 | output_file = config.output 387 | # save the result after all iteration 388 | tensor_to_video(tensor_result, output_file, driving_audio_path) 389 | 390 | 391 | if __name__ == "__main__": 392 | parser = argparse.ArgumentParser() 393 | 394 | parser.add_argument( 395 | "-c", "--config", default="configs/inference/default.yaml") 396 | parser.add_argument("--source_image", type=str, required=False, 397 | help="source image", default="test_data/source_images/6.jpg") 398 | parser.add_argument("--driving_audio", type=str, required=False, 399 | help="driving audio", default="test_data/driving_audios/singing/sing_4.wav") 400 | parser.add_argument( 401 | "--output", type=str, help="output video file name", default=".cache/output.mp4") 402 | parser.add_argument( 403 | "--pose_weight", type=float, help="weight of pose", default=1.0) 404 | parser.add_argument( 405 | "--face_weight", type=float, help="weight of face", default=1.0) 406 | parser.add_argument( 407 | "--lip_weight", type=float, help="weight of lip", default=1.0) 408 | parser.add_argument( 409 | "--face_expand_ratio", type=float, help="face region", default=1.2) 410 | parser.add_argument( 411 | "--checkpoint", type=str, help="which checkpoint", default=None) 412 | parser.add_argument("--setting_steps", type=int, default=40) 413 | parser.add_argument("--setting_cfg", type=float, default=3.5) 414 | parser.add_argument("--settings_seed", type=int, default=42) 415 | parser.add_argument("--settings_fps", type=int, default=25) 416 | parser.add_argument("--settings_motion_pose_scale", type=float, default=1.1) 417 | parser.add_argument("--settings_motion_face_scale", type=float, default=1.1) 418 | parser.add_argument("--settings_motion_lip_scale", type=float, default=1.1) 419 | parser.add_argument("--settings_n_motion_frames", type=int, default=2) 420 | parser.add_argument("--settings_n_sample_frames", type=int, default=16) 421 | 422 | command_line_args = parser.parse_args() 423 | 424 | inference_process( 425 | command_line_args, 426 | command_line_args.setting_steps, 427 | command_line_args.setting_cfg, 428 | command_line_args.settings_seed, 429 | command_line_args.settings_fps, 430 | command_line_args.settings_motion_pose_scale, 431 | command_line_args.settings_motion_face_scale, 432 | command_line_args.settings_motion_lip_scale, 433 | command_line_args.settings_n_motion_frames, 434 | command_line_args.settings_n_sample_frames 435 | ) 436 | -------------------------------------------------------------------------------- /hallo/animate/face_animate_static.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques. 4 | It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments. 5 | The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance. 6 | 7 | Functions and Classes: 8 | - StaticPipelineOutput: A class that represents the output of the animation pipeline, c 9 | ontaining properties and methods related to the generated images. 10 | - prepare_latents: A function that prepares the initial noise for the animation process, 11 | scaling it according to the scheduler's requirements. 12 | - prepare_condition: A function that processes the user-provided conditions 13 | (e.g., facial expressions) and prepares them for use in the animation pipeline. 14 | - decode_latents: A function that decodes the latent representations of the face animations into 15 | their corresponding image formats. 16 | - prepare_extra_step_kwargs: A function that prepares additional parameters for each step of 17 | the animation process, such as the generator and eta values. 18 | 19 | Dependencies: 20 | - numpy: A library for numerical computing. 21 | - torch: A machine learning library based on PyTorch. 22 | - diffusers: A library for image-to-image diffusion models. 23 | - transformers: A library for pre-trained transformer models. 24 | 25 | Usage: 26 | - To create an instance of the animation pipeline, provide the necessary components such as 27 | the VAE, reference UNET, denoising UNET, face locator, and image processor. 28 | - Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as 29 | required for the animation process. 30 | - Generate the face animations by decoding the latents and processing the conditions. 31 | 32 | Note: 33 | - The module is designed to work with the diffusers library, which is based on 34 | the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765). 35 | - The face animations generated by this module should be used for entertainment purposes 36 | only and should respect the rights and privacy of the individuals involved. 37 | """ 38 | import inspect 39 | from dataclasses import dataclass 40 | from typing import Callable, List, Optional, Union 41 | 42 | import numpy as np 43 | import torch 44 | from diffusers import DiffusionPipeline 45 | from diffusers.image_processor import VaeImageProcessor 46 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 47 | EulerAncestralDiscreteScheduler, 48 | EulerDiscreteScheduler, LMSDiscreteScheduler, 49 | PNDMScheduler) 50 | from diffusers.utils import BaseOutput, is_accelerate_available 51 | from diffusers.utils.torch_utils import randn_tensor 52 | from einops import rearrange 53 | from tqdm import tqdm 54 | from transformers import CLIPImageProcessor 55 | 56 | from hallo.models.mutual_self_attention import ReferenceAttentionControl 57 | 58 | if is_accelerate_available(): 59 | from accelerate import cpu_offload 60 | else: 61 | raise ImportError("Please install accelerate via `pip install accelerate`") 62 | 63 | 64 | @dataclass 65 | class StaticPipelineOutput(BaseOutput): 66 | """ 67 | StaticPipelineOutput is a class that represents the output of the static pipeline. 68 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. 69 | 70 | Attributes: 71 | images (Union[torch.Tensor, np.ndarray]): The generated images. 72 | """ 73 | images: Union[torch.Tensor, np.ndarray] 74 | 75 | 76 | class StaticPipeline(DiffusionPipeline): 77 | """ 78 | StaticPipelineOutput is a class that represents the output of the static pipeline. 79 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. 80 | 81 | Attributes: 82 | images (Union[torch.Tensor, np.ndarray]): The generated images. 83 | """ 84 | _optional_components = [] 85 | 86 | def __init__( 87 | self, 88 | vae, 89 | reference_unet, 90 | denoising_unet, 91 | face_locator, 92 | imageproj, 93 | scheduler: Union[ 94 | DDIMScheduler, 95 | PNDMScheduler, 96 | LMSDiscreteScheduler, 97 | EulerDiscreteScheduler, 98 | EulerAncestralDiscreteScheduler, 99 | DPMSolverMultistepScheduler, 100 | ], 101 | ): 102 | super().__init__() 103 | 104 | self.register_modules( 105 | vae=vae, 106 | reference_unet=reference_unet, 107 | denoising_unet=denoising_unet, 108 | face_locator=face_locator, 109 | scheduler=scheduler, 110 | imageproj=imageproj, 111 | ) 112 | self.vae_scale_factor = 2 ** ( 113 | len(self.vae.config.block_out_channels) - 1) 114 | self.clip_image_processor = CLIPImageProcessor() 115 | self.ref_image_processor = VaeImageProcessor( 116 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 117 | ) 118 | self.cond_image_processor = VaeImageProcessor( 119 | vae_scale_factor=self.vae_scale_factor, 120 | do_convert_rgb=True, 121 | do_normalize=False, 122 | ) 123 | 124 | def enable_vae_slicing(self): 125 | """ 126 | Enable VAE slicing. 127 | 128 | This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images. 129 | """ 130 | self.vae.enable_slicing() 131 | 132 | def disable_vae_slicing(self): 133 | """ 134 | Disable vae slicing. 135 | 136 | This function disables the vae slicing for the StaticPipeline object. 137 | It calls the `disable_slicing()` method of the vae model. 138 | This is useful when you want to use the entire vae model for decoding latents 139 | instead of slicing it for better performance. 140 | """ 141 | self.vae.disable_slicing() 142 | 143 | def enable_sequential_cpu_offload(self, gpu_id=0): 144 | """ 145 | Offloads selected models to the GPU for increased performance. 146 | 147 | Args: 148 | gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0. 149 | """ 150 | device = torch.device(f"cuda:{gpu_id}") 151 | 152 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 153 | if cpu_offloaded_model is not None: 154 | cpu_offload(cpu_offloaded_model, device) 155 | 156 | @property 157 | def _execution_device(self): 158 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 159 | return self.device 160 | for module in self.unet.modules(): 161 | if ( 162 | hasattr(module, "_hf_hook") 163 | and hasattr(module._hf_hook, "execution_device") 164 | and module._hf_hook.execution_device is not None 165 | ): 166 | return torch.device(module._hf_hook.execution_device) 167 | return self.device 168 | 169 | def decode_latents(self, latents): 170 | """ 171 | Decode the given latents to video frames. 172 | 173 | Parameters: 174 | latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width). 175 | 176 | Returns: 177 | video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width). 178 | """ 179 | video_length = latents.shape[2] 180 | latents = 1 / 0.18215 * latents 181 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 182 | # video = self.vae.decode(latents).sample 183 | video = [] 184 | for frame_idx in tqdm(range(latents.shape[0])): 185 | video.append(self.vae.decode( 186 | latents[frame_idx: frame_idx + 1]).sample) 187 | video = torch.cat(video) 188 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 189 | video = (video / 2 + 0.5).clamp(0, 1) 190 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 191 | video = video.cpu().float().numpy() 192 | return video 193 | 194 | def prepare_extra_step_kwargs(self, generator, eta): 195 | """ 196 | Prepare extra keyword arguments for the scheduler step. 197 | 198 | Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler. 199 | 200 | Args: 201 | generator (Optional[torch.Generator]): A random number generator for reproducibility. 202 | eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1. 203 | 204 | Returns: 205 | dict: A dictionary containing the extra keyword arguments for the scheduler step. 206 | """ 207 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 208 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 209 | # and should be between [0, 1] 210 | 211 | accepts_eta = "eta" in set( 212 | inspect.signature(self.scheduler.step).parameters.keys() 213 | ) 214 | extra_step_kwargs = {} 215 | if accepts_eta: 216 | extra_step_kwargs["eta"] = eta 217 | 218 | # check if the scheduler accepts generator 219 | accepts_generator = "generator" in set( 220 | inspect.signature(self.scheduler.step).parameters.keys() 221 | ) 222 | if accepts_generator: 223 | extra_step_kwargs["generator"] = generator 224 | return extra_step_kwargs 225 | 226 | def prepare_latents( 227 | self, 228 | batch_size, 229 | num_channels_latents, 230 | width, 231 | height, 232 | dtype, 233 | device, 234 | generator, 235 | latents=None, 236 | ): 237 | """ 238 | Prepares the initial latents for the diffusion pipeline. 239 | 240 | Args: 241 | batch_size (int): The number of images to generate in one forward pass. 242 | num_channels_latents (int): The number of channels in the latents tensor. 243 | width (int): The width of the latents tensor. 244 | height (int): The height of the latents tensor. 245 | dtype (torch.dtype): The data type of the latents tensor. 246 | device (torch.device): The device to place the latents tensor on. 247 | generator (Optional[torch.Generator], optional): A random number generator 248 | for reproducibility. Defaults to None. 249 | latents (Optional[torch.Tensor], optional): Pre-computed latents to use as 250 | initial conditions for the diffusion process. Defaults to None. 251 | 252 | Returns: 253 | torch.Tensor: The prepared latents tensor. 254 | """ 255 | shape = ( 256 | batch_size, 257 | num_channels_latents, 258 | height // self.vae_scale_factor, 259 | width // self.vae_scale_factor, 260 | ) 261 | if isinstance(generator, list) and len(generator) != batch_size: 262 | raise ValueError( 263 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 264 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 265 | ) 266 | 267 | if latents is None: 268 | latents = randn_tensor( 269 | shape, generator=generator, device=device, dtype=dtype 270 | ) 271 | else: 272 | latents = latents.to(device) 273 | 274 | # scale the initial noise by the standard deviation required by the scheduler 275 | latents = latents * self.scheduler.init_noise_sigma 276 | return latents 277 | 278 | def prepare_condition( 279 | self, 280 | cond_image, 281 | width, 282 | height, 283 | device, 284 | dtype, 285 | do_classififer_free_guidance=False, 286 | ): 287 | """ 288 | Prepares the condition for the face animation pipeline. 289 | 290 | Args: 291 | cond_image (torch.Tensor): The conditional image tensor. 292 | width (int): The width of the output image. 293 | height (int): The height of the output image. 294 | device (torch.device): The device to run the pipeline on. 295 | dtype (torch.dtype): The data type of the tensor. 296 | do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False. 297 | 298 | Returns: 299 | Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors. 300 | """ 301 | image = self.cond_image_processor.preprocess( 302 | cond_image, height=height, width=width 303 | ).to(dtype=torch.float32) 304 | 305 | image = image.to(device=device, dtype=dtype) 306 | 307 | if do_classififer_free_guidance: 308 | image = torch.cat([image] * 2) 309 | 310 | return image 311 | 312 | @torch.no_grad() 313 | def __call__( 314 | self, 315 | ref_image, 316 | face_mask, 317 | width, 318 | height, 319 | num_inference_steps, 320 | guidance_scale, 321 | face_embedding, 322 | num_images_per_prompt=1, 323 | eta: float = 0.0, 324 | generator: Optional[Union[torch.Generator, 325 | List[torch.Generator]]] = None, 326 | output_type: Optional[str] = "tensor", 327 | return_dict: bool = True, 328 | callback: Optional[Callable[[ 329 | int, int, torch.FloatTensor], None]] = None, 330 | callback_steps: Optional[int] = 1, 331 | **kwargs, 332 | ): 333 | # Default height and width to unet 334 | height = height or self.unet.config.sample_size * self.vae_scale_factor 335 | width = width or self.unet.config.sample_size * self.vae_scale_factor 336 | 337 | device = self._execution_device 338 | 339 | do_classifier_free_guidance = guidance_scale > 1.0 340 | 341 | # Prepare timesteps 342 | self.scheduler.set_timesteps(num_inference_steps, device=device) 343 | timesteps = self.scheduler.timesteps 344 | 345 | batch_size = 1 346 | 347 | image_prompt_embeds = self.imageproj(face_embedding) 348 | uncond_image_prompt_embeds = self.imageproj( 349 | torch.zeros_like(face_embedding)) 350 | 351 | if do_classifier_free_guidance: 352 | image_prompt_embeds = torch.cat( 353 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 354 | ) 355 | 356 | reference_control_writer = ReferenceAttentionControl( 357 | self.reference_unet, 358 | do_classifier_free_guidance=do_classifier_free_guidance, 359 | mode="write", 360 | batch_size=batch_size, 361 | fusion_blocks="full", 362 | ) 363 | reference_control_reader = ReferenceAttentionControl( 364 | self.denoising_unet, 365 | do_classifier_free_guidance=do_classifier_free_guidance, 366 | mode="read", 367 | batch_size=batch_size, 368 | fusion_blocks="full", 369 | ) 370 | 371 | num_channels_latents = self.denoising_unet.in_channels 372 | latents = self.prepare_latents( 373 | batch_size * num_images_per_prompt, 374 | num_channels_latents, 375 | width, 376 | height, 377 | face_embedding.dtype, 378 | device, 379 | generator, 380 | ) 381 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w') 382 | # latents_dtype = latents.dtype 383 | 384 | # Prepare extra step kwargs. 385 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 386 | 387 | # Prepare ref image latents 388 | ref_image_tensor = self.ref_image_processor.preprocess( 389 | ref_image, height=height, width=width 390 | ) # (bs, c, width, height) 391 | ref_image_tensor = ref_image_tensor.to( 392 | dtype=self.vae.dtype, device=self.vae.device 393 | ) 394 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 395 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 396 | 397 | # Prepare face mask image 398 | face_mask_tensor = self.cond_image_processor.preprocess( 399 | face_mask, height=height, width=width 400 | ) 401 | face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w) 402 | face_mask_tensor = face_mask_tensor.to( 403 | device=device, dtype=self.face_locator.dtype 404 | ) 405 | mask_fea = self.face_locator(face_mask_tensor) 406 | mask_fea = ( 407 | torch.cat( 408 | [mask_fea] * 2) if do_classifier_free_guidance else mask_fea 409 | ) 410 | 411 | # denoising loop 412 | num_warmup_steps = len(timesteps) - \ 413 | num_inference_steps * self.scheduler.order 414 | with self.progress_bar(total=num_inference_steps) as progress_bar: 415 | for i, t in enumerate(timesteps): 416 | # 1. Forward reference image 417 | if i == 0: 418 | self.reference_unet( 419 | ref_image_latents.repeat( 420 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 421 | ), 422 | torch.zeros_like(t), 423 | encoder_hidden_states=image_prompt_embeds, 424 | return_dict=False, 425 | ) 426 | 427 | # 2. Update reference unet feature into denosing net 428 | reference_control_reader.update(reference_control_writer) 429 | 430 | # 3.1 expand the latents if we are doing classifier free guidance 431 | latent_model_input = ( 432 | torch.cat( 433 | [latents] * 2) if do_classifier_free_guidance else latents 434 | ) 435 | latent_model_input = self.scheduler.scale_model_input( 436 | latent_model_input, t 437 | ) 438 | 439 | noise_pred = self.denoising_unet( 440 | latent_model_input, 441 | t, 442 | encoder_hidden_states=image_prompt_embeds, 443 | mask_cond_fea=mask_fea, 444 | return_dict=False, 445 | )[0] 446 | 447 | # perform guidance 448 | if do_classifier_free_guidance: 449 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 450 | noise_pred = noise_pred_uncond + guidance_scale * ( 451 | noise_pred_text - noise_pred_uncond 452 | ) 453 | 454 | # compute the previous noisy sample x_t -> x_t-1 455 | latents = self.scheduler.step( 456 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 457 | )[0] 458 | 459 | # call the callback, if provided 460 | if i == len(timesteps) - 1 or ( 461 | (i + 1) > num_warmup_steps and (i + 462 | 1) % self.scheduler.order == 0 463 | ): 464 | progress_bar.update() 465 | if callback is not None and i % callback_steps == 0: 466 | step_idx = i // getattr(self.scheduler, "order", 1) 467 | callback(step_idx, t, latents) 468 | reference_control_reader.clear() 469 | reference_control_writer.clear() 470 | 471 | # Post-processing 472 | image = self.decode_latents(latents) # (b, c, 1, h, w) 473 | 474 | # Convert to tensor 475 | if output_type == "tensor": 476 | image = torch.from_numpy(image) 477 | 478 | if not return_dict: 479 | return image 480 | 481 | return StaticPipelineOutput(images=image) 482 | -------------------------------------------------------------------------------- /hallo/animate/face_animate.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is responsible for animating faces in videos using a combination of deep learning techniques. 4 | It provides a pipeline for generating face animations by processing video frames and extracting face features. 5 | The module utilizes various schedulers and utilities for efficient face animation and supports different types 6 | of latents for more control over the animation process. 7 | 8 | Functions and Classes: 9 | - FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks. 10 | - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.). 11 | - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements. 12 | - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers. 13 | - decode_latents: Decodes the latents into video frames, ready for animation. 14 | 15 | Usage: 16 | - Import the necessary packages and classes. 17 | - Create a FaceAnimatePipeline instance with the required components. 18 | - Prepare the latents for the animation process. 19 | - Use the pipeline to generate the animated video. 20 | 21 | Note: 22 | - This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning. 23 | - The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases. 24 | """ 25 | 26 | import inspect 27 | from dataclasses import dataclass 28 | from typing import Callable, List, Optional, Union 29 | 30 | import numpy as np 31 | import torch 32 | from diffusers import (DDIMScheduler, DiffusionPipeline, 33 | DPMSolverMultistepScheduler, 34 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, 35 | LMSDiscreteScheduler, PNDMScheduler) 36 | from diffusers.image_processor import VaeImageProcessor 37 | from diffusers.utils import BaseOutput 38 | from diffusers.utils.torch_utils import randn_tensor 39 | from einops import rearrange, repeat 40 | from tqdm import tqdm 41 | 42 | from hallo.models.mutual_self_attention import ReferenceAttentionControl 43 | 44 | 45 | @dataclass 46 | class FaceAnimatePipelineOutput(BaseOutput): 47 | """ 48 | FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline. 49 | 50 | Attributes: 51 | videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames. 52 | 53 | Methods: 54 | __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames. 55 | """ 56 | videos: Union[torch.Tensor, np.ndarray] 57 | 58 | class FaceAnimatePipeline(DiffusionPipeline): 59 | """ 60 | FaceAnimatePipeline is a custom DiffusionPipeline for animating faces. 61 | 62 | It inherits from the DiffusionPipeline class and is used to animate faces by 63 | utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet, 64 | a face locator, and an image processor. The pipeline is responsible for generating 65 | and animating face latents, and decoding the latents to produce the final video output. 66 | 67 | Attributes: 68 | vae (VaeImageProcessor): Variational autoencoder for processing images. 69 | reference_unet (nn.Module): Reference UNet for mutual self-attention. 70 | denoising_unet (nn.Module): Denoising UNet for image denoising. 71 | face_locator (nn.Module): Face locator for detecting and cropping faces. 72 | image_proj (nn.Module): Image projector for processing images. 73 | scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, 74 | EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, 75 | DPMSolverMultistepScheduler]): Diffusion scheduler for 76 | controlling the noise level. 77 | 78 | Methods: 79 | __init__(self, vae, reference_unet, denoising_unet, face_locator, 80 | image_proj, scheduler): Initializes the FaceAnimatePipeline 81 | with the given components and scheduler. 82 | prepare_latents(self, batch_size, num_channels_latents, width, height, 83 | video_length, dtype, device, generator=None, latents=None): 84 | Prepares the initial latents for video generation. 85 | prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword 86 | arguments for the scheduler step. 87 | decode_latents(self, latents): Decodes the latents to produce the final 88 | video output. 89 | """ 90 | def __init__( 91 | self, 92 | vae, 93 | reference_unet, 94 | denoising_unet, 95 | face_locator, 96 | image_proj, 97 | scheduler: Union[ 98 | DDIMScheduler, 99 | PNDMScheduler, 100 | LMSDiscreteScheduler, 101 | EulerDiscreteScheduler, 102 | EulerAncestralDiscreteScheduler, 103 | DPMSolverMultistepScheduler, 104 | ], 105 | ) -> None: 106 | super().__init__() 107 | 108 | self.register_modules( 109 | vae=vae, 110 | reference_unet=reference_unet, 111 | denoising_unet=denoising_unet, 112 | face_locator=face_locator, 113 | scheduler=scheduler, 114 | image_proj=image_proj, 115 | ) 116 | 117 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1) 118 | 119 | self.ref_image_processor = VaeImageProcessor( 120 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, 121 | ) 122 | 123 | @property 124 | def _execution_device(self): 125 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 126 | return self.device 127 | for module in self.unet.modules(): 128 | if ( 129 | hasattr(module, "_hf_hook") 130 | and hasattr(module._hf_hook, "execution_device") 131 | and module._hf_hook.execution_device is not None 132 | ): 133 | return torch.device(module._hf_hook.execution_device) 134 | return self.device 135 | 136 | def prepare_latents( 137 | self, 138 | batch_size: int, # Number of videos to generate in parallel 139 | num_channels_latents: int, # Number of channels in the latents 140 | width: int, # Width of the video frame 141 | height: int, # Height of the video frame 142 | video_length: int, # Length of the video in frames 143 | dtype: torch.dtype, # Data type of the latents 144 | device: torch.device, # Device to store the latents on 145 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility 146 | latents: Optional[torch.Tensor] = None # Pre-generated latents (optional) 147 | ): 148 | """ 149 | Prepares the initial latents for video generation. 150 | 151 | Args: 152 | batch_size (int): Number of videos to generate in parallel. 153 | num_channels_latents (int): Number of channels in the latents. 154 | width (int): Width of the video frame. 155 | height (int): Height of the video frame. 156 | video_length (int): Length of the video in frames. 157 | dtype (torch.dtype): Data type of the latents. 158 | device (torch.device): Device to store the latents on. 159 | generator (Optional[torch.Generator]): Random number generator for reproducibility. 160 | latents (Optional[torch.Tensor]): Pre-generated latents (optional). 161 | 162 | Returns: 163 | latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height) 164 | containing the initial latents for video generation. 165 | """ 166 | shape = ( 167 | batch_size, 168 | num_channels_latents, 169 | video_length, 170 | height // self.vae_scale_factor, 171 | width // self.vae_scale_factor, 172 | ) 173 | if isinstance(generator, list) and len(generator) != batch_size: 174 | raise ValueError( 175 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 176 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 177 | ) 178 | 179 | if latents is None: 180 | latents = randn_tensor( 181 | shape, generator=generator, device=device, dtype=dtype 182 | ) 183 | else: 184 | latents = latents.to(device) 185 | 186 | # scale the initial noise by the standard deviation required by the scheduler 187 | latents = latents * self.scheduler.init_noise_sigma 188 | return latents 189 | 190 | def prepare_extra_step_kwargs(self, generator, eta): 191 | """ 192 | Prepares extra keyword arguments for the scheduler step. 193 | 194 | Args: 195 | generator (Optional[torch.Generator]): Random number generator for reproducibility. 196 | eta (float): The eta (η) parameter used with the DDIMScheduler. 197 | It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1]. 198 | 199 | Returns: 200 | dict: A dictionary containing the extra keyword arguments for the scheduler step. 201 | """ 202 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 203 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 204 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 205 | # and should be between [0, 1] 206 | 207 | accepts_eta = "eta" in set( 208 | inspect.signature(self.scheduler.step).parameters.keys() 209 | ) 210 | extra_step_kwargs = {} 211 | if accepts_eta: 212 | extra_step_kwargs["eta"] = eta 213 | 214 | # check if the scheduler accepts generator 215 | accepts_generator = "generator" in set( 216 | inspect.signature(self.scheduler.step).parameters.keys() 217 | ) 218 | if accepts_generator: 219 | extra_step_kwargs["generator"] = generator 220 | return extra_step_kwargs 221 | 222 | def decode_latents(self, latents): 223 | """ 224 | Decode the latents to produce a video. 225 | 226 | Parameters: 227 | latents (torch.Tensor): The latents to be decoded. 228 | 229 | Returns: 230 | video (torch.Tensor): The decoded video. 231 | video_length (int): The length of the video in frames. 232 | """ 233 | video_length = latents.shape[2] 234 | latents = 1 / 0.18215 * latents 235 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 236 | # video = self.vae.decode(latents).sample 237 | video = [] 238 | for frame_idx in tqdm(range(latents.shape[0])): 239 | video.append(self.vae.decode( 240 | latents[frame_idx: frame_idx + 1]).sample) 241 | video = torch.cat(video) 242 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 243 | video = (video / 2 + 0.5).clamp(0, 1) 244 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 245 | video = video.cpu().float().numpy() 246 | return video 247 | 248 | 249 | @torch.no_grad() 250 | def __call__( 251 | self, 252 | ref_image, 253 | face_emb, 254 | audio_tensor, 255 | face_mask, 256 | pixel_values_full_mask, 257 | pixel_values_face_mask, 258 | pixel_values_lip_mask, 259 | width, 260 | height, 261 | video_length, 262 | num_inference_steps, 263 | guidance_scale, 264 | num_images_per_prompt=1, 265 | eta: float = 0.0, 266 | motion_scale: Optional[List[torch.Tensor]] = None, 267 | generator: Optional[Union[torch.Generator, 268 | List[torch.Generator]]] = None, 269 | output_type: Optional[str] = "tensor", 270 | return_dict: bool = True, 271 | callback: Optional[Callable[[ 272 | int, int, torch.FloatTensor], None]] = None, 273 | callback_steps: Optional[int] = 1, 274 | **kwargs, 275 | ): 276 | # Default height and width to unet 277 | height = height or self.unet.config.sample_size * self.vae_scale_factor 278 | width = width or self.unet.config.sample_size * self.vae_scale_factor 279 | 280 | device = self._execution_device 281 | 282 | do_classifier_free_guidance = guidance_scale > 1.0 283 | 284 | # Prepare timesteps 285 | self.scheduler.set_timesteps(num_inference_steps, device=device) 286 | timesteps = self.scheduler.timesteps 287 | 288 | batch_size = 1 289 | 290 | # prepare clip image embeddings 291 | clip_image_embeds = face_emb 292 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype) 293 | 294 | encoder_hidden_states = self.image_proj(clip_image_embeds) 295 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds)) 296 | 297 | if do_classifier_free_guidance: 298 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0) 299 | 300 | reference_control_writer = ReferenceAttentionControl( 301 | self.reference_unet, 302 | do_classifier_free_guidance=do_classifier_free_guidance, 303 | mode="write", 304 | batch_size=batch_size, 305 | fusion_blocks="full", 306 | ) 307 | reference_control_reader = ReferenceAttentionControl( 308 | self.denoising_unet, 309 | do_classifier_free_guidance=do_classifier_free_guidance, 310 | mode="read", 311 | batch_size=batch_size, 312 | fusion_blocks="full", 313 | ) 314 | 315 | num_channels_latents = self.denoising_unet.in_channels 316 | 317 | latents = self.prepare_latents( 318 | batch_size * num_images_per_prompt, 319 | num_channels_latents, 320 | width, 321 | height, 322 | video_length, 323 | clip_image_embeds.dtype, 324 | device, 325 | generator, 326 | ) 327 | 328 | # Prepare extra step kwargs. 329 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 330 | 331 | # Prepare ref image latents 332 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w") 333 | ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height) 334 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device) 335 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 336 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 337 | 338 | 339 | face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W) 340 | face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length) 341 | face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W) 342 | face_mask = self.face_locator(face_mask) 343 | face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask 344 | 345 | pixel_values_full_mask = ( 346 | [torch.cat([mask] * 2) for mask in pixel_values_full_mask] 347 | if do_classifier_free_guidance 348 | else pixel_values_full_mask 349 | ) 350 | pixel_values_face_mask = ( 351 | [torch.cat([mask] * 2) for mask in pixel_values_face_mask] 352 | if do_classifier_free_guidance 353 | else pixel_values_face_mask 354 | ) 355 | pixel_values_lip_mask = ( 356 | [torch.cat([mask] * 2) for mask in pixel_values_lip_mask] 357 | if do_classifier_free_guidance 358 | else pixel_values_lip_mask 359 | ) 360 | pixel_values_face_mask_ = [] 361 | for mask in pixel_values_face_mask: 362 | pixel_values_face_mask_.append( 363 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 364 | pixel_values_face_mask = pixel_values_face_mask_ 365 | pixel_values_lip_mask_ = [] 366 | for mask in pixel_values_lip_mask: 367 | pixel_values_lip_mask_.append( 368 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 369 | pixel_values_lip_mask = pixel_values_lip_mask_ 370 | pixel_values_full_mask_ = [] 371 | for mask in pixel_values_full_mask: 372 | pixel_values_full_mask_.append( 373 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 374 | pixel_values_full_mask = pixel_values_full_mask_ 375 | 376 | 377 | uncond_audio_tensor = torch.zeros_like(audio_tensor) 378 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0) 379 | audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device) 380 | 381 | # denoising loop 382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 383 | with self.progress_bar(total=num_inference_steps) as progress_bar: 384 | for i, t in enumerate(timesteps): 385 | # Forward reference image 386 | if i == 0: 387 | self.reference_unet( 388 | ref_image_latents.repeat( 389 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 390 | ), 391 | torch.zeros_like(t), 392 | encoder_hidden_states=encoder_hidden_states, 393 | return_dict=False, 394 | ) 395 | reference_control_reader.update(reference_control_writer) 396 | 397 | # expand the latents if we are doing classifier free guidance 398 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 399 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 400 | 401 | noise_pred = self.denoising_unet( 402 | latent_model_input, 403 | t, 404 | encoder_hidden_states=encoder_hidden_states, 405 | mask_cond_fea=face_mask, 406 | full_mask=pixel_values_full_mask, 407 | face_mask=pixel_values_face_mask, 408 | lip_mask=pixel_values_lip_mask, 409 | audio_embedding=audio_tensor, 410 | motion_scale=motion_scale, 411 | return_dict=False, 412 | )[0] 413 | 414 | # perform guidance 415 | if do_classifier_free_guidance: 416 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 417 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 418 | 419 | # compute the previous noisy sample x_t -> x_t-1 420 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 421 | 422 | # call the callback, if provided 423 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: 424 | progress_bar.update() 425 | if callback is not None and i % callback_steps == 0: 426 | step_idx = i // getattr(self.scheduler, "order", 1) 427 | callback(step_idx, t, latents) 428 | 429 | reference_control_reader.clear() 430 | reference_control_writer.clear() 431 | 432 | # Post-processing 433 | images = self.decode_latents(latents) # (b, c, f, h, w) 434 | 435 | # Convert to tensor 436 | if output_type == "tensor": 437 | images = torch.from_numpy(images) 438 | 439 | if not return_dict: 440 | return images 441 | 442 | return FaceAnimatePipelineOutput(videos=images) 443 | -------------------------------------------------------------------------------- /hallo/models/transformer_2d.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | # src/models/transformer_2d.py 3 | 4 | """ 5 | This module defines the Transformer2DModel, a PyTorch model that extends ModelMixin and ConfigMixin. It includes 6 | methods for gradient checkpointing, forward propagation, and various utility functions. The model is designed for 7 | 2D image-related tasks and uses LoRa (Low-Rank All-Attention) compatible layers for efficient attention computation. 8 | 9 | The file includes the following import statements: 10 | 11 | - From dataclasses import dataclass 12 | - From typing import Any, Dict, Optional 13 | - Import torch 14 | - From diffusers.configuration_utils import ConfigMixin, register_to_config 15 | - From diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 16 | - From diffusers.models.modeling_utils import ModelMixin 17 | - From diffusers.models.normalization import AdaLayerNormSingle 18 | - From diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, 19 | is_torch_version) 20 | - From torch import nn 21 | - From .attention import BasicTransformerBlock 22 | 23 | The file also includes the following classes and functions: 24 | 25 | - Transformer2DModel: A model class that extends ModelMixin and ConfigMixin. It includes methods for gradient 26 | checkpointing, forward propagation, and various utility functions. 27 | - _set_gradient_checkpointing: A utility function to set gradient checkpointing for a given module. 28 | - forward: The forward propagation method for the Transformer2DModel. 29 | 30 | To use this module, you can import the Transformer2DModel class and create an instance of the model with the desired 31 | configuration. Then, you can use the forward method to pass input tensors through the model and get the output tensors. 32 | """ 33 | 34 | from dataclasses import dataclass 35 | from typing import Any, Dict, Optional 36 | 37 | import torch 38 | from diffusers.configuration_utils import ConfigMixin, register_to_config 39 | # from diffusers.models.embeddings import CaptionProjection 40 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 41 | from diffusers.models.modeling_utils import ModelMixin 42 | from diffusers.models.normalization import AdaLayerNormSingle 43 | from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, 44 | is_torch_version) 45 | from torch import nn 46 | 47 | from .attention import BasicTransformerBlock 48 | 49 | 50 | @dataclass 51 | class Transformer2DModelOutput(BaseOutput): 52 | """ 53 | The output of [`Transformer2DModel`]. 54 | 55 | Args: 56 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` 57 | or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 58 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 59 | distributions for the unnoised latent pixels. 60 | """ 61 | 62 | sample: torch.FloatTensor 63 | ref_feature: torch.FloatTensor 64 | 65 | 66 | class Transformer2DModel(ModelMixin, ConfigMixin): 67 | """ 68 | A 2D Transformer model for image-like data. 69 | 70 | Parameters: 71 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 72 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 73 | in_channels (`int`, *optional*): 74 | The number of channels in the input and output (specify if the input is **continuous**). 75 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 76 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 77 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 78 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 79 | This is fixed during training since it is used to learn a number of position embeddings. 80 | num_vector_embeds (`int`, *optional*): 81 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 82 | Includes the class for the masked latent pixel. 83 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 84 | num_embeds_ada_norm ( `int`, *optional*): 85 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 86 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 87 | added to the hidden states. 88 | 89 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 90 | attention_bias (`bool`, *optional*): 91 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 92 | """ 93 | 94 | _supports_gradient_checkpointing = True 95 | 96 | @register_to_config 97 | def __init__( 98 | self, 99 | num_attention_heads: int = 16, 100 | attention_head_dim: int = 88, 101 | in_channels: Optional[int] = None, 102 | out_channels: Optional[int] = None, 103 | num_layers: int = 1, 104 | dropout: float = 0.0, 105 | norm_num_groups: int = 32, 106 | cross_attention_dim: Optional[int] = None, 107 | attention_bias: bool = False, 108 | num_vector_embeds: Optional[int] = None, 109 | patch_size: Optional[int] = None, 110 | activation_fn: str = "geglu", 111 | num_embeds_ada_norm: Optional[int] = None, 112 | use_linear_projection: bool = False, 113 | only_cross_attention: bool = False, 114 | double_self_attention: bool = False, 115 | upcast_attention: bool = False, 116 | norm_type: str = "layer_norm", 117 | norm_elementwise_affine: bool = True, 118 | norm_eps: float = 1e-5, 119 | attention_type: str = "default", 120 | ): 121 | super().__init__() 122 | self.use_linear_projection = use_linear_projection 123 | self.num_attention_heads = num_attention_heads 124 | self.attention_head_dim = attention_head_dim 125 | inner_dim = num_attention_heads * attention_head_dim 126 | 127 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 128 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 129 | 130 | # 1. Transformer2DModel can process both standard continuous images of 131 | # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of 132 | # shape `(batch_size, num_image_vectors)` 133 | # Define whether input is continuous or discrete depending on configuration 134 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 135 | self.is_input_vectorized = num_vector_embeds is not None 136 | self.is_input_patches = in_channels is not None and patch_size is not None 137 | 138 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 139 | deprecation_message = ( 140 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 141 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 142 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 143 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 144 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 145 | ) 146 | deprecate( 147 | "norm_type!=num_embeds_ada_norm", 148 | "1.0.0", 149 | deprecation_message, 150 | standard_warn=False, 151 | ) 152 | norm_type = "ada_norm" 153 | 154 | if self.is_input_continuous and self.is_input_vectorized: 155 | raise ValueError( 156 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 157 | " sure that either `in_channels` or `num_vector_embeds` is None." 158 | ) 159 | 160 | if self.is_input_vectorized and self.is_input_patches: 161 | raise ValueError( 162 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 163 | " sure that either `num_vector_embeds` or `num_patches` is None." 164 | ) 165 | 166 | if ( 167 | not self.is_input_continuous 168 | and not self.is_input_vectorized 169 | and not self.is_input_patches 170 | ): 171 | raise ValueError( 172 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 173 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 174 | ) 175 | 176 | # 2. Define input layers 177 | self.in_channels = in_channels 178 | 179 | self.norm = torch.nn.GroupNorm( 180 | num_groups=norm_num_groups, 181 | num_channels=in_channels, 182 | eps=1e-6, 183 | affine=True, 184 | ) 185 | if use_linear_projection: 186 | self.proj_in = linear_cls(in_channels, inner_dim) 187 | else: 188 | self.proj_in = conv_cls( 189 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 190 | ) 191 | 192 | # 3. Define transformers blocks 193 | self.transformer_blocks = nn.ModuleList( 194 | [ 195 | BasicTransformerBlock( 196 | inner_dim, 197 | num_attention_heads, 198 | attention_head_dim, 199 | dropout=dropout, 200 | cross_attention_dim=cross_attention_dim, 201 | activation_fn=activation_fn, 202 | num_embeds_ada_norm=num_embeds_ada_norm, 203 | attention_bias=attention_bias, 204 | only_cross_attention=only_cross_attention, 205 | double_self_attention=double_self_attention, 206 | upcast_attention=upcast_attention, 207 | norm_type=norm_type, 208 | norm_elementwise_affine=norm_elementwise_affine, 209 | norm_eps=norm_eps, 210 | attention_type=attention_type, 211 | ) 212 | for d in range(num_layers) 213 | ] 214 | ) 215 | 216 | # 4. Define output layers 217 | self.out_channels = in_channels if out_channels is None else out_channels 218 | # TODO: should use out_channels for continuous projections 219 | if use_linear_projection: 220 | self.proj_out = linear_cls(inner_dim, in_channels) 221 | else: 222 | self.proj_out = conv_cls( 223 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 224 | ) 225 | 226 | # 5. PixArt-Alpha blocks. 227 | self.adaln_single = None 228 | self.use_additional_conditions = False 229 | if norm_type == "ada_norm_single": 230 | self.use_additional_conditions = self.config.sample_size == 128 231 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 232 | # additional conditions until we find better name 233 | self.adaln_single = AdaLayerNormSingle( 234 | inner_dim, use_additional_conditions=self.use_additional_conditions 235 | ) 236 | 237 | self.caption_projection = None 238 | 239 | self.gradient_checkpointing = False 240 | 241 | def _set_gradient_checkpointing(self, module, value=False): 242 | if hasattr(module, "gradient_checkpointing"): 243 | module.gradient_checkpointing = value 244 | 245 | def forward( 246 | self, 247 | hidden_states: torch.Tensor, 248 | encoder_hidden_states: Optional[torch.Tensor] = None, 249 | timestep: Optional[torch.LongTensor] = None, 250 | _added_cond_kwargs: Dict[str, torch.Tensor] = None, 251 | class_labels: Optional[torch.LongTensor] = None, 252 | cross_attention_kwargs: Dict[str, Any] = None, 253 | attention_mask: Optional[torch.Tensor] = None, 254 | encoder_attention_mask: Optional[torch.Tensor] = None, 255 | return_dict: bool = True, 256 | ): 257 | """ 258 | The [`Transformer2DModel`] forward method. 259 | 260 | Args: 261 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, 262 | `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 263 | Input `hidden_states`. 264 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 265 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 266 | self-attention. 267 | timestep ( `torch.LongTensor`, *optional*): 268 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 269 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 270 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 271 | `AdaLayerZeroNorm`. 272 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 274 | `self.processor` in 275 | [diffusers.models.attention_processor] 276 | (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 277 | attention_mask ( `torch.Tensor`, *optional*): 278 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 279 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 280 | negative values to the attention scores corresponding to "discard" tokens. 281 | encoder_attention_mask ( `torch.Tensor`, *optional*): 282 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 283 | 284 | * Mask `(batch, sequence_length)` True = keep, False = discard. 285 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 286 | 287 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 288 | above. This bias will be added to the cross-attention scores. 289 | return_dict (`bool`, *optional*, defaults to `True`): 290 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 291 | tuple. 292 | 293 | Returns: 294 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 295 | `tuple` where the first element is the sample tensor. 296 | """ 297 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 298 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 299 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 300 | # expects mask of shape: 301 | # [batch, key_tokens] 302 | # adds singleton query_tokens dimension: 303 | # [batch, 1, key_tokens] 304 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 305 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 306 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 307 | if attention_mask is not None and attention_mask.ndim == 2: 308 | # assume that mask is expressed as: 309 | # (1 = keep, 0 = discard) 310 | # convert mask into a bias that can be added to attention scores: 311 | # (keep = +0, discard = -10000.0) 312 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 313 | attention_mask = attention_mask.unsqueeze(1) 314 | 315 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 316 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 317 | encoder_attention_mask = ( 318 | 1 - encoder_attention_mask.to(hidden_states.dtype) 319 | ) * -10000.0 320 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 321 | 322 | # Retrieve lora scale. 323 | lora_scale = ( 324 | cross_attention_kwargs.get("scale", 1.0) 325 | if cross_attention_kwargs is not None 326 | else 1.0 327 | ) 328 | 329 | # 1. Input 330 | batch, _, height, width = hidden_states.shape 331 | residual = hidden_states 332 | 333 | hidden_states = self.norm(hidden_states) 334 | if not self.use_linear_projection: 335 | hidden_states = ( 336 | self.proj_in(hidden_states, scale=lora_scale) 337 | if not USE_PEFT_BACKEND 338 | else self.proj_in(hidden_states) 339 | ) 340 | inner_dim = hidden_states.shape[1] 341 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 342 | batch, height * width, inner_dim 343 | ) 344 | else: 345 | inner_dim = hidden_states.shape[1] 346 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 347 | batch, height * width, inner_dim 348 | ) 349 | hidden_states = ( 350 | self.proj_in(hidden_states, scale=lora_scale) 351 | if not USE_PEFT_BACKEND 352 | else self.proj_in(hidden_states) 353 | ) 354 | 355 | # 2. Blocks 356 | if self.caption_projection is not None: 357 | batch_size = hidden_states.shape[0] 358 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 359 | encoder_hidden_states = encoder_hidden_states.view( 360 | batch_size, -1, hidden_states.shape[-1] 361 | ) 362 | 363 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim) 364 | for block in self.transformer_blocks: 365 | if self.training and self.gradient_checkpointing: 366 | 367 | def create_custom_forward(module, return_dict=None): 368 | def custom_forward(*inputs): 369 | if return_dict is not None: 370 | return module(*inputs, return_dict=return_dict) 371 | 372 | return module(*inputs) 373 | 374 | return custom_forward 375 | 376 | ckpt_kwargs: Dict[str, Any] = ( 377 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 378 | ) 379 | hidden_states = torch.utils.checkpoint.checkpoint( 380 | create_custom_forward(block), 381 | hidden_states, 382 | attention_mask, 383 | encoder_hidden_states, 384 | encoder_attention_mask, 385 | timestep, 386 | cross_attention_kwargs, 387 | class_labels, 388 | **ckpt_kwargs, 389 | ) 390 | else: 391 | hidden_states = block( 392 | hidden_states, # shape [5, 4096, 320] 393 | attention_mask=attention_mask, 394 | encoder_hidden_states=encoder_hidden_states, # shape [1,4,768] 395 | encoder_attention_mask=encoder_attention_mask, 396 | timestep=timestep, 397 | cross_attention_kwargs=cross_attention_kwargs, 398 | class_labels=class_labels, 399 | ) 400 | 401 | # 3. Output 402 | output = None 403 | if self.is_input_continuous: 404 | if not self.use_linear_projection: 405 | hidden_states = ( 406 | hidden_states.reshape(batch, height, width, inner_dim) 407 | .permute(0, 3, 1, 2) 408 | .contiguous() 409 | ) 410 | hidden_states = ( 411 | self.proj_out(hidden_states, scale=lora_scale) 412 | if not USE_PEFT_BACKEND 413 | else self.proj_out(hidden_states) 414 | ) 415 | else: 416 | hidden_states = ( 417 | self.proj_out(hidden_states, scale=lora_scale) 418 | if not USE_PEFT_BACKEND 419 | else self.proj_out(hidden_states) 420 | ) 421 | hidden_states = ( 422 | hidden_states.reshape(batch, height, width, inner_dim) 423 | .permute(0, 3, 1, 2) 424 | .contiguous() 425 | ) 426 | 427 | output = hidden_states + residual 428 | if not return_dict: 429 | return (output, ref_feature) 430 | 431 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature) 432 | -------------------------------------------------------------------------------- /hallo/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1120 2 | """ 3 | This module contains the implementation of mutual self-attention, 4 | which is a type of attention mechanism used in deep learning models. 5 | The module includes several classes and functions related to attention mechanisms, 6 | such as BasicTransformerBlock and TemporalBasicTransformerBlock. 7 | The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning, 8 | such as image and video processing, natural language processing, and so on. 9 | """ 10 | 11 | from typing import Any, Dict, Optional 12 | 13 | import torch 14 | from einops import rearrange 15 | 16 | from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock 17 | 18 | 19 | def torch_dfs(model: torch.nn.Module): 20 | """ 21 | Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture. 22 | 23 | This function recursively traverses all the children modules of a given PyTorch model and returns a list 24 | containing all the modules in the model's architecture. The DFS approach starts with the input model and 25 | explores its children modules depth-wise before backtracking and exploring other branches. 26 | 27 | Args: 28 | model (torch.nn.Module): The root module of the neural network to traverse. 29 | 30 | Returns: 31 | list: A list of all the modules in the model's architecture. 32 | """ 33 | result = [model] 34 | for child in model.children(): 35 | result += torch_dfs(child) 36 | return result 37 | 38 | 39 | class ReferenceAttentionControl: 40 | """ 41 | This class is used to control the reference attention mechanism in a neural network model. 42 | It is responsible for managing the guidance and fusion blocks, and modifying the self-attention 43 | and group normalization mechanisms. The class also provides methods for registering reference hooks 44 | and updating/clearing the internal state of the attention control object. 45 | 46 | Attributes: 47 | unet: The UNet model associated with this attention control object. 48 | mode: The operating mode of the attention control object, either 'write' or 'read'. 49 | do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism. 50 | attention_auto_machine_weight: The weight assigned to the attention auto-machine. 51 | gn_auto_machine_weight: The weight assigned to the group normalization auto-machine. 52 | style_fidelity: The style fidelity parameter for the attention mechanism. 53 | reference_attn: Whether to use reference attention in the model. 54 | reference_adain: Whether to use reference AdaIN in the model. 55 | fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion'). 56 | batch_size: The batch size used for processing video frames. 57 | 58 | Methods: 59 | register_reference_hooks: Registers the reference hooks for the attention control object. 60 | hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block. 61 | update: Updates the internal state of the attention control object using the provided writer and dtype. 62 | clear: Clears the internal state of the attention control object. 63 | """ 64 | def __init__( 65 | self, 66 | unet, 67 | mode="write", 68 | do_classifier_free_guidance=False, 69 | attention_auto_machine_weight=float("inf"), 70 | gn_auto_machine_weight=1.0, 71 | style_fidelity=1.0, 72 | reference_attn=True, 73 | reference_adain=False, 74 | fusion_blocks="midup", 75 | batch_size=1, 76 | ) -> None: 77 | """ 78 | Initializes the ReferenceAttentionControl class. 79 | 80 | Args: 81 | unet (torch.nn.Module): The UNet model. 82 | mode (str, optional): The mode of operation. Defaults to "write". 83 | do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False. 84 | attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity. 85 | gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0. 86 | style_fidelity (float, optional): The style fidelity. Defaults to 1.0. 87 | reference_attn (bool, optional): Whether to use reference attention. Defaults to True. 88 | reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False. 89 | fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup". 90 | batch_size (int, optional): The batch size. Defaults to 1. 91 | 92 | Raises: 93 | ValueError: If the mode is not recognized. 94 | ValueError: If the fusion blocks are not recognized. 95 | """ 96 | # 10. Modify self attention and group norm 97 | self.unet = unet 98 | assert mode in ["read", "write"] 99 | assert fusion_blocks in ["midup", "full"] 100 | self.reference_attn = reference_attn 101 | self.reference_adain = reference_adain 102 | self.fusion_blocks = fusion_blocks 103 | self.register_reference_hooks( 104 | mode, 105 | do_classifier_free_guidance, 106 | attention_auto_machine_weight, 107 | gn_auto_machine_weight, 108 | style_fidelity, 109 | reference_attn, 110 | reference_adain, 111 | fusion_blocks, 112 | batch_size=batch_size, 113 | ) 114 | 115 | def register_reference_hooks( 116 | self, 117 | mode, 118 | do_classifier_free_guidance, 119 | _attention_auto_machine_weight, 120 | _gn_auto_machine_weight, 121 | _style_fidelity, 122 | _reference_attn, 123 | _reference_adain, 124 | _dtype=torch.float16, 125 | batch_size=1, 126 | num_images_per_prompt=1, 127 | device=torch.device("cpu"), 128 | _fusion_blocks="midup", 129 | ): 130 | """ 131 | Registers reference hooks for the model. 132 | 133 | This function is responsible for registering reference hooks in the model, 134 | which are used to modify the attention mechanism and group normalization layers. 135 | It takes various parameters as input, such as mode, 136 | do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity, 137 | _reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks. 138 | 139 | Args: 140 | self: Reference to the instance of the class. 141 | mode: The mode of operation for the reference hooks. 142 | do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance. 143 | _attention_auto_machine_weight: The weight for the attention auto-machine. 144 | _gn_auto_machine_weight: The weight for the group normalization auto-machine. 145 | _style_fidelity: The style fidelity for the reference hooks. 146 | _reference_attn: A boolean flag indicating whether to use reference attention. 147 | _reference_adain: A boolean flag indicating whether to use reference AdaIN. 148 | _dtype: The data type for the reference hooks. 149 | batch_size: The batch size for the reference hooks. 150 | num_images_per_prompt: The number of images per prompt for the reference hooks. 151 | device: The device for the reference hooks. 152 | _fusion_blocks: The fusion blocks for the reference hooks. 153 | 154 | Returns: 155 | None 156 | """ 157 | MODE = mode 158 | if do_classifier_free_guidance: 159 | uc_mask = ( 160 | torch.Tensor( 161 | [1] * batch_size * num_images_per_prompt * 16 162 | + [0] * batch_size * num_images_per_prompt * 16 163 | ) 164 | .to(device) 165 | .bool() 166 | ) 167 | else: 168 | uc_mask = ( 169 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 170 | .to(device) 171 | .bool() 172 | ) 173 | 174 | def hacked_basic_transformer_inner_forward( 175 | self, 176 | hidden_states: torch.FloatTensor, 177 | attention_mask: Optional[torch.FloatTensor] = None, 178 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 179 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 180 | timestep: Optional[torch.LongTensor] = None, 181 | cross_attention_kwargs: Dict[str, Any] = None, 182 | class_labels: Optional[torch.LongTensor] = None, 183 | video_length=None, 184 | ): 185 | gate_msa = None 186 | shift_mlp = None 187 | scale_mlp = None 188 | gate_mlp = None 189 | 190 | if self.use_ada_layer_norm: # False 191 | norm_hidden_states = self.norm1(hidden_states, timestep) 192 | elif self.use_ada_layer_norm_zero: 193 | ( 194 | norm_hidden_states, 195 | gate_msa, 196 | shift_mlp, 197 | scale_mlp, 198 | gate_mlp, 199 | ) = self.norm1( 200 | hidden_states, 201 | timestep, 202 | class_labels, 203 | hidden_dtype=hidden_states.dtype, 204 | ) 205 | else: 206 | norm_hidden_states = self.norm1(hidden_states) 207 | 208 | # 1. Self-Attention 209 | # self.only_cross_attention = False 210 | cross_attention_kwargs = ( 211 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 212 | ) 213 | if self.only_cross_attention: 214 | attn_output = self.attn1( 215 | norm_hidden_states, 216 | encoder_hidden_states=( 217 | encoder_hidden_states if self.only_cross_attention else None 218 | ), 219 | attention_mask=attention_mask, 220 | **cross_attention_kwargs, 221 | ) 222 | else: 223 | if MODE == "write": 224 | self.bank.append(norm_hidden_states.clone()) 225 | attn_output = self.attn1( 226 | norm_hidden_states, 227 | encoder_hidden_states=( 228 | encoder_hidden_states if self.only_cross_attention else None 229 | ), 230 | attention_mask=attention_mask, 231 | **cross_attention_kwargs, 232 | ) 233 | if MODE == "read": 234 | 235 | bank_fea = [ 236 | rearrange( 237 | rearrange( 238 | d, 239 | "(b s) l c -> b s l c", 240 | b=norm_hidden_states.shape[0] // video_length, 241 | )[:, 0, :, :] 242 | # .unsqueeze(1) 243 | .repeat(1, video_length, 1, 1), 244 | "b t l c -> (b t) l c", 245 | ) 246 | for d in self.bank 247 | ] 248 | motion_frames_fea = [rearrange( 249 | d, 250 | "(b s) l c -> b s l c", 251 | b=norm_hidden_states.shape[0] // video_length, 252 | )[:, 1:, :, :] for d in self.bank] 253 | modify_norm_hidden_states = torch.cat( 254 | [norm_hidden_states] + bank_fea, dim=1 255 | ) 256 | hidden_states_uc = ( 257 | self.attn1( 258 | norm_hidden_states, 259 | encoder_hidden_states=modify_norm_hidden_states, 260 | attention_mask=attention_mask, 261 | ) 262 | + hidden_states 263 | ) 264 | if do_classifier_free_guidance: 265 | hidden_states_c = hidden_states_uc.clone() 266 | _uc_mask = uc_mask.clone() 267 | if hidden_states.shape[0] != _uc_mask.shape[0]: 268 | _uc_mask = ( 269 | torch.Tensor( 270 | [1] * (hidden_states.shape[0] // 2) 271 | + [0] * (hidden_states.shape[0] // 2) 272 | ) 273 | .to(device) 274 | .bool() 275 | ) 276 | hidden_states_c[_uc_mask] = ( 277 | self.attn1( 278 | norm_hidden_states[_uc_mask], 279 | encoder_hidden_states=norm_hidden_states[_uc_mask], 280 | attention_mask=attention_mask, 281 | ) 282 | + hidden_states[_uc_mask] 283 | ) 284 | hidden_states = hidden_states_c.clone() 285 | else: 286 | hidden_states = hidden_states_uc 287 | 288 | # self.bank.clear() 289 | if self.attn2 is not None: 290 | # Cross-Attention 291 | norm_hidden_states = ( 292 | self.norm2(hidden_states, timestep) 293 | if self.use_ada_layer_norm 294 | else self.norm2(hidden_states) 295 | ) 296 | hidden_states = ( 297 | self.attn2( 298 | norm_hidden_states, 299 | encoder_hidden_states=encoder_hidden_states, 300 | attention_mask=attention_mask, 301 | ) 302 | + hidden_states 303 | ) 304 | 305 | # Feed-forward 306 | hidden_states = self.ff(self.norm3( 307 | hidden_states)) + hidden_states 308 | 309 | # Temporal-Attention 310 | if self.unet_use_temporal_attention: 311 | d = hidden_states.shape[1] 312 | hidden_states = rearrange( 313 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 314 | ) 315 | norm_hidden_states = ( 316 | self.norm_temp(hidden_states, timestep) 317 | if self.use_ada_layer_norm 318 | else self.norm_temp(hidden_states) 319 | ) 320 | hidden_states = ( 321 | self.attn_temp(norm_hidden_states) + hidden_states 322 | ) 323 | hidden_states = rearrange( 324 | hidden_states, "(b d) f c -> (b f) d c", d=d 325 | ) 326 | 327 | return hidden_states, motion_frames_fea 328 | 329 | if self.use_ada_layer_norm_zero: 330 | attn_output = gate_msa.unsqueeze(1) * attn_output 331 | hidden_states = attn_output + hidden_states 332 | 333 | if self.attn2 is not None: 334 | norm_hidden_states = ( 335 | self.norm2(hidden_states, timestep) 336 | if self.use_ada_layer_norm 337 | else self.norm2(hidden_states) 338 | ) 339 | 340 | # 2. Cross-Attention 341 | tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0] 342 | attn_output = self.attn2( 343 | norm_hidden_states, 344 | # TODO: repeat这个地方需要斟酌一下 345 | encoder_hidden_states=encoder_hidden_states.repeat( 346 | tmp, 1, 1), 347 | attention_mask=encoder_attention_mask, 348 | **cross_attention_kwargs, 349 | ) 350 | hidden_states = attn_output + hidden_states 351 | 352 | # 3. Feed-forward 353 | norm_hidden_states = self.norm3(hidden_states) 354 | 355 | if self.use_ada_layer_norm_zero: 356 | norm_hidden_states = ( 357 | norm_hidden_states * 358 | (1 + scale_mlp[:, None]) + shift_mlp[:, None] 359 | ) 360 | 361 | ff_output = self.ff(norm_hidden_states) 362 | 363 | if self.use_ada_layer_norm_zero: 364 | ff_output = gate_mlp.unsqueeze(1) * ff_output 365 | 366 | hidden_states = ff_output + hidden_states 367 | 368 | return hidden_states 369 | 370 | if self.reference_attn: 371 | if self.fusion_blocks == "midup": 372 | attn_modules = [ 373 | module 374 | for module in ( 375 | torch_dfs(self.unet.mid_block) + 376 | torch_dfs(self.unet.up_blocks) 377 | ) 378 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) 379 | ] 380 | elif self.fusion_blocks == "full": 381 | attn_modules = [ 382 | module 383 | for module in torch_dfs(self.unet) 384 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) 385 | ] 386 | attn_modules = sorted( 387 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 388 | ) 389 | 390 | for i, module in enumerate(attn_modules): 391 | module._original_inner_forward = module.forward 392 | if isinstance(module, BasicTransformerBlock): 393 | module.forward = hacked_basic_transformer_inner_forward.__get__( 394 | module, 395 | BasicTransformerBlock) 396 | if isinstance(module, TemporalBasicTransformerBlock): 397 | module.forward = hacked_basic_transformer_inner_forward.__get__( 398 | module, 399 | TemporalBasicTransformerBlock) 400 | 401 | module.bank = [] 402 | module.attn_weight = float(i) / float(len(attn_modules)) 403 | 404 | def update(self, writer, dtype=torch.float16): 405 | """ 406 | Update the model's parameters. 407 | 408 | Args: 409 | writer (torch.nn.Module): The model's writer object. 410 | dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16. 411 | 412 | Returns: 413 | None. 414 | """ 415 | if self.reference_attn: 416 | if self.fusion_blocks == "midup": 417 | reader_attn_modules = [ 418 | module 419 | for module in ( 420 | torch_dfs(self.unet.mid_block) + 421 | torch_dfs(self.unet.up_blocks) 422 | ) 423 | if isinstance(module, TemporalBasicTransformerBlock) 424 | ] 425 | writer_attn_modules = [ 426 | module 427 | for module in ( 428 | torch_dfs(writer.unet.mid_block) 429 | + torch_dfs(writer.unet.up_blocks) 430 | ) 431 | if isinstance(module, BasicTransformerBlock) 432 | ] 433 | elif self.fusion_blocks == "full": 434 | reader_attn_modules = [ 435 | module 436 | for module in torch_dfs(self.unet) 437 | if isinstance(module, TemporalBasicTransformerBlock) 438 | ] 439 | writer_attn_modules = [ 440 | module 441 | for module in torch_dfs(writer.unet) 442 | if isinstance(module, BasicTransformerBlock) 443 | ] 444 | 445 | assert len(reader_attn_modules) == len(writer_attn_modules) 446 | reader_attn_modules = sorted( 447 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 448 | ) 449 | writer_attn_modules = sorted( 450 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 451 | ) 452 | for r, w in zip(reader_attn_modules, writer_attn_modules): 453 | r.bank = [v.clone().to(dtype) for v in w.bank] 454 | 455 | 456 | def clear(self): 457 | """ 458 | Clears the attention bank of all reader attention modules. 459 | 460 | This method is used when the `reference_attn` attribute is set to `True`. 461 | It clears the attention bank of all reader attention modules inside the UNet 462 | model based on the selected `fusion_blocks` mode. 463 | 464 | If `fusion_blocks` is set to "midup", it searches for reader attention modules 465 | in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set 466 | to "full", it searches for reader attention modules in the entire UNet model. 467 | 468 | It sorts the reader attention modules by the number of neurons in their 469 | `norm1.normalized_shape[0]` attribute in descending order. This sorting ensures 470 | that the modules with more neurons are cleared first. 471 | 472 | Finally, it iterates through the sorted list of reader attention modules and 473 | calls the `clear()` method on each module's `bank` attribute to clear the 474 | attention bank. 475 | """ 476 | if self.reference_attn: 477 | if self.fusion_blocks == "midup": 478 | reader_attn_modules = [ 479 | module 480 | for module in ( 481 | torch_dfs(self.unet.mid_block) + 482 | torch_dfs(self.unet.up_blocks) 483 | ) 484 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) 485 | ] 486 | elif self.fusion_blocks == "full": 487 | reader_attn_modules = [ 488 | module 489 | for module in torch_dfs(self.unet) 490 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) 491 | ] 492 | reader_attn_modules = sorted( 493 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 494 | ) 495 | for r in reader_attn_modules: 496 | r.bank.clear() 497 | --------------------------------------------------------------------------------