├── hallo
├── __init__.py
├── utils
│ └── __init__.py
├── animate
│ ├── __init__.py
│ ├── face_animate_static.py
│ └── face_animate.py
├── datasets
│ ├── __init__.py
│ ├── mask_image.py
│ ├── audio_processor.py
│ ├── image_processor.py
│ └── talk_video.py
└── models
│ ├── __init__.py
│ ├── image_proj.py
│ ├── face_locator.py
│ ├── audio_proj.py
│ ├── wav2vec.py
│ ├── transformer_3d.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ └── mutual_self_attention.py
├── configs
├── inference
│ ├── .gitkeep
│ └── default.yaml
└── unet
│ └── unet.yaml
├── start_docker.sh
├── start.bat
├── start.sh
├── start_collab.sh
├── docker-compose.yml
├── .pre-commit-config.yaml
├── install.sh
├── install.bat
├── requirements.txt
├── DOCKERFILE
├── setup.py
├── .gitignore
├── app.py
├── container.py
├── README.md
└── scripts
└── inference.py
/hallo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hallo/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/inference/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hallo/animate/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hallo/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hallo/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/start_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python container.py
--------------------------------------------------------------------------------
/start.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | call venv/scripts/activate
4 | python app.py
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source venv/bin/activate
4 | python app.py
--------------------------------------------------------------------------------
/start_collab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source venv/bin/activate
4 | python app.py --share
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 | services:
3 | app:
4 | build:
5 | context: ./
6 | dockerfile: DOCKERFILE
7 | ports:
8 | - "8020:7860"
9 | deploy:
10 | resources:
11 | reservations:
12 | devices:
13 | - driver: nvidia
14 | count: 1
15 | capabilities: [gpu]
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: isort
5 | name: isort
6 | language: system
7 | types: [python]
8 | pass_filenames: false
9 | entry: isort
10 | args: ["."]
11 | - id: pylint
12 | name: pylint
13 | language: system
14 | types: [python]
15 | pass_filenames: false
16 | entry: pylint
17 | args: ["**/*.py"]
18 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "Clone models"
4 | git lfs install
5 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
6 | wget -O pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true
7 |
8 | echo "Install dependencies"
9 | python3 -m venv venv
10 | source venv/bin/activate
11 | pip install -r requirements.txt
12 | pip install -e .
13 |
14 | echo "Install GPU libraries"
15 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121
16 | pip install onnxruntime-gpu
17 |
18 | echo "Installation complete"
19 |
--------------------------------------------------------------------------------
/install.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | echo clone models
4 | git lfs install
5 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
6 | curl -L -o pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true
7 |
8 | echo Install Depends
9 | python -m venv venv
10 | call venv/scripts/activate
11 | pip install -r requirements.txt
12 | pip install -e .
13 |
14 | pip install bitsandbytes-windows --force-reinstall
15 |
16 | echo Install GPU libs
17 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121
18 |
19 | echo install complete
20 | pause
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.28.0
2 | audio-separator==0.17.2
3 | av==12.1.0
4 | bitsandbytes==0.43.1
5 | decord==0.6.0
6 | diffusers==0.27.2
7 | einops==0.8.0
8 | insightface==0.7.3
9 | librosa==0.10.2.post1
10 | mediapipe[vision]==0.10.14
11 | mlflow==2.13.1
12 | moviepy==1.0.3
13 | numpy==1.26.4
14 | omegaconf==2.3.0
15 | onnx2torch==1.5.14
16 | onnx==1.16.1
17 | onnxruntime==1.18.0
18 | opencv-contrib-python==4.9.0.80
19 | opencv-python-headless==4.9.0.80
20 | opencv-python==4.9.0.80
21 | pillow==10.3.0
22 | setuptools==70.0.0
23 | torch==2.2.2
24 | torchvision==0.17.2
25 | tqdm==4.66.4
26 | transformers==4.39.2
27 | xformers==0.0.25.post1
28 | isort==5.13.2
29 | pylint==3.2.2
30 | pre-commit==3.7.1
31 | gradio==4.36.1
--------------------------------------------------------------------------------
/DOCKERFILE:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
3 |
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | # Set the working directory in the container to /app
7 | WORKDIR /app
8 |
9 | # Install dependencies & python
10 | RUN apt-get update && apt-get install -y --no-install-recommends git wget bzip2 ffmpeg gcc g++ software-properties-common
11 |
12 | RUN add-apt-repository ppa:deadsnakes/ppa && apt update && apt install -y python3.11
13 |
14 | RUN apt install -y python3-pip
15 |
16 | # Clone the GitHub repository and checkout to the correct branch/tag
17 | RUN git clone https://github.com/moda20/hallo-webui.git .
18 |
19 | # Run install.sh to perform any setup or installations required before setting up the virtual environment
20 |
21 | RUN chmod +x install.sh && ./install.sh
22 |
23 | RUN pip install --upgrade pip
24 |
25 | # Expose port 7860 for the gradio app
26 | EXPOSE 7860
27 |
28 | RUN ln -s /usr/bin/python3 /usr/bin/python
29 |
30 |
31 | RUN chmod +x start_docker.sh
32 |
33 | # Run start.sh when the container starts
34 | CMD ["./start_docker.sh"]
--------------------------------------------------------------------------------
/configs/unet/unet.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | use_audio_module: true
7 | motion_module_resolutions:
8 | - 1
9 | - 2
10 | - 4
11 | - 8
12 | motion_module_mid_block: true
13 | motion_module_decoder_only: false
14 | motion_module_type: Vanilla
15 | motion_module_kwargs:
16 | num_attention_heads: 8
17 | num_transformer_block: 1
18 | attention_block_types:
19 | - Temporal_Self
20 | - Temporal_Self
21 | temporal_position_encoding: true
22 | temporal_position_encoding_max_len: 32
23 | temporal_attention_dim_div: 1
24 | audio_attention_dim: 768
25 | stack_enable_blocks_name:
26 | - "up"
27 | - "down"
28 | - "mid"
29 | stack_enable_blocks_depth: [0,1,2,3]
30 |
31 | enable_zero_snr: true
32 |
33 | noise_scheduler_kwargs:
34 | beta_start: 0.00085
35 | beta_end: 0.012
36 | beta_schedule: "linear"
37 | clip_sample: false
38 | steps_offset: 1
39 | ### Zero-SNR params
40 | prediction_type: "v_prediction"
41 | rescale_betas_zero_snr: True
42 | timestep_spacing: "trailing"
43 |
44 | sampler: DDIM
45 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | setup.py
3 | ----
4 | This is the main setup file for the hallo face animation project. It defines the package
5 | metadata, required dependencies, and provides the entry point for installing the package.
6 |
7 | """
8 |
9 | # -*- coding: utf-8 -*-
10 | from setuptools import setup
11 |
12 | packages = \
13 | ['hallo', 'hallo.datasets', 'hallo.models', 'hallo.animate', 'hallo.utils']
14 |
15 | package_data = \
16 | {'': ['*']}
17 |
18 | install_requires = \
19 | ['accelerate==0.28.0',
20 | 'audio-separator>=0.17.2,<0.18.0',
21 | 'av==12.1.0',
22 | 'bitsandbytes==0.43.1',
23 | 'decord==0.6.0',
24 | 'diffusers==0.27.2',
25 | 'einops>=0.8.0,<0.9.0',
26 | 'insightface>=0.7.3,<0.8.0',
27 | 'mediapipe[vision]>=0.10.14,<0.11.0',
28 | 'mlflow==2.13.1',
29 | 'moviepy>=1.0.3,<2.0.0',
30 | 'omegaconf>=2.3.0,<3.0.0',
31 | 'opencv-python>=4.9.0.80,<5.0.0.0',
32 | 'pillow>=10.3.0,<11.0.0',
33 | 'torch==2.2.2',
34 | 'torchvision==0.17.2',
35 | 'transformers==4.39.2',
36 | 'xformers==0.0.25.post1']
37 |
38 | setup_kwargs = {
39 | 'name': 'anna',
40 | 'version': '0.1.0',
41 | 'description': '',
42 | 'long_description': '# Anna face animation',
43 | 'author': 'Your Name',
44 | 'author_email': 'you@example.com',
45 | 'maintainer': 'None',
46 | 'maintainer_email': 'None',
47 | 'url': 'None',
48 | 'packages': packages,
49 | 'package_data': package_data,
50 | 'install_requires': install_requires,
51 | 'python_requires': '>=3.10,<4.0',
52 | }
53 |
54 |
55 | setup(**setup_kwargs)
56 |
--------------------------------------------------------------------------------
/configs/inference/default.yaml:
--------------------------------------------------------------------------------
1 | source_image: ./default.png
2 | driving_audio: default.wav
3 |
4 | weight_dtype: fp16
5 |
6 | data:
7 | n_motion_frames: 2
8 | n_sample_frames: 16
9 | source_image:
10 | width: 512
11 | height: 512
12 | driving_audio:
13 | sample_rate: 16000
14 | export_video:
15 | fps: 25
16 |
17 | inference_steps: 40
18 | cfg_scale: 3.5
19 |
20 | audio_ckpt_dir: ./pretrained_models/hallo
21 |
22 | base_model_path: ./pretrained_models/stable-diffusion-v1-5
23 |
24 | motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt
25 |
26 | face_analysis:
27 | model_path: ./pretrained_models/face_analysis
28 |
29 | wav2vec:
30 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
31 | features: all
32 |
33 | audio_separator:
34 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
35 |
36 | vae:
37 | model_path: ./pretrained_models/sd-vae-ft-mse
38 |
39 | save_path: ./.cache
40 |
41 | face_expand_ratio: 1.1
42 | pose_weight: 1.1
43 | face_weight: 1.1
44 | lip_weight: 1.1
45 |
46 | unet_additional_kwargs:
47 | use_inflated_groupnorm: true
48 | unet_use_cross_frame_attention: false
49 | unet_use_temporal_attention: false
50 | use_motion_module: true
51 | use_audio_module: true
52 | motion_module_resolutions:
53 | - 1
54 | - 2
55 | - 4
56 | - 8
57 | motion_module_mid_block: true
58 | motion_module_decoder_only: false
59 | motion_module_type: Vanilla
60 | motion_module_kwargs:
61 | num_attention_heads: 8
62 | num_transformer_block: 1
63 | attention_block_types:
64 | - Temporal_Self
65 | - Temporal_Self
66 | temporal_position_encoding: true
67 | temporal_position_encoding_max_len: 32
68 | temporal_attention_dim_div: 1
69 | audio_attention_dim: 768
70 | stack_enable_blocks_name:
71 | - "up"
72 | - "down"
73 | - "mid"
74 | stack_enable_blocks_depth: [0,1,2,3]
75 |
76 |
77 | enable_zero_snr: true
78 |
79 | noise_scheduler_kwargs:
80 | beta_start: 0.00085
81 | beta_end: 0.012
82 | beta_schedule: "linear"
83 | clip_sample: false
84 | steps_offset: 1
85 | ### Zero-SNR params
86 | prediction_type: "v_prediction"
87 | rescale_betas_zero_snr: True
88 | timestep_spacing: "trailing"
89 |
90 | sampler: DDIM
91 |
--------------------------------------------------------------------------------
/hallo/models/image_proj.py:
--------------------------------------------------------------------------------
1 | """
2 | image_proj_model.py
3 |
4 | This module defines the ImageProjModel class, which is responsible for
5 | projecting image embeddings into a different dimensional space. The model
6 | leverages a linear transformation followed by a layer normalization to
7 | reshape and normalize the input image embeddings for further processing in
8 | cross-attention mechanisms or other downstream tasks.
9 |
10 | Classes:
11 | ImageProjModel
12 |
13 | Dependencies:
14 | torch
15 | diffusers.ModelMixin
16 |
17 | """
18 |
19 | import torch
20 | from diffusers import ModelMixin
21 |
22 |
23 | class ImageProjModel(ModelMixin):
24 | """
25 | ImageProjModel is a class that projects image embeddings into a different
26 | dimensional space. It inherits from ModelMixin, providing additional functionalities
27 | specific to image projection.
28 |
29 | Attributes:
30 | cross_attention_dim (int): The dimension of the cross attention.
31 | clip_embeddings_dim (int): The dimension of the CLIP embeddings.
32 | clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
33 |
34 | Methods:
35 | forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
36 | embeddings and returns the projected tokens.
37 |
38 | """
39 |
40 | def __init__(
41 | self,
42 | cross_attention_dim=1024,
43 | clip_embeddings_dim=1024,
44 | clip_extra_context_tokens=4,
45 | ):
46 | super().__init__()
47 |
48 | self.generator = None
49 | self.cross_attention_dim = cross_attention_dim
50 | self.clip_extra_context_tokens = clip_extra_context_tokens
51 | self.proj = torch.nn.Linear(
52 | clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
53 | )
54 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
55 |
56 | def forward(self, image_embeds):
57 | """
58 | Forward pass of the ImageProjModel, which takes in image embeddings and returns the
59 | projected tokens after reshaping and normalization.
60 |
61 | Args:
62 | image_embeds (torch.Tensor): The input image embeddings, with shape
63 | batch_size x num_image_tokens x clip_embeddings_dim.
64 |
65 | Returns:
66 | clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
67 | and normalization, with shape batch_size x (clip_extra_context_tokens *
68 | cross_attention_dim).
69 |
70 | """
71 | embeds = image_embeds
72 | clip_extra_context_tokens = self.proj(embeds).reshape(
73 | -1, self.clip_extra_context_tokens, self.cross_attention_dim
74 | )
75 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
76 | return clip_extra_context_tokens
77 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # running cache
2 | mlruns/
3 |
4 | # Test directories
5 | test_data/
6 | pretrained_models/
7 |
8 | # Poetry project
9 | poetry.lock
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # poetry
108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112 | #poetry.lock
113 |
114 | # pdm
115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116 | #pdm.lock
117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118 | # in version control.
119 | # https://pdm.fming.dev/#use-with-ide
120 | .pdm.toml
121 |
122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123 | __pypackages__/
124 |
125 | # Celery stuff
126 | celerybeat-schedule
127 | celerybeat.pid
128 |
129 | # SageMath parsed files
130 | *.sage.py
131 |
132 | # Environments
133 | .env
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | # IDE
166 | .idea/
167 | .vscode/
168 | data
169 | pretrained_models
170 | test_data
171 | /output
172 |
--------------------------------------------------------------------------------
/hallo/models/face_locator.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements the FaceLocator class, which is a neural network model designed to
3 | locate and extract facial features from input images or tensors. It uses a series of
4 | convolutional layers to progressively downsample and refine the facial feature map.
5 |
6 | The FaceLocator class is part of a larger system that may involve facial recognition or
7 | similar tasks where precise location and extraction of facial features are required.
8 |
9 | Attributes:
10 | conditioning_embedding_channels (int): The number of channels in the output embedding.
11 | conditioning_channels (int): The number of input channels for the conditioning tensor.
12 | block_out_channels (Tuple[int]): A tuple of integers representing the output channels
13 | for each block in the model.
14 |
15 | The model uses the following components:
16 | - InflatedConv3d: A convolutional layer that inflates the input to increase the depth.
17 | - zero_module: A utility function that may set certain parameters to zero for regularization
18 | or other purposes.
19 |
20 | The forward method of the FaceLocator class takes a conditioning tensor as input and
21 | produces an embedding tensor as output, which can be used for further processing or analysis.
22 | """
23 |
24 | from typing import Tuple
25 |
26 | import torch.nn.functional as F
27 | from diffusers.models.modeling_utils import ModelMixin
28 | from torch import nn
29 |
30 | from .motion_module import zero_module
31 | from .resnet import InflatedConv3d
32 |
33 |
34 | class FaceLocator(ModelMixin):
35 | """
36 | The FaceLocator class is a neural network model designed to process and extract facial
37 | features from an input tensor. It consists of a series of convolutional layers that
38 | progressively downsample the input while increasing the depth of the feature map.
39 |
40 | The model is built using InflatedConv3d layers, which are designed to inflate the
41 | feature channels, allowing for more complex feature extraction. The final output is a
42 | conditioning embedding that can be used for various tasks such as facial recognition or
43 | feature-based image manipulation.
44 |
45 | Parameters:
46 | conditioning_embedding_channels (int): The number of channels in the output embedding.
47 | conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3.
48 | block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels
49 | for each block in the model. The default is (16, 32, 64, 128), which defines the
50 | progression of the network's depth.
51 |
52 | Attributes:
53 | conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process.
54 | blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model.
55 | conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding.
56 |
57 | The forward method applies the convolutional layers to the input conditioning tensor and
58 | returns the resulting embedding tensor.
59 | """
60 | def __init__(
61 | self,
62 | conditioning_embedding_channels: int,
63 | conditioning_channels: int = 3,
64 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
65 | ):
66 | super().__init__()
67 | self.conv_in = InflatedConv3d(
68 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
69 | )
70 |
71 | self.blocks = nn.ModuleList([])
72 |
73 | for i in range(len(block_out_channels) - 1):
74 | channel_in = block_out_channels[i]
75 | channel_out = block_out_channels[i + 1]
76 | self.blocks.append(
77 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
78 | )
79 | self.blocks.append(
80 | InflatedConv3d(
81 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
82 | )
83 | )
84 |
85 | self.conv_out = zero_module(
86 | InflatedConv3d(
87 | block_out_channels[-1],
88 | conditioning_embedding_channels,
89 | kernel_size=3,
90 | padding=1,
91 | )
92 | )
93 |
94 | def forward(self, conditioning):
95 | """
96 | Forward pass of the FaceLocator model.
97 |
98 | Args:
99 | conditioning (Tensor): The input conditioning tensor.
100 |
101 | Returns:
102 | Tensor: The output embedding tensor.
103 | """
104 | embedding = self.conv_in(conditioning)
105 | embedding = F.silu(embedding)
106 |
107 | for block in self.blocks:
108 | embedding = block(embedding)
109 | embedding = F.silu(embedding)
110 |
111 | embedding = self.conv_out(embedding)
112 |
113 | return embedding
114 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import gradio as gr
3 | import subprocess
4 | from datetime import datetime
5 | import os
6 | import platform
7 |
8 | def generate_video(ref_img, ref_audio,settings_face_expand_ratio=1.2, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16):
9 | # Ensure file paths are correct
10 | if not os.path.isfile(ref_img) or not os.path.isfile(ref_audio):
11 | return "Error: File not found", None
12 |
13 | # Path to the output video file
14 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
15 | # Check if output exists and if not create it
16 | if not os.path.exists("output"):
17 | os.makedirs("output")
18 |
19 | output_video = f"output/{timestamp}.mp4"
20 |
21 | # Determine the command based on the operating system
22 | if platform.system() == "Windows":
23 | command = [
24 | "venv\\Scripts\\python.exe",
25 | "scripts\\inference.py",
26 | "--source_image", ref_img,
27 | "--driving_audio", ref_audio,
28 | "--output", output_video,
29 | "--face_expand_ratio", str(settings_face_expand_ratio),
30 | "--setting_steps", str(setting_steps),
31 | "--setting_cfg", str(setting_cfg),
32 | "--settings_seed", str(settings_seed),
33 | "--settings_fps", str(settings_fps),
34 | "--settings_motion_pose_scale", str(settings_motion_pose_scale),
35 | "--settings_motion_face_scale", str(settings_motion_face_scale),
36 | "--settings_motion_lip_scale", str(settings_motion_lip_scale),
37 | "--settings_n_motion_frames", str(settings_n_motion_frames),
38 | "--settings_n_sample_frames", str(settings_n_sample_frames)
39 | ]
40 | else:
41 | command = [
42 | "python3",
43 | "scripts/inference.py",
44 | "--source_image", ref_img,
45 | "--driving_audio", ref_audio,
46 | "--output", output_video,
47 | "--setting_steps", str(setting_steps),
48 | "--setting_cfg", str(setting_cfg),
49 | "--settings_seed", str(settings_seed),
50 | "--settings_fps", str(settings_fps),
51 | "--settings_motion_pose_scale", str(settings_motion_pose_scale),
52 | "--settings_motion_face_scale", str(settings_motion_face_scale),
53 | "--settings_motion_lip_scale", str(settings_motion_lip_scale),
54 | "--settings_n_motion_frames", str(settings_n_motion_frames),
55 | "--settings_n_sample_frames", str(settings_n_sample_frames)
56 | ]
57 |
58 | try:
59 | # Execute the command
60 | result = subprocess.run(command, check=True)
61 |
62 | if result.returncode == 0:
63 | return "Video generated successfully", output_video
64 | else:
65 | return "Error generating video", None
66 |
67 | except subprocess.CalledProcessError as e:
68 | return f"Error: {str(e)}", None
69 |
70 |
71 | with gr.Blocks() as demo:
72 | with gr.Row():
73 | with gr.Column():
74 | ref_img = gr.Image(label="Reference Image", type="filepath")
75 | ref_audio = gr.Audio(label="Audio", type="filepath")
76 | with gr.Accordion("Settings", open=True):
77 | settings_face_expand_ratio = gr.Slider(label="Face Expand Ratio", value=1.2, minimum=0, maximum=10, step=0.01)
78 | setting_steps = gr.Slider(label="Steps", value=40, minimum=1, maximum=200, step=1)
79 | setting_cfg = gr.Slider(label="CFG Scale", value=3.5, minimum=0, maximum=10, step=0.01)
80 | settings_seed = gr.Textbox(label="Seed", value=42)
81 | settings_fps = gr.Slider(label="FPS", value=25, minimum=1, maximum=200, step=1)
82 | with gr.Accordion("Motion Scale", open=True):
83 | settings_motion_pose_scale = gr.Slider(label="Motion Pose Scale", value=1.0, minimum=0, maximum=5, step=0.01)
84 | settings_motion_face_scale = gr.Slider(label="Motion Face Scale", value=1.0, minimum=0, maximum=5, step=0.01)
85 | settings_motion_lip_scale = gr.Slider(label="Motion Lip Scale", value=1.0, minimum=0, maximum=5, step=0.01)
86 | with gr.Accordion("Extra Settings", open=True):
87 | settings_n_motion_frames = gr.Slider(label="N Motion Frames", value=2, minimum=1, maximum=100, step=1)
88 | settings_n_sample_frames = gr.Slider(label="N Sample Frames", value=16, minimum=1, maximum=100, step=1)
89 | with gr.Column():
90 | result_status = gr.Label(value="Status")
91 | result_video = gr.Video(label="Result Video", interactive=False)
92 | result_btn = gr.Button(value="Generate Video")
93 |
94 | result_btn.click(fn=generate_video, inputs=[ref_img, ref_audio,settings_face_expand_ratio, setting_steps, setting_cfg, settings_seed, settings_fps, settings_motion_pose_scale, settings_motion_face_scale, settings_motion_lip_scale, settings_n_motion_frames, settings_n_sample_frames], outputs=[result_status, result_video])
95 |
96 | if __name__ == "__main__":
97 | share_url = False if "--share" not in sys.argv else True
98 |
99 | demo.queue()
100 | demo.launch(inbrowser=True, share=share_url)
101 |
102 |
--------------------------------------------------------------------------------
/container.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import gradio as gr
3 | import subprocess
4 | from datetime import datetime
5 | import os
6 | import platform
7 |
8 | def generate_video(ref_img, ref_audio,settings_face_expand_ratio=1.2, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16):
9 | # Ensure file paths are correct
10 | if not os.path.isfile(ref_img) or not os.path.isfile(ref_audio):
11 | return "Error: File not found", None
12 |
13 | # Path to the output video file
14 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
15 | # Check if output exists and if not create it
16 | if not os.path.exists("output"):
17 | os.makedirs("output")
18 |
19 | output_video = f"output/{timestamp}.mp4"
20 |
21 | # Determine the command based on the operating system
22 | if platform.system() == "Windows":
23 | command = [
24 | "venv\\Scripts\\python.exe",
25 | "scripts\\inference.py",
26 | "--source_image", ref_img,
27 | "--driving_audio", ref_audio,
28 | "--output", output_video,
29 | "--face_expand_ratio", str(settings_face_expand_ratio),
30 | "--setting_steps", str(setting_steps),
31 | "--setting_cfg", str(setting_cfg),
32 | "--settings_seed", str(settings_seed),
33 | "--settings_fps", str(settings_fps),
34 | "--settings_motion_pose_scale", str(settings_motion_pose_scale),
35 | "--settings_motion_face_scale", str(settings_motion_face_scale),
36 | "--settings_motion_lip_scale", str(settings_motion_lip_scale),
37 | "--settings_n_motion_frames", str(settings_n_motion_frames),
38 | "--settings_n_sample_frames", str(settings_n_sample_frames)
39 | ]
40 | else:
41 | command = [
42 | "python3",
43 | "scripts/inference.py",
44 | "--source_image", ref_img,
45 | "--driving_audio", ref_audio,
46 | "--output", output_video,
47 | "--setting_steps", str(setting_steps),
48 | "--setting_cfg", str(setting_cfg),
49 | "--settings_seed", str(settings_seed),
50 | "--settings_fps", str(settings_fps),
51 | "--settings_motion_pose_scale", str(settings_motion_pose_scale),
52 | "--settings_motion_face_scale", str(settings_motion_face_scale),
53 | "--settings_motion_lip_scale", str(settings_motion_lip_scale),
54 | "--settings_n_motion_frames", str(settings_n_motion_frames),
55 | "--settings_n_sample_frames", str(settings_n_sample_frames)
56 | ]
57 |
58 | try:
59 | # Execute the command
60 | result = subprocess.run(command, check=True)
61 |
62 | if result.returncode == 0:
63 | return "Video generated successfully", output_video
64 | else:
65 | return "Error generating video", None
66 |
67 | except subprocess.CalledProcessError as e:
68 | return f"Error: {str(e)}", None
69 |
70 |
71 | with gr.Blocks() as demo:
72 | with gr.Row():
73 | with gr.Column():
74 | ref_img = gr.Image(label="Reference Image", type="filepath")
75 | ref_audio = gr.Audio(label="Audio", type="filepath")
76 | with gr.Accordion("Settings", open=True):
77 | settings_face_expand_ratio = gr.Slider(label="Face Expand Ratio", value=1.2, minimum=0, maximum=10, step=0.01)
78 | setting_steps = gr.Slider(label="Steps", value=40, minimum=1, maximum=200, step=1)
79 | setting_cfg = gr.Slider(label="CFG Scale", value=3.5, minimum=0, maximum=10, step=0.01)
80 | settings_seed = gr.Textbox(label="Seed", value=42)
81 | settings_fps = gr.Slider(label="FPS", value=25, minimum=1, maximum=200, step=1)
82 | with gr.Accordion("Motion Scale", open=True):
83 | settings_motion_pose_scale = gr.Slider(label="Motion Pose Scale", value=1.0, minimum=0, maximum=5, step=0.01)
84 | settings_motion_face_scale = gr.Slider(label="Motion Face Scale", value=1.0, minimum=0, maximum=5, step=0.01)
85 | settings_motion_lip_scale = gr.Slider(label="Motion Lip Scale", value=1.0, minimum=0, maximum=5, step=0.01)
86 | with gr.Accordion("Extra Settings", open=True):
87 | settings_n_motion_frames = gr.Slider(label="N Motion Frames", value=2, minimum=1, maximum=100, step=1)
88 | settings_n_sample_frames = gr.Slider(label="N Sample Frames", value=16, minimum=1, maximum=100, step=1)
89 | with gr.Column():
90 | result_status = gr.Label(value="Status")
91 | result_video = gr.Video(label="Result Video", interactive=False)
92 | result_btn = gr.Button(value="Generate Video")
93 |
94 | result_btn.click(fn=generate_video, inputs=[ref_img, ref_audio,settings_face_expand_ratio, setting_steps, setting_cfg, settings_seed, settings_fps, settings_motion_pose_scale, settings_motion_face_scale, settings_motion_lip_scale, settings_n_motion_frames, settings_n_sample_frames], outputs=[result_status, result_video])
95 |
96 | if __name__ == "__main__":
97 | share_url = False if "--share" not in sys.argv else True
98 |
99 | demo.queue()
100 | demo.launch(inbrowser=True, share=share_url, server_port=7860, server_name="0.0.0.0")
--------------------------------------------------------------------------------
/hallo/models/audio_proj.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides the implementation of an Audio Projection Model, which is designed for
3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4 | that can be used for various downstream applications, such as audio analysis or synthesis.
5 |
6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7 | provides a foundation for building custom models. This implementation includes multiple linear
8 | layers with ReLU activation functions and a LayerNorm for normalization.
9 |
10 | Key Features:
11 | - Audio embedding input with flexible sequence length and block structure.
12 | - Multiple linear layers for feature transformation.
13 | - ReLU activation for non-linear transformation.
14 | - LayerNorm for stabilizing and speeding up training.
15 | - Rearrangement of input embeddings to match the model's expected input shape.
16 | - Customizable number of blocks, channels, and context tokens for adaptability.
17 |
18 | The module is structured to be easily integrated into larger systems or used as a standalone
19 | component for audio feature extraction and processing.
20 |
21 | Classes:
22 | - AudioProjModel: A class representing the audio projection model with configurable parameters.
23 |
24 | Functions:
25 | - (none)
26 |
27 | Dependencies:
28 | - torch: For tensor operations and neural network components.
29 | - diffusers: For the ModelMixin base class.
30 | - einops: For tensor rearrangement operations.
31 |
32 | """
33 |
34 | import torch
35 | from diffusers import ModelMixin
36 | from einops import rearrange
37 | from torch import nn
38 |
39 |
40 | class AudioProjModel(ModelMixin):
41 | """Audio Projection Model
42 |
43 | This class defines an audio projection model that takes audio embeddings as input
44 | and produces context tokens as output. The model is based on the ModelMixin class
45 | and consists of multiple linear layers and activation functions. It can be used
46 | for various audio processing tasks.
47 |
48 | Attributes:
49 | seq_len (int): The length of the audio sequence.
50 | blocks (int): The number of blocks in the audio projection model.
51 | channels (int): The number of channels in the audio projection model.
52 | intermediate_dim (int): The intermediate dimension of the model.
53 | context_tokens (int): The number of context tokens in the output.
54 | output_dim (int): The output dimension of the context tokens.
55 |
56 | Methods:
57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
58 | Initializes the AudioProjModel with the given parameters.
59 | forward(self, audio_embeds):
60 | Defines the forward pass for the AudioProjModel.
61 | Parameters:
62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
63 | Returns:
64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
65 |
66 | """
67 |
68 | def __init__(
69 | self,
70 | seq_len=5,
71 | blocks=12, # add a new parameter blocks
72 | channels=768, # add a new parameter channels
73 | intermediate_dim=512,
74 | output_dim=768,
75 | context_tokens=32,
76 | ):
77 | super().__init__()
78 |
79 | self.seq_len = seq_len
80 | self.blocks = blocks
81 | self.channels = channels
82 | self.input_dim = (
83 | seq_len * blocks * channels
84 | ) # update input_dim to be the product of blocks and channels.
85 | self.intermediate_dim = intermediate_dim
86 | self.context_tokens = context_tokens
87 | self.output_dim = output_dim
88 |
89 | # define multiple linear layers
90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
93 |
94 | self.norm = nn.LayerNorm(output_dim)
95 |
96 | def forward(self, audio_embeds):
97 | """
98 | Defines the forward pass for the AudioProjModel.
99 |
100 | Parameters:
101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
102 |
103 | Returns:
104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
105 | """
106 | # merge
107 | video_length = audio_embeds.shape[1]
108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
109 | batch_size, window_size, blocks, channels = audio_embeds.shape
110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
111 |
112 | audio_embeds = torch.relu(self.proj1(audio_embeds))
113 | audio_embeds = torch.relu(self.proj2(audio_embeds))
114 |
115 | context_tokens = self.proj3(audio_embeds).reshape(
116 | batch_size, self.context_tokens, self.output_dim
117 | )
118 |
119 | context_tokens = self.norm(context_tokens)
120 | context_tokens = rearrange(
121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length
122 | )
123 |
124 | return context_tokens
125 |
--------------------------------------------------------------------------------
/hallo/datasets/mask_image.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
4 | load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
5 | provides methods for data augmentation, getting items from the dataset, and determining the length of the
6 | dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
7 | PIL, and transformers.
8 | """
9 |
10 | import json
11 | import random
12 | from pathlib import Path
13 |
14 | import torch
15 | from PIL import Image
16 | from torch.utils.data import Dataset
17 | from torchvision import transforms
18 | from transformers import CLIPImageProcessor
19 |
20 |
21 | class FaceMaskDataset(Dataset):
22 | """
23 | FaceMaskDataset is a custom dataset for face mask images.
24 |
25 | Args:
26 | img_size (int): The size of the input images.
27 | drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
28 | data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
29 | sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
30 |
31 | Attributes:
32 | img_size (int): The size of the input images.
33 | drop_ratio (float): The ratio of dropped pixels during data augmentation.
34 | data_meta_paths (list): The paths to the metadata files containing image paths and labels.
35 | sample_margin (int): The margin for sampling regions in the image.
36 | processor (CLIPImageProcessor): The image processor for preprocessing images.
37 | transform (transforms.Compose): The image augmentation transform.
38 | """
39 |
40 | def __init__(
41 | self,
42 | img_size,
43 | drop_ratio=0.1,
44 | data_meta_paths=None,
45 | sample_margin=30,
46 | ):
47 | super().__init__()
48 |
49 | self.img_size = img_size
50 | self.sample_margin = sample_margin
51 |
52 | vid_meta = []
53 | for data_meta_path in data_meta_paths:
54 | with open(data_meta_path, "r", encoding="utf-8") as f:
55 | vid_meta.extend(json.load(f))
56 | self.vid_meta = vid_meta
57 | self.length = len(self.vid_meta)
58 |
59 | self.clip_image_processor = CLIPImageProcessor()
60 |
61 | self.transform = transforms.Compose(
62 | [
63 | transforms.Resize(self.img_size),
64 | transforms.ToTensor(),
65 | transforms.Normalize([0.5], [0.5]),
66 | ]
67 | )
68 |
69 | self.cond_transform = transforms.Compose(
70 | [
71 | transforms.Resize(self.img_size),
72 | transforms.ToTensor(),
73 | ]
74 | )
75 |
76 | self.drop_ratio = drop_ratio
77 |
78 | def augmentation(self, image, transform, state=None):
79 | """
80 | Apply data augmentation to the input image.
81 |
82 | Args:
83 | image (PIL.Image): The input image.
84 | transform (torchvision.transforms.Compose): The data augmentation transforms.
85 | state (dict, optional): The random state for reproducibility. Defaults to None.
86 |
87 | Returns:
88 | PIL.Image: The augmented image.
89 | """
90 | if state is not None:
91 | torch.set_rng_state(state)
92 | return transform(image)
93 |
94 | def __getitem__(self, index):
95 | video_meta = self.vid_meta[index]
96 | video_path = video_meta["image_path"]
97 | mask_path = video_meta["mask_path"]
98 | face_emb_path = video_meta["face_emb"]
99 |
100 | video_frames = sorted(Path(video_path).iterdir())
101 | video_length = len(video_frames)
102 |
103 | margin = min(self.sample_margin, video_length)
104 |
105 | ref_img_idx = random.randint(0, video_length - 1)
106 | if ref_img_idx + margin < video_length:
107 | tgt_img_idx = random.randint(
108 | ref_img_idx + margin, video_length - 1)
109 | elif ref_img_idx - margin > 0:
110 | tgt_img_idx = random.randint(0, ref_img_idx - margin)
111 | else:
112 | tgt_img_idx = random.randint(0, video_length - 1)
113 |
114 | ref_img_pil = Image.open(video_frames[ref_img_idx])
115 | tgt_img_pil = Image.open(video_frames[tgt_img_idx])
116 |
117 | tgt_mask_pil = Image.open(mask_path)
118 |
119 | assert ref_img_pil is not None, "Fail to load reference image."
120 | assert tgt_img_pil is not None, "Fail to load target image."
121 | assert tgt_mask_pil is not None, "Fail to load target mask."
122 |
123 | state = torch.get_rng_state()
124 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
125 | tgt_mask_img = self.augmentation(
126 | tgt_mask_pil, self.cond_transform, state)
127 | tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
128 | ref_img_vae = self.augmentation(
129 | ref_img_pil, self.transform, state)
130 | face_emb = torch.load(face_emb_path)
131 |
132 |
133 | sample = {
134 | "video_dir": video_path,
135 | "img": tgt_img,
136 | "tgt_mask": tgt_mask_img,
137 | "ref_img": ref_img_vae,
138 | "face_emb": face_emb,
139 | }
140 |
141 | return sample
142 |
143 | def __len__(self):
144 | return len(self.vid_meta)
145 |
146 |
147 | if __name__ == "__main__":
148 | data = FaceMaskDataset(img_size=(512, 512))
149 | train_dataloader = torch.utils.data.DataLoader(
150 | data, batch_size=4, shuffle=True, num_workers=1
151 | )
152 | for step, batch in enumerate(train_dataloader):
153 | print(batch["tgt_mask"].shape)
154 | break
155 |
--------------------------------------------------------------------------------
/hallo/datasets/audio_processor.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=C0301
2 | '''
3 | This module contains the AudioProcessor class and related functions for processing audio data.
4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
5 | and audio separation. The class is initialized with configuration parameters and can process
6 | audio files using the provided models.
7 | '''
8 | import math
9 | import os
10 |
11 | import librosa
12 | import numpy as np
13 | import torch
14 | from audio_separator.separator import Separator
15 | from einops import rearrange
16 | from transformers import Wav2Vec2FeatureExtractor
17 |
18 | from hallo.models.wav2vec import Wav2VecModel
19 | from hallo.utils.util import resample_audio
20 |
21 |
22 | class AudioProcessor:
23 | """
24 | AudioProcessor is a class that handles the processing of audio files.
25 | It takes care of preprocessing the audio files, extracting features
26 | using wav2vec models, and separating audio signals if needed.
27 |
28 | :param sample_rate: Sampling rate of the audio file
29 | :param fps: Frames per second for the extracted features
30 | :param wav2vec_model_path: Path to the wav2vec model
31 | :param only_last_features: Whether to only use the last features
32 | :param audio_separator_model_path: Path to the audio separator model
33 | :param audio_separator_model_name: Name of the audio separator model
34 | :param cache_dir: Directory to cache the intermediate results
35 | :param device: Device to run the processing on
36 | """
37 | def __init__(
38 | self,
39 | sample_rate,
40 | fps,
41 | wav2vec_model_path,
42 | only_last_features,
43 | audio_separator_model_path:str=None,
44 | audio_separator_model_name:str=None,
45 | cache_dir:str='',
46 | device="cuda:0",
47 | ) -> None:
48 | self.sample_rate = sample_rate
49 | self.fps = fps
50 | self.device = device
51 |
52 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
53 | self.audio_encoder.feature_extractor._freeze_parameters()
54 | self.only_last_features = only_last_features
55 |
56 | if audio_separator_model_name is not None:
57 | try:
58 | os.makedirs(cache_dir, exist_ok=True)
59 | except OSError as _:
60 | print("Fail to create the output cache dir.")
61 | self.audio_separator = Separator(
62 | output_dir=cache_dir,
63 | output_single_stem="vocals",
64 | model_file_dir=audio_separator_model_path,
65 | )
66 | self.audio_separator.load_model(audio_separator_model_name)
67 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
68 | else:
69 | self.audio_separator=None
70 | print("Use audio directly without vocals seperator.")
71 |
72 |
73 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
74 |
75 |
76 | def preprocess(self, wav_file: str):
77 | """
78 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
79 | The separated vocal track is then converted into wav2vec2 for further processing or analysis.
80 |
81 | Args:
82 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
83 |
84 | Raises:
85 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
86 | such as file not found, unsupported file format, or errors during the audio processing steps.
87 |
88 | Returns:
89 | torch.tensor: Returns an audio embedding as a torch.tensor
90 | """
91 | if self.audio_separator is not None:
92 | # 1. separate vocals
93 | # TODO: process in memory
94 | outputs = self.audio_separator.separate(wav_file)
95 | if len(outputs) <= 0:
96 | raise RuntimeError("Audio separate failed.")
97 |
98 | vocal_audio_file = outputs[0]
99 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
100 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
101 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
102 | else:
103 | vocal_audio_file=wav_file
104 |
105 | # 2. extract wav2vec features
106 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
107 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
108 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
109 |
110 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
111 | audio_feature = audio_feature.unsqueeze(0)
112 |
113 | with torch.no_grad():
114 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
115 | assert len(embeddings) > 0, "Fail to extract audio embedding"
116 | if self.only_last_features:
117 | audio_emb = embeddings.last_hidden_state.squeeze()
118 | else:
119 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
120 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
121 |
122 | audio_emb = audio_emb.cpu().detach()
123 |
124 | return audio_emb
125 |
126 | def get_embedding(self, wav_file: str):
127 | """preprocess wav audio file convert to embeddings
128 |
129 | Args:
130 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
131 |
132 | Returns:
133 | torch.tensor: Returns an audio embedding as a torch.tensor
134 | """
135 | speech_array, sampling_rate = librosa.load(
136 | wav_file, sr=self.sample_rate)
137 | assert sampling_rate == 16000, "The audio sample rate must be 16000"
138 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(
139 | speech_array, sampling_rate=sampling_rate).input_values)
140 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
141 |
142 | audio_feature = torch.from_numpy(
143 | audio_feature).float().to(device=self.device)
144 | audio_feature = audio_feature.unsqueeze(0)
145 |
146 | with torch.no_grad():
147 | embeddings = self.audio_encoder(
148 | audio_feature, seq_len=seq_len, output_hidden_states=True)
149 | assert len(embeddings) > 0, "Fail to extract audio embedding"
150 |
151 | if self.only_last_features:
152 | audio_emb = embeddings.last_hidden_state.squeeze()
153 | else:
154 | audio_emb = torch.stack(
155 | embeddings.hidden_states[1:], dim=1).squeeze(0)
156 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
157 |
158 | audio_emb = audio_emb.cpu().detach()
159 |
160 | return audio_emb
161 |
162 | def close(self):
163 | """
164 | TODO: to be implemented
165 | """
166 | return self
167 |
168 | def __enter__(self):
169 | return self
170 |
171 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
172 | self.close()
173 |
--------------------------------------------------------------------------------
/hallo/datasets/image_processor.py:
--------------------------------------------------------------------------------
1 | """
2 | This module is responsible for processing images, particularly for face-related tasks.
3 | It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
4 | face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates
5 | the functionality for these operations.
6 | """
7 | import os
8 | from typing import List
9 |
10 | import cv2
11 | import numpy as np
12 | import torch
13 | from insightface.app import FaceAnalysis
14 | from PIL import Image
15 | from torchvision import transforms
16 |
17 | from ..utils.util import get_mask
18 |
19 | MEAN = 0.5
20 | STD = 0.5
21 |
22 | class ImageProcessor:
23 | """
24 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
25 | It takes in an image and performs various operations such as augmentation, face detection,
26 | face embedding extraction, and rendering a face mask. The processed images are then used for
27 | further analysis or recognition purposes.
28 |
29 | Attributes:
30 | img_size (int): The size of the image to be processed.
31 | face_analysis_model_path (str): The path to the face analysis model.
32 |
33 | Methods:
34 | preprocess(source_image_path, cache_dir):
35 | Preprocesses the input image by performing augmentation, face detection,
36 | face embedding extraction, and rendering a face mask.
37 |
38 | close():
39 | Closes the ImageProcessor and releases any resources being used.
40 |
41 | _augmentation(images, transform, state=None):
42 | Applies image augmentation to the input images using the given transform and state.
43 |
44 | __enter__():
45 | Enters a runtime context and returns the ImageProcessor object.
46 |
47 | __exit__(_exc_type, _exc_val, _exc_tb):
48 | Exits a runtime context and handles any exceptions that occurred during the processing.
49 | """
50 | def __init__(self, img_size, face_analysis_model_path) -> None:
51 | self.img_size = img_size
52 |
53 | self.pixel_transform = transforms.Compose(
54 | [
55 | transforms.Resize(self.img_size),
56 | transforms.ToTensor(),
57 | transforms.Normalize([MEAN], [STD]),
58 | ]
59 | )
60 |
61 | self.cond_transform = transforms.Compose(
62 | [
63 | transforms.Resize(self.img_size),
64 | transforms.ToTensor(),
65 | ]
66 | )
67 |
68 | self.attn_transform_64 = transforms.Compose(
69 | [
70 | transforms.Resize(
71 | (self.img_size[0] // 8, self.img_size[0] // 8)),
72 | transforms.ToTensor(),
73 | ]
74 | )
75 | self.attn_transform_32 = transforms.Compose(
76 | [
77 | transforms.Resize(
78 | (self.img_size[0] // 16, self.img_size[0] // 16)),
79 | transforms.ToTensor(),
80 | ]
81 | )
82 | self.attn_transform_16 = transforms.Compose(
83 | [
84 | transforms.Resize(
85 | (self.img_size[0] // 32, self.img_size[0] // 32)),
86 | transforms.ToTensor(),
87 | ]
88 | )
89 | self.attn_transform_8 = transforms.Compose(
90 | [
91 | transforms.Resize(
92 | (self.img_size[0] // 64, self.img_size[0] // 64)),
93 | transforms.ToTensor(),
94 | ]
95 | )
96 |
97 | self.face_analysis = FaceAnalysis(
98 | name="",
99 | root=face_analysis_model_path,
100 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
101 | )
102 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
103 |
104 | def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float):
105 | """
106 | Apply preprocessing to the source image to prepare for face analysis.
107 |
108 | Parameters:
109 | source_image_path (str): The path to the source image.
110 | cache_dir (str): The directory to cache intermediate results.
111 |
112 | Returns:
113 | None
114 | """
115 | source_image = Image.open(source_image_path)
116 | ref_image_pil = source_image.convert("RGB")
117 | # 1. image augmentation
118 | pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform)
119 |
120 |
121 | # 2.1 detect face
122 | faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
123 | # use max size face
124 | face = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
125 |
126 | # 2.2 face embedding
127 | face_emb = face["embedding"]
128 |
129 | # 2.3 render face mask
130 | get_mask(source_image_path, cache_dir, face_region_ratio)
131 | file_name = os.path.basename(source_image_path).split(".")[0]
132 | face_mask_pil = Image.open(
133 | os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB")
134 |
135 | face_mask = self._augmentation(face_mask_pil, self.cond_transform)
136 |
137 | # 2.4 detect and expand lip, face mask
138 | sep_background_mask = Image.open(
139 | os.path.join(cache_dir, f"{file_name}_sep_background.png"))
140 | sep_face_mask = Image.open(
141 | os.path.join(cache_dir, f"{file_name}_sep_face.png"))
142 | sep_lip_mask = Image.open(
143 | os.path.join(cache_dir, f"{file_name}_sep_lip.png"))
144 |
145 | pixel_values_face_mask = [
146 | self._augmentation(sep_face_mask, self.attn_transform_64),
147 | self._augmentation(sep_face_mask, self.attn_transform_32),
148 | self._augmentation(sep_face_mask, self.attn_transform_16),
149 | self._augmentation(sep_face_mask, self.attn_transform_8),
150 | ]
151 | pixel_values_lip_mask = [
152 | self._augmentation(sep_lip_mask, self.attn_transform_64),
153 | self._augmentation(sep_lip_mask, self.attn_transform_32),
154 | self._augmentation(sep_lip_mask, self.attn_transform_16),
155 | self._augmentation(sep_lip_mask, self.attn_transform_8),
156 | ]
157 | pixel_values_full_mask = [
158 | self._augmentation(sep_background_mask, self.attn_transform_64),
159 | self._augmentation(sep_background_mask, self.attn_transform_32),
160 | self._augmentation(sep_background_mask, self.attn_transform_16),
161 | self._augmentation(sep_background_mask, self.attn_transform_8),
162 | ]
163 |
164 | pixel_values_full_mask = [mask.view(1, -1)
165 | for mask in pixel_values_full_mask]
166 | pixel_values_face_mask = [mask.view(1, -1)
167 | for mask in pixel_values_face_mask]
168 | pixel_values_lip_mask = [mask.view(1, -1)
169 | for mask in pixel_values_lip_mask]
170 |
171 | return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask
172 |
173 | def close(self):
174 | """
175 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
176 |
177 | Args:
178 | self: The ImageProcessor instance.
179 |
180 | Returns:
181 | None.
182 | """
183 | for _, model in self.face_analysis.models.items():
184 | if hasattr(model, "Dispose"):
185 | model.Dispose()
186 |
187 | def _augmentation(self, images, transform, state=None):
188 | if state is not None:
189 | torch.set_rng_state(state)
190 | if isinstance(images, List):
191 | transformed_images = [transform(img) for img in images]
192 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
193 | else:
194 | ret_tensor = transform(images) # (c, h, w)
195 | return ret_tensor
196 |
197 | def __enter__(self):
198 | return self
199 |
200 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
201 | self.close()
202 |
--------------------------------------------------------------------------------
/hallo/models/wav2vec.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0901
2 | # src/models/wav2vec.py
3 |
4 | """
5 | This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
6 | It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
7 | such as feature extraction and encoding.
8 |
9 | Classes:
10 | Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
11 |
12 | Functions:
13 | linear_interpolation: Interpolates the features based on the sequence length.
14 | """
15 |
16 | import torch.nn.functional as F
17 | from transformers import Wav2Vec2Model
18 | from transformers.modeling_outputs import BaseModelOutput
19 |
20 |
21 | class Wav2VecModel(Wav2Vec2Model):
22 | """
23 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
24 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
25 | ...
26 |
27 | Attributes:
28 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
29 |
30 | Methods:
31 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
32 | , output_attentions=None, output_hidden_states=None, return_dict=None):
33 | Forward pass of the Wav2VecModel.
34 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
35 |
36 | feature_extract(input_values, seq_len):
37 | Extracts features from the input_values using the base model.
38 |
39 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
40 | Encodes the extracted features using the base model and returns the encoded features.
41 | """
42 | def forward(
43 | self,
44 | input_values,
45 | seq_len,
46 | attention_mask=None,
47 | mask_time_indices=None,
48 | output_attentions=None,
49 | output_hidden_states=None,
50 | return_dict=None,
51 | ):
52 | """
53 | Forward pass of the Wav2Vec model.
54 |
55 | Args:
56 | self: The instance of the model.
57 | input_values: The input values (waveform) to the model.
58 | seq_len: The sequence length of the input values.
59 | attention_mask: Attention mask to be used for the model.
60 | mask_time_indices: Mask indices to be used for the model.
61 | output_attentions: If set to True, returns attentions.
62 | output_hidden_states: If set to True, returns hidden states.
63 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
64 |
65 | Returns:
66 | The output of the Wav2Vec model.
67 | """
68 | self.config.output_attentions = True
69 |
70 | output_hidden_states = (
71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72 | )
73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74 |
75 | extract_features = self.feature_extractor(input_values)
76 | extract_features = extract_features.transpose(1, 2)
77 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
78 |
79 | if attention_mask is not None:
80 | # compute reduced attention_mask corresponding to feature vectors
81 | attention_mask = self._get_feature_vector_attention_mask(
82 | extract_features.shape[1], attention_mask, add_adapter=False
83 | )
84 |
85 | hidden_states, extract_features = self.feature_projection(extract_features)
86 | hidden_states = self._mask_hidden_states(
87 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
88 | )
89 |
90 | encoder_outputs = self.encoder(
91 | hidden_states,
92 | attention_mask=attention_mask,
93 | output_attentions=output_attentions,
94 | output_hidden_states=output_hidden_states,
95 | return_dict=return_dict,
96 | )
97 |
98 | hidden_states = encoder_outputs[0]
99 |
100 | if self.adapter is not None:
101 | hidden_states = self.adapter(hidden_states)
102 |
103 | if not return_dict:
104 | return (hidden_states, ) + encoder_outputs[1:]
105 | return BaseModelOutput(
106 | last_hidden_state=hidden_states,
107 | hidden_states=encoder_outputs.hidden_states,
108 | attentions=encoder_outputs.attentions,
109 | )
110 |
111 |
112 | def feature_extract(
113 | self,
114 | input_values,
115 | seq_len,
116 | ):
117 | """
118 | Extracts features from the input values and returns the extracted features.
119 |
120 | Parameters:
121 | input_values (torch.Tensor): The input values to be processed.
122 | seq_len (torch.Tensor): The sequence lengths of the input values.
123 |
124 | Returns:
125 | extracted_features (torch.Tensor): The extracted features from the input values.
126 | """
127 | extract_features = self.feature_extractor(input_values)
128 | extract_features = extract_features.transpose(1, 2)
129 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
130 |
131 | return extract_features
132 |
133 | def encode(
134 | self,
135 | extract_features,
136 | attention_mask=None,
137 | mask_time_indices=None,
138 | output_attentions=None,
139 | output_hidden_states=None,
140 | return_dict=None,
141 | ):
142 | """
143 | Encodes the input features into the output space.
144 |
145 | Args:
146 | extract_features (torch.Tensor): The extracted features from the audio signal.
147 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
148 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
149 | output_attentions (bool, optional): If set to True, returns the attention weights.
150 | output_hidden_states (bool, optional): If set to True, returns all hidden states.
151 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
152 |
153 | Returns:
154 | The encoded output features.
155 | """
156 | self.config.output_attentions = True
157 |
158 | output_hidden_states = (
159 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
160 | )
161 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
162 |
163 | if attention_mask is not None:
164 | # compute reduced attention_mask corresponding to feature vectors
165 | attention_mask = self._get_feature_vector_attention_mask(
166 | extract_features.shape[1], attention_mask, add_adapter=False
167 | )
168 |
169 | hidden_states, extract_features = self.feature_projection(extract_features)
170 | hidden_states = self._mask_hidden_states(
171 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
172 | )
173 |
174 | encoder_outputs = self.encoder(
175 | hidden_states,
176 | attention_mask=attention_mask,
177 | output_attentions=output_attentions,
178 | output_hidden_states=output_hidden_states,
179 | return_dict=return_dict,
180 | )
181 |
182 | hidden_states = encoder_outputs[0]
183 |
184 | if self.adapter is not None:
185 | hidden_states = self.adapter(hidden_states)
186 |
187 | if not return_dict:
188 | return (hidden_states, ) + encoder_outputs[1:]
189 | return BaseModelOutput(
190 | last_hidden_state=hidden_states,
191 | hidden_states=encoder_outputs.hidden_states,
192 | attentions=encoder_outputs.attentions,
193 | )
194 |
195 |
196 | def linear_interpolation(features, seq_len):
197 | """
198 | Transpose the features to interpolate linearly.
199 |
200 | Args:
201 | features (torch.Tensor): The extracted features to be interpolated.
202 | seq_len (torch.Tensor): The sequence lengths of the features.
203 |
204 | Returns:
205 | torch.Tensor: The interpolated features.
206 | """
207 | features = features.transpose(1, 2)
208 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
209 | return output_features.transpose(1, 2)
210 |
--------------------------------------------------------------------------------
/hallo/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module implements the Transformer3DModel, a PyTorch model designed for processing
4 | 3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer
5 | model with support for gradient checkpointing and various types of attention mechanisms.
6 | The model can be configured with different parameters such as the number of attention heads,
7 | attention head dimension, and the number of layers. It also supports the use of audio modules
8 | for enhanced feature extraction from video data.
9 | """
10 |
11 | from dataclasses import dataclass
12 | from typing import Optional
13 |
14 | import torch
15 | from diffusers.configuration_utils import ConfigMixin, register_to_config
16 | from diffusers.models import ModelMixin
17 | from diffusers.utils import BaseOutput
18 | from einops import rearrange, repeat
19 | from torch import nn
20 |
21 | from .attention import (AudioTemporalBasicTransformerBlock,
22 | TemporalBasicTransformerBlock)
23 |
24 |
25 | @dataclass
26 | class Transformer3DModelOutput(BaseOutput):
27 | """
28 | The output of the [`Transformer3DModel`].
29 |
30 | Attributes:
31 | sample (`torch.FloatTensor`):
32 | The output tensor from the transformer model, which is the result of processing the input
33 | hidden states through the transformer blocks and any subsequent layers.
34 | """
35 | sample: torch.FloatTensor
36 |
37 |
38 | class Transformer3DModel(ModelMixin, ConfigMixin):
39 | """
40 | Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model.
41 | It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks.
42 | The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method.
43 | """
44 | _supports_gradient_checkpointing = True
45 |
46 | @register_to_config
47 | def __init__(
48 | self,
49 | num_attention_heads: int = 16,
50 | attention_head_dim: int = 88,
51 | in_channels: Optional[int] = None,
52 | num_layers: int = 1,
53 | dropout: float = 0.0,
54 | norm_num_groups: int = 32,
55 | cross_attention_dim: Optional[int] = None,
56 | attention_bias: bool = False,
57 | activation_fn: str = "geglu",
58 | num_embeds_ada_norm: Optional[int] = None,
59 | use_linear_projection: bool = False,
60 | only_cross_attention: bool = False,
61 | upcast_attention: bool = False,
62 | unet_use_cross_frame_attention=None,
63 | unet_use_temporal_attention=None,
64 | use_audio_module=False,
65 | depth=0,
66 | unet_block_name=None,
67 | stack_enable_blocks_name = None,
68 | stack_enable_blocks_depth = None,
69 | ):
70 | super().__init__()
71 | self.use_linear_projection = use_linear_projection
72 | self.num_attention_heads = num_attention_heads
73 | self.attention_head_dim = attention_head_dim
74 | inner_dim = num_attention_heads * attention_head_dim
75 | self.use_audio_module = use_audio_module
76 | # Define input layers
77 | self.in_channels = in_channels
78 |
79 | self.norm = torch.nn.GroupNorm(
80 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
81 | )
82 | if use_linear_projection:
83 | self.proj_in = nn.Linear(in_channels, inner_dim)
84 | else:
85 | self.proj_in = nn.Conv2d(
86 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
87 | )
88 |
89 | if use_audio_module:
90 | self.transformer_blocks = nn.ModuleList(
91 | [
92 | AudioTemporalBasicTransformerBlock(
93 | inner_dim,
94 | num_attention_heads,
95 | attention_head_dim,
96 | dropout=dropout,
97 | cross_attention_dim=cross_attention_dim,
98 | activation_fn=activation_fn,
99 | num_embeds_ada_norm=num_embeds_ada_norm,
100 | attention_bias=attention_bias,
101 | only_cross_attention=only_cross_attention,
102 | upcast_attention=upcast_attention,
103 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
104 | unet_use_temporal_attention=unet_use_temporal_attention,
105 | depth=depth,
106 | unet_block_name=unet_block_name,
107 | stack_enable_blocks_name=stack_enable_blocks_name,
108 | stack_enable_blocks_depth=stack_enable_blocks_depth,
109 | )
110 | for d in range(num_layers)
111 | ]
112 | )
113 | else:
114 | # Define transformers blocks
115 | self.transformer_blocks = nn.ModuleList(
116 | [
117 | TemporalBasicTransformerBlock(
118 | inner_dim,
119 | num_attention_heads,
120 | attention_head_dim,
121 | dropout=dropout,
122 | cross_attention_dim=cross_attention_dim,
123 | activation_fn=activation_fn,
124 | num_embeds_ada_norm=num_embeds_ada_norm,
125 | attention_bias=attention_bias,
126 | only_cross_attention=only_cross_attention,
127 | upcast_attention=upcast_attention,
128 | )
129 | for d in range(num_layers)
130 | ]
131 | )
132 |
133 | # 4. Define output layers
134 | if use_linear_projection:
135 | self.proj_out = nn.Linear(in_channels, inner_dim)
136 | else:
137 | self.proj_out = nn.Conv2d(
138 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
139 | )
140 |
141 | self.gradient_checkpointing = False
142 |
143 | def _set_gradient_checkpointing(self, module, value=False):
144 | if hasattr(module, "gradient_checkpointing"):
145 | module.gradient_checkpointing = value
146 |
147 | def forward(
148 | self,
149 | hidden_states,
150 | encoder_hidden_states=None,
151 | attention_mask=None,
152 | full_mask=None,
153 | face_mask=None,
154 | lip_mask=None,
155 | motion_scale=None,
156 | timestep=None,
157 | return_dict: bool = True,
158 | ):
159 | """
160 | Forward pass for the Transformer3DModel.
161 |
162 | Args:
163 | hidden_states (torch.Tensor): The input hidden states.
164 | encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states.
165 | attention_mask (torch.Tensor, optional): The attention mask.
166 | full_mask (torch.Tensor, optional): The full mask.
167 | face_mask (torch.Tensor, optional): The face mask.
168 | lip_mask (torch.Tensor, optional): The lip mask.
169 | timestep (int, optional): The current timestep.
170 | return_dict (bool, optional): Whether to return a dictionary or a tuple.
171 |
172 | Returns:
173 | output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel.
174 | """
175 | # Input
176 | assert (
177 | hidden_states.dim() == 5
178 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
179 | video_length = hidden_states.shape[2]
180 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
181 |
182 | # TODO
183 | if self.use_audio_module:
184 | encoder_hidden_states = rearrange(
185 | encoder_hidden_states,
186 | "bs f margin dim -> (bs f) margin dim",
187 | )
188 | else:
189 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
190 | encoder_hidden_states = repeat(
191 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
192 | )
193 |
194 | batch, _, height, weight = hidden_states.shape
195 | residual = hidden_states
196 |
197 | hidden_states = self.norm(hidden_states)
198 | if not self.use_linear_projection:
199 | hidden_states = self.proj_in(hidden_states)
200 | inner_dim = hidden_states.shape[1]
201 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
202 | batch, height * weight, inner_dim
203 | )
204 | else:
205 | inner_dim = hidden_states.shape[1]
206 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
207 | batch, height * weight, inner_dim
208 | )
209 | hidden_states = self.proj_in(hidden_states)
210 |
211 | # Blocks
212 | motion_frames = []
213 | for _, block in enumerate(self.transformer_blocks):
214 | if isinstance(block, TemporalBasicTransformerBlock):
215 | hidden_states, motion_frame_fea = block(
216 | hidden_states,
217 | encoder_hidden_states=encoder_hidden_states,
218 | timestep=timestep,
219 | video_length=video_length,
220 | )
221 | motion_frames.append(motion_frame_fea)
222 | else:
223 | hidden_states = block(
224 | hidden_states, # shape [2, 4096, 320]
225 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
226 | attention_mask=attention_mask,
227 | full_mask=full_mask,
228 | face_mask=face_mask,
229 | lip_mask=lip_mask,
230 | timestep=timestep,
231 | video_length=video_length,
232 | motion_scale=motion_scale,
233 | )
234 |
235 | # Output
236 | if not self.use_linear_projection:
237 | hidden_states = (
238 | hidden_states.reshape(batch, height, weight, inner_dim)
239 | .permute(0, 3, 1, 2)
240 | .contiguous()
241 | )
242 | hidden_states = self.proj_out(hidden_states)
243 | else:
244 | hidden_states = self.proj_out(hidden_states)
245 | hidden_states = (
246 | hidden_states.reshape(batch, height, weight, inner_dim)
247 | .permute(0, 3, 1, 2)
248 | .contiguous()
249 | )
250 |
251 | output = hidden_states + residual
252 |
253 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
254 | if not return_dict:
255 | return (output, motion_frames)
256 |
257 | return Transformer3DModelOutput(sample=output)
258 |
--------------------------------------------------------------------------------
/hallo/datasets/talk_video.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | talking_video_dataset.py
4 |
5 | This module defines the TalkingVideoDataset class, a custom PyTorch dataset
6 | for handling talking video data. The dataset uses video files, masks, and
7 | embeddings to prepare data for tasks such as video generation and
8 | speech-driven video animation.
9 |
10 | Classes:
11 | TalkingVideoDataset
12 |
13 | Dependencies:
14 | json
15 | random
16 | torch
17 | decord.VideoReader, decord.cpu
18 | PIL.Image
19 | torch.utils.data.Dataset
20 | torchvision.transforms
21 |
22 | Example:
23 | from talking_video_dataset import TalkingVideoDataset
24 | from torch.utils.data import DataLoader
25 |
26 | # Example configuration for the Wav2Vec model
27 | class Wav2VecConfig:
28 | def __init__(self, audio_type, model_scale, features):
29 | self.audio_type = audio_type
30 | self.model_scale = model_scale
31 | self.features = features
32 |
33 | wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature")
34 |
35 | # Initialize dataset
36 | dataset = TalkingVideoDataset(
37 | img_size=(512, 512),
38 | sample_rate=16000,
39 | audio_margin=2,
40 | n_motion_frames=0,
41 | n_sample_frames=16,
42 | data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"],
43 | wav2vec_cfg=wav2vec_cfg,
44 | )
45 |
46 | # Initialize dataloader
47 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
48 |
49 | # Fetch one batch of data
50 | batch = next(iter(dataloader))
51 | print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512)
52 |
53 | The TalkingVideoDataset class provides methods for loading video frames, masks,
54 | audio embeddings, and other relevant data, applying transformations, and preparing
55 | the data for training and evaluation in a deep learning pipeline.
56 |
57 | Attributes:
58 | img_size (tuple): The dimensions to resize the video frames to.
59 | sample_rate (int): The audio sample rate.
60 | audio_margin (int): The margin for audio sampling.
61 | n_motion_frames (int): The number of motion frames.
62 | n_sample_frames (int): The number of sample frames.
63 | data_meta_paths (list): List of paths to the JSON metadata files.
64 | wav2vec_cfg (object): Configuration for the Wav2Vec model.
65 |
66 | Methods:
67 | augmentation(images, transform, state=None): Apply transformation to input images.
68 | __getitem__(index): Get a sample from the dataset at the specified index.
69 | __len__(): Return the length of the dataset.
70 | """
71 |
72 | import json
73 | import random
74 | from typing import List
75 |
76 | import torch
77 | from decord import VideoReader, cpu
78 | from PIL import Image
79 | from torch.utils.data import Dataset
80 | from torchvision import transforms
81 |
82 |
83 | class TalkingVideoDataset(Dataset):
84 | """
85 | A dataset class for processing talking video data.
86 |
87 | Args:
88 | img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
89 | sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
90 | audio_margin (int, optional): The margin for the audio data. Defaults to 2.
91 | n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
92 | n_sample_frames (int, optional): The number of sample frames. Defaults to 16.
93 | data_meta_paths (list, optional): The paths to the data metadata. Defaults to None.
94 | wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None.
95 |
96 | Attributes:
97 | img_size (tuple): The size of the output images.
98 | sample_rate (int): The sample rate of the audio data.
99 | audio_margin (int): The margin for the audio data.
100 | n_motion_frames (int): The number of motion frames.
101 | n_sample_frames (int): The number of sample frames.
102 | data_meta_paths (list): The paths to the data metadata.
103 | wav2vec_cfg (dict): The configuration for the wav2vec model.
104 | """
105 |
106 | def __init__(
107 | self,
108 | img_size=(512, 512),
109 | sample_rate=16000,
110 | audio_margin=2,
111 | n_motion_frames=0,
112 | n_sample_frames=16,
113 | data_meta_paths=None,
114 | wav2vec_cfg=None,
115 | ):
116 | super().__init__()
117 | self.sample_rate = sample_rate
118 | self.img_size = img_size
119 | self.audio_margin = audio_margin
120 | self.n_motion_frames = n_motion_frames
121 | self.n_sample_frames = n_sample_frames
122 | self.audio_type = wav2vec_cfg.audio_type
123 | self.audio_model = wav2vec_cfg.model_scale
124 | self.audio_features = wav2vec_cfg.features
125 |
126 | vid_meta = []
127 | for data_meta_path in data_meta_paths:
128 | with open(data_meta_path, "r", encoding="utf-8") as f:
129 | vid_meta.extend(json.load(f))
130 | self.vid_meta = vid_meta
131 | self.length = len(self.vid_meta)
132 | self.pixel_transform = transforms.Compose(
133 | [
134 | transforms.Resize(self.img_size),
135 | transforms.ToTensor(),
136 | transforms.Normalize([0.5], [0.5]),
137 | ]
138 | )
139 |
140 | self.cond_transform = transforms.Compose(
141 | [
142 | transforms.Resize(self.img_size),
143 | transforms.ToTensor(),
144 | ]
145 | )
146 | self.attn_transform_64 = transforms.Compose(
147 | [
148 | transforms.Resize((64,64)),
149 | transforms.ToTensor(),
150 | ]
151 | )
152 | self.attn_transform_32 = transforms.Compose(
153 | [
154 | transforms.Resize((32, 32)),
155 | transforms.ToTensor(),
156 | ]
157 | )
158 | self.attn_transform_16 = transforms.Compose(
159 | [
160 | transforms.Resize((16, 16)),
161 | transforms.ToTensor(),
162 | ]
163 | )
164 | self.attn_transform_8 = transforms.Compose(
165 | [
166 | transforms.Resize((8, 8)),
167 | transforms.ToTensor(),
168 | ]
169 | )
170 |
171 | def augmentation(self, images, transform, state=None):
172 | """
173 | Apply the given transformation to the input images.
174 |
175 | Args:
176 | images (List[PIL.Image] or PIL.Image): The input images to be transformed.
177 | transform (torchvision.transforms.Compose): The transformation to be applied to the images.
178 | state (torch.ByteTensor, optional): The state of the random number generator.
179 | If provided, it will set the RNG state to this value before applying the transformation. Defaults to None.
180 |
181 | Returns:
182 | torch.Tensor: The transformed images as a tensor.
183 | If the input was a list of images, the tensor will have shape (f, c, h, w),
184 | where f is the number of images, c is the number of channels, h is the height, and w is the width.
185 | If the input was a single image, the tensor will have shape (c, h, w),
186 | where c is the number of channels, h is the height, and w is the width.
187 | """
188 | if state is not None:
189 | torch.set_rng_state(state)
190 | if isinstance(images, List):
191 | transformed_images = [transform(img) for img in images]
192 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
193 | else:
194 | ret_tensor = transform(images) # (c, h, w)
195 | return ret_tensor
196 |
197 | def __getitem__(self, index):
198 | video_meta = self.vid_meta[index]
199 | video_path = video_meta["video_path"]
200 | mask_path = video_meta["mask_path"]
201 | lip_mask_union_path = video_meta.get("sep_mask_lip", None)
202 | face_mask_union_path = video_meta.get("sep_mask_face", None)
203 | full_mask_union_path = video_meta.get("sep_mask_border", None)
204 | face_emb_path = video_meta["face_emb_path"]
205 | audio_emb_path = video_meta[
206 | f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}"
207 | ]
208 | tgt_mask_pil = Image.open(mask_path)
209 | video_frames = VideoReader(video_path, ctx=cpu(0))
210 | assert tgt_mask_pil is not None, "Fail to load target mask."
211 | assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames."
212 | video_length = len(video_frames)
213 |
214 | assert (
215 | video_length
216 | > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin
217 | )
218 | start_idx = random.randint(
219 | self.n_motion_frames,
220 | video_length - self.n_sample_frames - self.audio_margin - 1,
221 | )
222 |
223 | videos = video_frames[start_idx : start_idx + self.n_sample_frames]
224 |
225 | frame_list = [
226 | Image.fromarray(video).convert("RGB") for video in videos.asnumpy()
227 | ]
228 |
229 | face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames
230 | lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames
231 | full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames
232 | assert face_masks_list[0] is not None, "Fail to load face mask."
233 | assert lip_masks_list[0] is not None, "Fail to load lip mask."
234 | assert full_masks_list[0] is not None, "Fail to load full mask."
235 |
236 |
237 | face_emb = torch.load(face_emb_path)
238 | audio_emb = torch.load(audio_emb_path)
239 | indices = (
240 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
241 | ) # Generates [-2, -1, 0, 1, 2]
242 | center_indices = torch.arange(
243 | start_idx,
244 | start_idx + self.n_sample_frames,
245 | ).unsqueeze(1) + indices.unsqueeze(0)
246 | audio_tensor = audio_emb[center_indices]
247 |
248 | ref_img_idx = random.randint(
249 | self.n_motion_frames,
250 | video_length - self.n_sample_frames - self.audio_margin - 1,
251 | )
252 | ref_img = video_frames[ref_img_idx].asnumpy()
253 | ref_img = Image.fromarray(ref_img)
254 |
255 | if self.n_motion_frames > 0:
256 | motions = video_frames[start_idx - self.n_motion_frames : start_idx]
257 | motion_list = [
258 | Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy()
259 | ]
260 |
261 | # transform
262 | state = torch.get_rng_state()
263 | pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state)
264 |
265 | pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state)
266 | pixel_values_mask = pixel_values_mask.repeat(3, 1, 1)
267 |
268 | pixel_values_face_mask = [
269 | self.augmentation(face_masks_list, self.attn_transform_64, state),
270 | self.augmentation(face_masks_list, self.attn_transform_32, state),
271 | self.augmentation(face_masks_list, self.attn_transform_16, state),
272 | self.augmentation(face_masks_list, self.attn_transform_8, state),
273 | ]
274 | pixel_values_lip_mask = [
275 | self.augmentation(lip_masks_list, self.attn_transform_64, state),
276 | self.augmentation(lip_masks_list, self.attn_transform_32, state),
277 | self.augmentation(lip_masks_list, self.attn_transform_16, state),
278 | self.augmentation(lip_masks_list, self.attn_transform_8, state),
279 | ]
280 | pixel_values_full_mask = [
281 | self.augmentation(full_masks_list, self.attn_transform_64, state),
282 | self.augmentation(full_masks_list, self.attn_transform_32, state),
283 | self.augmentation(full_masks_list, self.attn_transform_16, state),
284 | self.augmentation(full_masks_list, self.attn_transform_8, state),
285 | ]
286 |
287 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
288 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
289 | if self.n_motion_frames > 0:
290 | pixel_values_motion = self.augmentation(
291 | motion_list, self.pixel_transform, state
292 | )
293 | pixel_values_ref_img = torch.cat(
294 | [pixel_values_ref_img, pixel_values_motion], dim=0
295 | )
296 |
297 | sample = {
298 | "video_dir": video_path,
299 | "pixel_values_vid": pixel_values_vid,
300 | "pixel_values_mask": pixel_values_mask,
301 | "pixel_values_face_mask": pixel_values_face_mask,
302 | "pixel_values_lip_mask": pixel_values_lip_mask,
303 | "pixel_values_full_mask": pixel_values_full_mask,
304 | "audio_tensor": audio_tensor,
305 | "pixel_values_ref_img": pixel_values_ref_img,
306 | "face_emb": face_emb,
307 | }
308 |
309 | return sample
310 |
311 | def __len__(self):
312 | return len(self.vid_meta)
313 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # About this fork
2 |
3 | This fork was created to provide a convenient web interface for using Hallo. The original code has been slightly modified to allow for more control over the generation process.
4 |
5 | ## About colab
6 | ⚠️ To run the web interface, you need at least 12 GB of video memory (VRAM) and more than 12 GB of RAM. ⚠️
7 |
8 | Unfortunately, I was unable to create a free tier Colab notebook as there is not enough RAM available.
9 |
10 | ## Portable version
11 |
12 | But you can try if you have `pro` [colab](https://colab.research.google.com/drive/1JGkftvdEksrhJbeAUnnRAZNAfZyGjP44?usp=sharing)
13 |
14 | If you have windows and you don't want to bother with installing libs, you can download the [portable version](https://huggingface.co/daswer123/portable_webuis/resolve/main/hallo-portable-2.zip?download=true) , unpack and launch `run.bat`
15 |
16 | ## Screenshot
17 |
18 | 
19 |
20 | ## Installation
21 |
22 | ### Docker
23 |
24 | ```bash
25 | docker compose up -d
26 | ```
27 | this will start the gradio web ui and expose the port 7680 which is mapped to 8020 to teh container's host
28 | The app will be available at http://localhost:8020
29 |
30 | Note : Be sure to use the correct cuda start image for your GPU driver version, if it doesn't build from the start.
31 |
32 | ### Windows
33 |
34 | 1. Clone this repository:
35 | ```
36 | git clone https://github.com/yourusername/hallo.git
37 | ```
38 |
39 | 2. Run `install.bat` to set up the environment and download the pretrained models.
40 |
41 | 3. Make sure ffmpeg is installed on your system. It doesn't matter where it's located, as long as the system can find it.
42 |
43 | 4. Launch the web interface by running `start.bat`.
44 |
45 | ### Linux
46 |
47 | 1. Clone this repository:
48 | ```
49 | git clone https://github.com/yourusername/hallo.git
50 | ```
51 |
52 | 2. Run `install.sh` to set up the environment and download the pretrained models.
53 |
54 | 3. Ensure ffmpeg is installed on your system. You can install it with:
55 | ```
56 | sudo apt-get install ffmpeg
57 | ```
58 |
59 | 4. Launch the web interface by running `start.sh`.
60 |
61 | ### Manual Installation
62 |
63 | If you prefer to install manually, here are the detailed steps:
64 |
65 | 1. Clone the repository and pretrained models:
66 | ```
67 | git lfs install
68 | git clone https://github.com/yourusername/hallo.git
69 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
70 | curl -L -o pretrained_models/hallo/net.pth https://huggingface.co/fudan-generative-ai/hallo/resolve/main/hallo/net.pth?download=true
71 | ```
72 |
73 | 2. Create a virtual environment and activate it:
74 | ```
75 | python -m venv venv
76 | venv\Scripts\activate # For Windows
77 | source venv/bin/activate # For Linux
78 | ```
79 |
80 | 3. Install the required packages:
81 | ```
82 | pip install -r requirements.txt
83 | pip install -e .
84 | pip install bitsandbytes-windows --force-reinstall # For Windows only
85 | ```
86 |
87 | 4. Install GPU libraries:
88 | ```
89 | pip install torch==2.2.2+cu121 torchaudio torchvision --index-url https://download.pytorch.org/whl/cu121
90 | pip install onnxruntime-gpu
91 | ```
92 |
93 | 5. Launch the web interface:
94 | ```
95 | python app.py
96 | ```
97 |
98 | 6. To share , use `--share` flag
99 | ```
100 | python app.py --share
101 | ```
102 |
103 |
Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation
104 |
105 |
113 |
118 |
119 |
120 | 1Fudan University 2Baidu Inc 3ETH Zurich 4Nanjing University
121 |
122 |
123 |
124 |
125 |

126 |

127 |

128 |

129 |

130 |
131 |
132 |
133 |
134 | # Showcase
135 |
136 |
137 | https://github.com/fudan-generative-vision/hallo/assets/17402682/294e78ef-c60d-4c32-8e3c-7f8d6934c6bd
138 |
139 |
140 | # Framework
141 |
142 | 
143 | 
144 |
145 | # News
146 |
147 | - **`2024/06/15`**: 🎉🎉🎉 Release the first version on [GitHub](https://github.com/fudan-generative-vision/hallo).
148 | - **`2024/06/15`**: ✨✨✨ Release some images and audios for inference testing on [Huggingface](https://huggingface.co/datasets/fudan-generative-ai/hallo_inference_samples).
149 |
150 | # Installation
151 |
152 | - System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 12.1
153 | - Tested GPUs: A100
154 |
155 | Create conda environment:
156 |
157 | ```bash
158 | conda create -n hallo python=3.10
159 | conda activate hallo
160 | ```
161 |
162 | Install packages with `pip`
163 |
164 | ```bash
165 | pip install -r requirements.txt
166 | pip install .
167 | ```
168 |
169 | Besides, ffmpeg is also need:
170 | ```bash
171 | apt-get install ffmpeg
172 | ```
173 |
174 | # Inference
175 |
176 | The inference entrypoint script is `scripts/inference.py`. Before testing your cases, there are two preparations need to be completed:
177 |
178 | 1. [Download all required pretrained models](#download-pretrained-models).
179 | 2. [Prepare source image and driving audio pairs](#prepare-inference-data).
180 | 3. [Run inference](#run-inference).
181 |
182 | ## Download pretrained models
183 |
184 | You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo).
185 |
186 | Clone the the pretrained models into `${PROJECT_ROOT}/pretrained_models` directory by cmd below:
187 |
188 | ```shell
189 | git lfs install
190 | git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
191 | ```
192 |
193 | Or you can download them separately from their source repo:
194 |
195 | - [hallo](https://huggingface.co/fudan-generative-ai/hallo/tree/main/hallo): Our checkpoints consist of denoising UNet, face locator, image & audio proj.
196 | - [audio_separator](https://huggingface.co/huangjackson/Kim_Vocal_2): Kim\_Vocal\_2 MDX-Net vocal removal model by [KimberleyJensen](https://github.com/KimberleyJensen). (_Thanks to runwayml_)
197 | - [insightface](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo): 2D and 3D Face Analysis placed into `pretrained_models/face_analysis/models/`. (_Thanks to deepinsight_)
198 | - [face landmarker](https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task): Face detection & mesh model from [mediapipe](https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker#models) placed into `pretrained_models/face_analysis/models`.
199 | - [motion module](https://github.com/guoyww/AnimateDiff/blob/main/README.md#202309-animatediff-v2): motion module from [AnimateDiff](https://github.com/guoyww/AnimateDiff). (_Thanks to guoyww_).
200 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse): Weights are intended to be used with the diffusers library. (_Thanks to stablilityai_)
201 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5): Initialized and fine-tuned from Stable-Diffusion-v1-2. (_Thanks to runwayml_)
202 | - [wav2vec](https://huggingface.co/facebook/wav2vec2-base-960h): wav audio to vector model from [Facebook](https://huggingface.co/facebook/wav2vec2-base-960h).
203 |
204 | Finally, these pretrained models should be organized as follows:
205 |
206 | ```text
207 | ./pretrained_models/
208 | |-- audio_separator/
209 | | `-- Kim_Vocal_2.onnx
210 | |-- face_analysis/
211 | | `-- models/
212 | | |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe
213 | | |-- 1k3d68.onnx
214 | | |-- 2d106det.onnx
215 | | |-- genderage.onnx
216 | | |-- glintr100.onnx
217 | | `-- scrfd_10g_bnkps.onnx
218 | |-- motion_module/
219 | | `-- mm_sd_v15_v2.ckpt
220 | |-- sd-vae-ft-mse/
221 | | |-- config.json
222 | | `-- diffusion_pytorch_model.safetensors
223 | |-- stable-diffusion-v1-5/
224 | | |-- feature_extractor/
225 | | | `-- preprocessor_config.json
226 | | |-- model_index.json
227 | | |-- unet/
228 | | | |-- config.json
229 | | | `-- diffusion_pytorch_model.safetensors
230 | | `-- v1-inference.yaml
231 | `-- wav2vec/
232 | |-- wav2vec2-base-960h/
233 | | |-- config.json
234 | | |-- feature_extractor_config.json
235 | | |-- model.safetensors
236 | | |-- preprocessor_config.json
237 | | |-- special_tokens_map.json
238 | | |-- tokenizer_config.json
239 | | `-- vocab.json
240 | ```
241 |
242 | ## Prepare Inference Data
243 |
244 | Hallo has a few simple requirements for input data:
245 |
246 | For the source image:
247 |
248 | 1. It should be cropped into squares.
249 | 2. The face should be the main focus, making up 50%-70% of the image.
250 | 3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles).
251 |
252 | For the driving audio:
253 |
254 | 1. It must be in WAV format.
255 | 2. It must be in English since our training datasets are only in this language.
256 | 3. Ensure the vocals are clear; background music is acceptable.
257 |
258 | We have provided some samples for your reference.
259 |
260 | ## Run inference
261 |
262 | Simply to run the `scripts/inference.py` and pass `source_image` and `driving_audio` as input:
263 |
264 | ```bash
265 | python scripts/inference.py --source_image examples/source_images/1.jpg --driving_audio examples/driving_audios/1.wav
266 | ```
267 |
268 | Animation results will be saved as `${PROJECT_ROOT}/.cache/output.mp4` by default. You can pass `--output` to specify the output file name. You can find more examples for inference at [examples folder](https://github.com/fudan-generative-vision/hallo/tree/main/examples).
269 |
270 | For more options:
271 |
272 | ```shell
273 | usage: inference.py [-h] [-c CONFIG] [--source_image SOURCE_IMAGE] [--driving_audio DRIVING_AUDIO] [--output OUTPUT] [--pose_weight POSE_WEIGHT]
274 | [--face_weight FACE_WEIGHT] [--lip_weight LIP_WEIGHT] [--face_expand_ratio FACE_EXPAND_RATIO]
275 |
276 | options:
277 | -h, --help show this help message and exit
278 | -c CONFIG, --config CONFIG
279 | --source_image SOURCE_IMAGE
280 | source image
281 | --driving_audio DRIVING_AUDIO
282 | driving audio
283 | --output OUTPUT output video file name
284 | --pose_weight POSE_WEIGHT
285 | weight of pose
286 | --face_weight FACE_WEIGHT
287 | weight of face
288 | --lip_weight LIP_WEIGHT
289 | weight of lip
290 | --face_expand_ratio FACE_EXPAND_RATIO
291 | face region
292 | ```
293 |
294 | # Roadmap
295 |
296 | | Status | Milestone | ETA |
297 | | :----: | :---------------------------------------------------------------------------------------------------- | :--------: |
298 | | ✅ | **[Inference source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo)** | 2024-06-15 |
299 | | ✅ | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/hallo)** | 2024-06-15 |
300 | | 🚀🚀🚀 | **[Traning: data preparation and training scripts]()** | 2024-06-25 |
301 | | 🚀🚀🚀 | **[Optimize inference performance in Mandarin]()** | TBD |
302 |
303 | # Citation
304 |
305 | If you find our work useful for your research, please consider citing the paper:
306 |
307 | ```
308 | @misc{xu2024hallo,
309 | title={Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation},
310 | author={Mingwang Xu and Hui Li and Qingkun Su and Hanlin Shang and Liwei Zhang and Ce Liu and Jingdong Wang and Yao Yao and Siyu zhu},
311 | year={2024},
312 | eprint={2406.08801},
313 | archivePrefix={arXiv},
314 | primaryClass={cs.CV}
315 | }
316 | ```
317 |
318 | # Opportunities available
319 |
320 | Multiple research positions are open at the **Generative Vision Lab, Fudan University**! Include:
321 |
322 | - Research assistant
323 | - Postdoctoral researcher
324 | - PhD candidate
325 | - Master students
326 |
327 | Interested individuals are encouraged to contact us at [siyuzhu@fudan.edu.cn](mailto://siyuzhu@fudan.edu.cn) for further information.
328 |
329 | # Social Risks and Mitigations
330 |
331 | The development of portrait image animation technologies driven by audio inputs poses social risks, such as the ethical implications of creating realistic portraits that could be misused for deepfakes. To mitigate these risks, it is crucial to establish ethical guidelines and responsible use practices. Privacy and consent concerns also arise from using individuals' images and voices. Addressing these involves transparent data usage policies, informed consent, and safeguarding privacy rights. By addressing these risks and implementing mitigations, the research aims to ensure the responsible and ethical development of this technology.
332 |
--------------------------------------------------------------------------------
/hallo/models/resnet.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1120
2 | # pylint: disable=E1102
3 | # pylint: disable=W0237
4 |
5 | # src/models/resnet.py
6 |
7 | """
8 | This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm,
9 | Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct
10 | a deep neural network model for image classification or other computer vision tasks.
11 |
12 | Classes:
13 | - InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d.
14 | - InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm.
15 | - Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor.
16 | - Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor.
17 | - ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures.
18 | - Mish: A Mish activation function, which is a smooth, non-monotonic activation function.
19 |
20 | To use this module, simply import the classes and functions you need and follow the instructions provided in
21 | the respective class and function docstrings.
22 | """
23 |
24 | import torch
25 | import torch.nn.functional as F
26 | from einops import rearrange
27 | from torch import nn
28 |
29 |
30 | class InflatedConv3d(nn.Conv2d):
31 | """
32 | InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method.
33 |
34 | This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer
35 | commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and
36 | InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of
37 | inflating 2D convolutional layers to 3D for use in 3D deep learning tasks.
38 |
39 | Attributes:
40 | Same as torch.nn.Conv2d.
41 |
42 | Methods:
43 | forward(self, x):
44 | Performs 3D convolution on the input tensor x using the InflatedConv3d layer.
45 |
46 | Example:
47 | conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
48 | output = conv_layer(input_tensor)
49 | """
50 | def forward(self, x):
51 | """
52 | Forward pass of the InflatedConv3d layer.
53 |
54 | Args:
55 | x (torch.Tensor): Input tensor to the layer.
56 |
57 | Returns:
58 | torch.Tensor: Output tensor after applying the InflatedConv3d layer.
59 | """
60 | video_length = x.shape[2]
61 |
62 | x = rearrange(x, "b c f h w -> (b f) c h w")
63 | x = super().forward(x)
64 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
65 |
66 | return x
67 |
68 |
69 | class InflatedGroupNorm(nn.GroupNorm):
70 | """
71 | InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm.
72 | It is used to apply group normalization to 3D tensors.
73 |
74 | Args:
75 | num_groups (int): The number of groups to divide the channels into.
76 | num_channels (int): The number of channels in the input tensor.
77 | eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5.
78 | affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True.
79 |
80 | Attributes:
81 | weight (torch.Tensor): The learnable weight tensor for scale.
82 | bias (torch.Tensor): The learnable bias tensor for shift.
83 |
84 | Forward method:
85 | x (torch.Tensor): Input tensor to be normalized.
86 | return (torch.Tensor): Normalized tensor.
87 | """
88 | def forward(self, x):
89 | """
90 | Performs a forward pass through the CustomClassName.
91 |
92 | :param x: Input tensor of shape (batch_size, channels, video_length, height, width).
93 | :return: Output tensor of shape (batch_size, channels, video_length, height, width).
94 | """
95 | video_length = x.shape[2]
96 |
97 | x = rearrange(x, "b c f h w -> (b f) c h w")
98 | x = super().forward(x)
99 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
100 |
101 | return x
102 |
103 |
104 | class Upsample3D(nn.Module):
105 | """
106 | Upsample3D is a PyTorch module that upsamples a 3D tensor.
107 |
108 | Args:
109 | channels (int): The number of channels in the input tensor.
110 | use_conv (bool): Whether to use a convolutional layer for upsampling.
111 | use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling.
112 | out_channels (int): The number of channels in the output tensor.
113 | name (str): The name of the convolutional layer.
114 | """
115 | def __init__(
116 | self,
117 | channels,
118 | use_conv=False,
119 | use_conv_transpose=False,
120 | out_channels=None,
121 | name="conv",
122 | ):
123 | super().__init__()
124 | self.channels = channels
125 | self.out_channels = out_channels or channels
126 | self.use_conv = use_conv
127 | self.use_conv_transpose = use_conv_transpose
128 | self.name = name
129 |
130 | if use_conv_transpose:
131 | raise NotImplementedError
132 | if use_conv:
133 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
134 |
135 | def forward(self, hidden_states, output_size=None):
136 | """
137 | Forward pass of the Upsample3D class.
138 |
139 | Args:
140 | hidden_states (torch.Tensor): Input tensor to be upsampled.
141 | output_size (tuple, optional): Desired output size of the upsampled tensor.
142 |
143 | Returns:
144 | torch.Tensor: Upsampled tensor.
145 |
146 | Raises:
147 | AssertionError: If the number of channels in the input tensor does not match the expected channels.
148 | """
149 | assert hidden_states.shape[1] == self.channels
150 |
151 | if self.use_conv_transpose:
152 | raise NotImplementedError
153 |
154 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
155 | dtype = hidden_states.dtype
156 | if dtype == torch.bfloat16:
157 | hidden_states = hidden_states.to(torch.float32)
158 |
159 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
160 | if hidden_states.shape[0] >= 64:
161 | hidden_states = hidden_states.contiguous()
162 |
163 | # if `output_size` is passed we force the interpolation output
164 | # size and do not make use of `scale_factor=2`
165 | if output_size is None:
166 | hidden_states = F.interpolate(
167 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
168 | )
169 | else:
170 | hidden_states = F.interpolate(
171 | hidden_states, size=output_size, mode="nearest"
172 | )
173 |
174 | # If the input is bfloat16, we cast back to bfloat16
175 | if dtype == torch.bfloat16:
176 | hidden_states = hidden_states.to(dtype)
177 |
178 | # if self.use_conv:
179 | # if self.name == "conv":
180 | # hidden_states = self.conv(hidden_states)
181 | # else:
182 | # hidden_states = self.Conv2d_0(hidden_states)
183 | hidden_states = self.conv(hidden_states)
184 |
185 | return hidden_states
186 |
187 |
188 | class Downsample3D(nn.Module):
189 | """
190 | The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to
191 | reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network.
192 |
193 | Attributes:
194 | channels (int): Number of input channels.
195 | use_conv (bool): Flag to use a convolutional layer for downsampling.
196 | out_channels (int, optional): Number of output channels. Defaults to input channels if None.
197 | padding (int): Padding added to the input.
198 | name (str): Name of the convolutional layer used for downsampling.
199 |
200 | Methods:
201 | forward(self, hidden_states):
202 | Downsamples the input tensor hidden_states and returns the downsampled tensor.
203 | """
204 | def __init__(
205 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
206 | ):
207 | """
208 | Downsamples the given input in the 3D space.
209 |
210 | Args:
211 | channels: The number of input channels.
212 | use_conv: Whether to use a convolutional layer for downsampling.
213 | out_channels: The number of output channels. If None, the input channels are used.
214 | padding: The amount of padding to be added to the input.
215 | name: The name of the convolutional layer.
216 | """
217 | super().__init__()
218 | self.channels = channels
219 | self.out_channels = out_channels or channels
220 | self.use_conv = use_conv
221 | self.padding = padding
222 | stride = 2
223 | self.name = name
224 |
225 | if use_conv:
226 | self.conv = InflatedConv3d(
227 | self.channels, self.out_channels, 3, stride=stride, padding=padding
228 | )
229 | else:
230 | raise NotImplementedError
231 |
232 | def forward(self, hidden_states):
233 | """
234 | Forward pass for the Downsample3D class.
235 |
236 | Args:
237 | hidden_states (torch.Tensor): Input tensor to be downsampled.
238 |
239 | Returns:
240 | torch.Tensor: Downsampled tensor.
241 |
242 | Raises:
243 | AssertionError: If the number of channels in the input tensor does not match the expected channels.
244 | """
245 | assert hidden_states.shape[1] == self.channels
246 | if self.use_conv and self.padding == 0:
247 | raise NotImplementedError
248 |
249 | assert hidden_states.shape[1] == self.channels
250 | hidden_states = self.conv(hidden_states)
251 |
252 | return hidden_states
253 |
254 |
255 | class ResnetBlock3D(nn.Module):
256 | """
257 | The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet
258 | architectures for both image and video modeling tasks.
259 |
260 | Attributes:
261 | in_channels (int): Number of input channels.
262 | out_channels (int, optional): Number of output channels, defaults to in_channels if None.
263 | conv_shortcut (bool): Flag to use a convolutional shortcut.
264 | dropout (float): Dropout rate.
265 | temb_channels (int): Number of channels in the time embedding tensor.
266 | groups (int): Number of groups for the group normalization layers.
267 | eps (float): Epsilon value for group normalization.
268 | non_linearity (str): Type of nonlinearity to apply after convolutions.
269 | time_embedding_norm (str): Type of normalization for the time embedding.
270 | output_scale_factor (float): Scaling factor for the output tensor.
271 | use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection.
272 | use_inflated_groupnorm (bool): Flag to use inflated group normalization layers.
273 |
274 | Methods:
275 | forward(self, input_tensor, temb):
276 | Passes the input tensor and time embedding through the residual block and
277 | returns the output tensor.
278 | """
279 | def __init__(
280 | self,
281 | *,
282 | in_channels,
283 | out_channels=None,
284 | conv_shortcut=False,
285 | dropout=0.0,
286 | temb_channels=512,
287 | groups=32,
288 | groups_out=None,
289 | pre_norm=True,
290 | eps=1e-6,
291 | non_linearity="swish",
292 | time_embedding_norm="default",
293 | output_scale_factor=1.0,
294 | use_in_shortcut=None,
295 | use_inflated_groupnorm=None,
296 | ):
297 | super().__init__()
298 | self.pre_norm = pre_norm
299 | self.pre_norm = True
300 | self.in_channels = in_channels
301 | out_channels = in_channels if out_channels is None else out_channels
302 | self.out_channels = out_channels
303 | self.use_conv_shortcut = conv_shortcut
304 | self.time_embedding_norm = time_embedding_norm
305 | self.output_scale_factor = output_scale_factor
306 |
307 | if groups_out is None:
308 | groups_out = groups
309 |
310 | assert use_inflated_groupnorm is not None
311 | if use_inflated_groupnorm:
312 | self.norm1 = InflatedGroupNorm(
313 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
314 | )
315 | else:
316 | self.norm1 = torch.nn.GroupNorm(
317 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
318 | )
319 |
320 | self.conv1 = InflatedConv3d(
321 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
322 | )
323 |
324 | if temb_channels is not None:
325 | if self.time_embedding_norm == "default":
326 | time_emb_proj_out_channels = out_channels
327 | elif self.time_embedding_norm == "scale_shift":
328 | time_emb_proj_out_channels = out_channels * 2
329 | else:
330 | raise ValueError(
331 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
332 | )
333 |
334 | self.time_emb_proj = torch.nn.Linear(
335 | temb_channels, time_emb_proj_out_channels
336 | )
337 | else:
338 | self.time_emb_proj = None
339 |
340 | if use_inflated_groupnorm:
341 | self.norm2 = InflatedGroupNorm(
342 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
343 | )
344 | else:
345 | self.norm2 = torch.nn.GroupNorm(
346 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
347 | )
348 | self.dropout = torch.nn.Dropout(dropout)
349 | self.conv2 = InflatedConv3d(
350 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
351 | )
352 |
353 | if non_linearity == "swish":
354 | self.nonlinearity = F.silu()
355 | elif non_linearity == "mish":
356 | self.nonlinearity = Mish()
357 | elif non_linearity == "silu":
358 | self.nonlinearity = nn.SiLU()
359 |
360 | self.use_in_shortcut = (
361 | self.in_channels != self.out_channels
362 | if use_in_shortcut is None
363 | else use_in_shortcut
364 | )
365 |
366 | self.conv_shortcut = None
367 | if self.use_in_shortcut:
368 | self.conv_shortcut = InflatedConv3d(
369 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
370 | )
371 |
372 | def forward(self, input_tensor, temb):
373 | """
374 | Forward pass for the ResnetBlock3D class.
375 |
376 | Args:
377 | input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer.
378 | temb (torch.Tensor): Token embedding tensor.
379 |
380 | Returns:
381 | torch.Tensor: Output tensor after passing through the ResnetBlock3D layer.
382 | """
383 | hidden_states = input_tensor
384 |
385 | hidden_states = self.norm1(hidden_states)
386 | hidden_states = self.nonlinearity(hidden_states)
387 |
388 | hidden_states = self.conv1(hidden_states)
389 |
390 | if temb is not None:
391 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
392 |
393 | if temb is not None and self.time_embedding_norm == "default":
394 | hidden_states = hidden_states + temb
395 |
396 | hidden_states = self.norm2(hidden_states)
397 |
398 | if temb is not None and self.time_embedding_norm == "scale_shift":
399 | scale, shift = torch.chunk(temb, 2, dim=1)
400 | hidden_states = hidden_states * (1 + scale) + shift
401 |
402 | hidden_states = self.nonlinearity(hidden_states)
403 |
404 | hidden_states = self.dropout(hidden_states)
405 | hidden_states = self.conv2(hidden_states)
406 |
407 | if self.conv_shortcut is not None:
408 | input_tensor = self.conv_shortcut(input_tensor)
409 |
410 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
411 |
412 | return output_tensor
413 |
414 |
415 | class Mish(torch.nn.Module):
416 | """
417 | The Mish class implements the Mish activation function, a smooth, non-monotonic function
418 | that can be used in neural networks as an alternative to traditional activation functions like ReLU.
419 |
420 | Methods:
421 | forward(self, hidden_states):
422 | Applies the Mish activation function to the input tensor hidden_states and
423 | returns the resulting tensor.
424 | """
425 | def forward(self, hidden_states):
426 | """
427 | Mish activation function.
428 |
429 | Args:
430 | hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to.
431 |
432 | Returns:
433 | hidden_states (torch.Tensor): The output tensor after applying the Mish activation function.
434 | """
435 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
436 |
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 | # scripts/inference.py
3 |
4 | """
5 | This script contains the main inference pipeline for processing audio and image inputs to generate a video output.
6 |
7 | The script imports necessary packages and classes, defines a neural network model,
8 | and contains functions for processing audio embeddings and performing inference.
9 |
10 | The main inference process is outlined in the following steps:
11 | 1. Initialize the configuration.
12 | 2. Set up runtime variables.
13 | 3. Prepare the input data for inference (source image, face mask, and face embeddings).
14 | 4. Process the audio embeddings.
15 | 5. Build and freeze the model and scheduler.
16 | 6. Run the inference loop and save the result.
17 |
18 | Usage:
19 | This script can be run from the command line with the following arguments:
20 | - audio_path: Path to the audio file.
21 | - image_path: Path to the source image.
22 | - face_mask_path: Path to the face mask image.
23 | - face_emb_path: Path to the face embeddings file.
24 | - output_path: Path to save the output video.
25 |
26 | Example:
27 | python scripts/inference.py --audio_path audio.wav --image_path image.jpg
28 | --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4
29 | """
30 |
31 | import argparse
32 | import gc
33 | import os
34 |
35 | import torch
36 | from diffusers import AutoencoderKL, DDIMScheduler
37 | from omegaconf import OmegaConf
38 | from torch import nn
39 |
40 | from hallo.animate.face_animate import FaceAnimatePipeline
41 | from hallo.datasets.audio_processor import AudioProcessor
42 | from hallo.datasets.image_processor import ImageProcessor
43 | from hallo.models.audio_proj import AudioProjModel
44 | from hallo.models.face_locator import FaceLocator
45 | from hallo.models.image_proj import ImageProjModel
46 | from hallo.models.unet_2d_condition import UNet2DConditionModel
47 | from hallo.models.unet_3d import UNet3DConditionModel
48 | from hallo.utils.util import tensor_to_video
49 |
50 | # from diffusers.utils.import_utils import is_xformers_available
51 |
52 |
53 | class Net(nn.Module):
54 | """
55 | The Net class combines all the necessary modules for the inference process.
56 |
57 | Args:
58 | reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
59 | denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
60 | face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
61 | imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
62 | audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
63 | """
64 | def __init__(
65 | self,
66 | reference_unet: UNet2DConditionModel,
67 | denoising_unet: UNet3DConditionModel,
68 | face_locator: FaceLocator,
69 | imageproj,
70 | audioproj,
71 | ):
72 | super().__init__()
73 | self.reference_unet = reference_unet
74 | self.denoising_unet = denoising_unet
75 | self.face_locator = face_locator
76 | self.imageproj = imageproj
77 | self.audioproj = audioproj
78 |
79 | def forward(self,):
80 | """
81 | empty function to override abstract function of nn Module
82 | """
83 |
84 | def get_modules(self):
85 | """
86 | Simple method to avoid too-few-public-methods pylint error
87 | """
88 | return {
89 | "reference_unet": self.reference_unet,
90 | "denoising_unet": self.denoising_unet,
91 | "face_locator": self.face_locator,
92 | "imageproj": self.imageproj,
93 | "audioproj": self.audioproj,
94 | }
95 |
96 |
97 | def process_audio_emb(audio_emb):
98 | """
99 | Process the audio embedding to concatenate with other tensors.
100 |
101 | Parameters:
102 | audio_emb (torch.Tensor): The audio embedding tensor to process.
103 |
104 | Returns:
105 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
106 | """
107 | concatenated_tensors = []
108 |
109 | for i in range(audio_emb.shape[0]):
110 | vectors_to_concat = [
111 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
112 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
113 |
114 | audio_emb = torch.stack(concatenated_tensors, dim=0)
115 |
116 | return audio_emb
117 |
118 |
119 |
120 | def inference_process(args: argparse.Namespace, setting_steps=40, setting_cfg=3.5, settings_seed=42, settings_fps=25, settings_motion_pose_scale=1.1, settings_motion_face_scale=1.1, settings_motion_lip_scale=1.1, settings_n_motion_frames=2, settings_n_sample_frames=16):
121 | """
122 | Perform inference processing.
123 |
124 | Args:
125 | args (argparse.Namespace): Command-line arguments.
126 |
127 | This function initializes the configuration for the inference process. It sets up the necessary
128 | modules and variables to prepare for the upcoming inference steps.
129 | """
130 | # 1. init config
131 | config = OmegaConf.load(args.config)
132 | config = OmegaConf.merge(config, vars(args))
133 |
134 |
135 | if setting_steps is not None:
136 | config.inference_steps = setting_steps
137 | if setting_cfg is not None:
138 | config.cfg_scale = setting_cfg
139 | if settings_seed is not None:
140 | config.seed = int(settings_seed)
141 | if settings_fps is not None:
142 | config.data.export_video.fps = settings_fps
143 | if settings_motion_pose_scale is not None:
144 | config.pose_weight = settings_motion_pose_scale
145 | if settings_motion_face_scale is not None:
146 | config.face_weight = settings_motion_face_scale
147 | if settings_motion_lip_scale is not None:
148 | config.lip_weight = settings_motion_lip_scale
149 | if settings_n_motion_frames is not None:
150 | config.data.n_motion_frames = settings_n_motion_frames
151 | if settings_n_sample_frames is not None:
152 | config.data.n_sample_frames = settings_n_sample_frames
153 |
154 |
155 | source_image_path = config.source_image
156 | driving_audio_path = config.driving_audio
157 | save_path = config.save_path
158 | if not os.path.exists(save_path):
159 | os.makedirs(save_path)
160 | motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
161 | if args.checkpoint is not None:
162 | config.audio_ckpt_dir = args.checkpoint
163 | # 2. runtime variables
164 | device = torch.device(
165 | "cuda") if torch.cuda.is_available() else torch.device("cpu")
166 | if config.weight_dtype == "fp16":
167 | weight_dtype = torch.float16
168 | elif config.weight_dtype == "bf16":
169 | weight_dtype = torch.bfloat16
170 | elif config.weight_dtype == "fp32":
171 | weight_dtype = torch.float32
172 | else:
173 | weight_dtype = torch.float32
174 |
175 | # 3. prepare inference data
176 | # 3.1 prepare source image, face mask, face embeddings
177 | img_size = (config.data.source_image.width,
178 | config.data.source_image.height)
179 | clip_length = config.data.n_sample_frames
180 | face_analysis_model_path = config.face_analysis.model_path
181 | with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
182 | source_image_pixels, \
183 | source_image_face_region, \
184 | source_image_face_emb, \
185 | source_image_full_mask, \
186 | source_image_face_mask, \
187 | source_image_lip_mask = image_processor.preprocess(
188 | source_image_path, save_path, config.face_expand_ratio)
189 |
190 | # 3.2 prepare audio embeddings
191 | sample_rate = config.data.driving_audio.sample_rate
192 | assert sample_rate == 16000, "audio sample rate must be 16000"
193 | fps = config.data.export_video.fps
194 | wav2vec_model_path = config.wav2vec.model_path
195 | wav2vec_only_last_features = config.wav2vec.features == "last"
196 | audio_separator_model_file = config.audio_separator.model_path
197 | with AudioProcessor(
198 | sample_rate,
199 | fps,
200 | wav2vec_model_path,
201 | wav2vec_only_last_features,
202 | os.path.dirname(audio_separator_model_file),
203 | os.path.basename(audio_separator_model_file),
204 | os.path.join(save_path, "audio_preprocess")
205 | ) as audio_processor:
206 | audio_emb = audio_processor.preprocess(driving_audio_path)
207 |
208 | # Clear memory
209 | # del image_processor
210 | del audio_processor
211 | gc.collect()
212 | torch.cuda.empty_cache()
213 |
214 |
215 | # 4. build modules
216 | sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
217 | if config.enable_zero_snr:
218 | sched_kwargs.update(
219 | rescale_betas_zero_snr=True,
220 | timestep_spacing="trailing",
221 | prediction_type="v_prediction",
222 | )
223 | val_noise_scheduler = DDIMScheduler(**sched_kwargs)
224 | sched_kwargs.update({"beta_schedule": "scaled_linear"})
225 |
226 | vae = AutoencoderKL.from_pretrained(config.vae.model_path)
227 | reference_unet = UNet2DConditionModel.from_pretrained(
228 | config.base_model_path, subfolder="unet")
229 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
230 | config.base_model_path,
231 | config.motion_module_path,
232 | subfolder="unet",
233 | unet_additional_kwargs=OmegaConf.to_container(
234 | config.unet_additional_kwargs),
235 | use_landmark=False,
236 | )
237 | face_locator = FaceLocator(conditioning_embedding_channels=320)
238 | image_proj = ImageProjModel(
239 | cross_attention_dim=denoising_unet.config.cross_attention_dim,
240 | clip_embeddings_dim=512,
241 | clip_extra_context_tokens=4,
242 | )
243 |
244 | audio_proj = AudioProjModel(
245 | seq_len=5,
246 | blocks=12, # use 12 layers' hidden states of wav2vec
247 | channels=768, # audio embedding channel
248 | intermediate_dim=512,
249 | output_dim=768,
250 | context_tokens=32,
251 | ).to(device=device, dtype=weight_dtype)
252 |
253 | audio_ckpt_dir = config.audio_ckpt_dir
254 |
255 |
256 | # Freeze
257 | vae.requires_grad_(False)
258 | image_proj.requires_grad_(False)
259 | reference_unet.requires_grad_(False)
260 | denoising_unet.requires_grad_(False)
261 | face_locator.requires_grad_(False)
262 | audio_proj.requires_grad_(False)
263 |
264 | # Not working soryy :(
265 | # if is_xformers_available():
266 | # reference_unet.enable_xformers_memory_efficient_attention()
267 | # denoising_unet.enable_xformers_memory_efficient_attention()
268 |
269 | reference_unet.enable_gradient_checkpointing()
270 | denoising_unet.enable_gradient_checkpointing()
271 |
272 | net = Net(
273 | reference_unet,
274 | denoising_unet,
275 | face_locator,
276 | image_proj,
277 | audio_proj,
278 | )
279 |
280 | m,u = net.load_state_dict(
281 | torch.load(
282 | os.path.join(audio_ckpt_dir, "net.pth"),
283 | map_location="cpu",
284 | ),
285 | )
286 | assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
287 | print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
288 |
289 | # 5. inference
290 | pipeline = FaceAnimatePipeline(
291 | vae=vae,
292 | reference_unet=net.reference_unet,
293 | denoising_unet=net.denoising_unet,
294 | face_locator=net.face_locator,
295 | scheduler=val_noise_scheduler,
296 | image_proj=net.imageproj,
297 | )
298 | pipeline.to(device=device, dtype=weight_dtype)
299 |
300 | audio_emb = process_audio_emb(audio_emb)
301 |
302 | source_image_pixels = source_image_pixels.unsqueeze(0)
303 | source_image_face_region = source_image_face_region.unsqueeze(0)
304 | source_image_face_emb = source_image_face_emb.reshape(1, -1)
305 | source_image_face_emb = torch.tensor(source_image_face_emb)
306 |
307 | source_image_full_mask = [
308 | (mask.repeat(clip_length, 1))
309 | for mask in source_image_full_mask
310 | ]
311 | source_image_face_mask = [
312 | (mask.repeat(clip_length, 1))
313 | for mask in source_image_face_mask
314 | ]
315 | source_image_lip_mask = [
316 | (mask.repeat(clip_length, 1))
317 | for mask in source_image_lip_mask
318 | ]
319 |
320 |
321 | times = audio_emb.shape[0] // clip_length
322 |
323 | tensor_result = []
324 |
325 | generator = torch.manual_seed(42)
326 |
327 | for t in range(times):
328 |
329 | if len(tensor_result) == 0:
330 | # The first iteration
331 | motion_zeros = source_image_pixels.repeat(
332 | config.data.n_motion_frames, 1, 1, 1)
333 | motion_zeros = motion_zeros.to(
334 | dtype=source_image_pixels.dtype, device=source_image_pixels.device)
335 | pixel_values_ref_img = torch.cat(
336 | [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
337 | else:
338 | motion_frames = tensor_result[-1][0]
339 | motion_frames = motion_frames.permute(1, 0, 2, 3)
340 | motion_frames = motion_frames[0-config.data.n_motion_frames:]
341 | motion_frames = motion_frames * 2.0 - 1.0
342 | motion_frames = motion_frames.to(
343 | dtype=source_image_pixels.dtype, device=source_image_pixels.device)
344 | pixel_values_ref_img = torch.cat(
345 | [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
346 |
347 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
348 |
349 | audio_tensor = audio_emb[
350 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
351 | ]
352 | audio_tensor = audio_tensor.unsqueeze(0)
353 | audio_tensor = audio_tensor.to(
354 | device=net.audioproj.device, dtype=net.audioproj.dtype)
355 | audio_tensor = net.audioproj(audio_tensor)
356 |
357 | # Get all params
358 | print(
359 | f"""
360 | inference {t+1} / {times}
361 | """
362 | )
363 |
364 | pipeline_output = pipeline(
365 | ref_image=pixel_values_ref_img,
366 | audio_tensor=audio_tensor,
367 | face_emb=source_image_face_emb,
368 | face_mask=source_image_face_region,
369 | pixel_values_full_mask=source_image_full_mask,
370 | pixel_values_face_mask=source_image_face_mask,
371 | pixel_values_lip_mask=source_image_lip_mask,
372 | width=img_size[0],
373 | height=img_size[1],
374 | video_length=clip_length,
375 | num_inference_steps=config.inference_steps,
376 | guidance_scale=config.cfg_scale,
377 | generator=generator,
378 | motion_scale=motion_scale,
379 | )
380 |
381 | tensor_result.append(pipeline_output.videos)
382 |
383 | tensor_result = torch.cat(tensor_result, dim=2)
384 | tensor_result = tensor_result.squeeze(0)
385 |
386 | output_file = config.output
387 | # save the result after all iteration
388 | tensor_to_video(tensor_result, output_file, driving_audio_path)
389 |
390 |
391 | if __name__ == "__main__":
392 | parser = argparse.ArgumentParser()
393 |
394 | parser.add_argument(
395 | "-c", "--config", default="configs/inference/default.yaml")
396 | parser.add_argument("--source_image", type=str, required=False,
397 | help="source image", default="test_data/source_images/6.jpg")
398 | parser.add_argument("--driving_audio", type=str, required=False,
399 | help="driving audio", default="test_data/driving_audios/singing/sing_4.wav")
400 | parser.add_argument(
401 | "--output", type=str, help="output video file name", default=".cache/output.mp4")
402 | parser.add_argument(
403 | "--pose_weight", type=float, help="weight of pose", default=1.0)
404 | parser.add_argument(
405 | "--face_weight", type=float, help="weight of face", default=1.0)
406 | parser.add_argument(
407 | "--lip_weight", type=float, help="weight of lip", default=1.0)
408 | parser.add_argument(
409 | "--face_expand_ratio", type=float, help="face region", default=1.2)
410 | parser.add_argument(
411 | "--checkpoint", type=str, help="which checkpoint", default=None)
412 | parser.add_argument("--setting_steps", type=int, default=40)
413 | parser.add_argument("--setting_cfg", type=float, default=3.5)
414 | parser.add_argument("--settings_seed", type=int, default=42)
415 | parser.add_argument("--settings_fps", type=int, default=25)
416 | parser.add_argument("--settings_motion_pose_scale", type=float, default=1.1)
417 | parser.add_argument("--settings_motion_face_scale", type=float, default=1.1)
418 | parser.add_argument("--settings_motion_lip_scale", type=float, default=1.1)
419 | parser.add_argument("--settings_n_motion_frames", type=int, default=2)
420 | parser.add_argument("--settings_n_sample_frames", type=int, default=16)
421 |
422 | command_line_args = parser.parse_args()
423 |
424 | inference_process(
425 | command_line_args,
426 | command_line_args.setting_steps,
427 | command_line_args.setting_cfg,
428 | command_line_args.settings_seed,
429 | command_line_args.settings_fps,
430 | command_line_args.settings_motion_pose_scale,
431 | command_line_args.settings_motion_face_scale,
432 | command_line_args.settings_motion_lip_scale,
433 | command_line_args.settings_n_motion_frames,
434 | command_line_args.settings_n_sample_frames
435 | )
436 |
--------------------------------------------------------------------------------
/hallo/animate/face_animate_static.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques.
4 | It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments.
5 | The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance.
6 |
7 | Functions and Classes:
8 | - StaticPipelineOutput: A class that represents the output of the animation pipeline, c
9 | ontaining properties and methods related to the generated images.
10 | - prepare_latents: A function that prepares the initial noise for the animation process,
11 | scaling it according to the scheduler's requirements.
12 | - prepare_condition: A function that processes the user-provided conditions
13 | (e.g., facial expressions) and prepares them for use in the animation pipeline.
14 | - decode_latents: A function that decodes the latent representations of the face animations into
15 | their corresponding image formats.
16 | - prepare_extra_step_kwargs: A function that prepares additional parameters for each step of
17 | the animation process, such as the generator and eta values.
18 |
19 | Dependencies:
20 | - numpy: A library for numerical computing.
21 | - torch: A machine learning library based on PyTorch.
22 | - diffusers: A library for image-to-image diffusion models.
23 | - transformers: A library for pre-trained transformer models.
24 |
25 | Usage:
26 | - To create an instance of the animation pipeline, provide the necessary components such as
27 | the VAE, reference UNET, denoising UNET, face locator, and image processor.
28 | - Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as
29 | required for the animation process.
30 | - Generate the face animations by decoding the latents and processing the conditions.
31 |
32 | Note:
33 | - The module is designed to work with the diffusers library, which is based on
34 | the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765).
35 | - The face animations generated by this module should be used for entertainment purposes
36 | only and should respect the rights and privacy of the individuals involved.
37 | """
38 | import inspect
39 | from dataclasses import dataclass
40 | from typing import Callable, List, Optional, Union
41 |
42 | import numpy as np
43 | import torch
44 | from diffusers import DiffusionPipeline
45 | from diffusers.image_processor import VaeImageProcessor
46 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
47 | EulerAncestralDiscreteScheduler,
48 | EulerDiscreteScheduler, LMSDiscreteScheduler,
49 | PNDMScheduler)
50 | from diffusers.utils import BaseOutput, is_accelerate_available
51 | from diffusers.utils.torch_utils import randn_tensor
52 | from einops import rearrange
53 | from tqdm import tqdm
54 | from transformers import CLIPImageProcessor
55 |
56 | from hallo.models.mutual_self_attention import ReferenceAttentionControl
57 |
58 | if is_accelerate_available():
59 | from accelerate import cpu_offload
60 | else:
61 | raise ImportError("Please install accelerate via `pip install accelerate`")
62 |
63 |
64 | @dataclass
65 | class StaticPipelineOutput(BaseOutput):
66 | """
67 | StaticPipelineOutput is a class that represents the output of the static pipeline.
68 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
69 |
70 | Attributes:
71 | images (Union[torch.Tensor, np.ndarray]): The generated images.
72 | """
73 | images: Union[torch.Tensor, np.ndarray]
74 |
75 |
76 | class StaticPipeline(DiffusionPipeline):
77 | """
78 | StaticPipelineOutput is a class that represents the output of the static pipeline.
79 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
80 |
81 | Attributes:
82 | images (Union[torch.Tensor, np.ndarray]): The generated images.
83 | """
84 | _optional_components = []
85 |
86 | def __init__(
87 | self,
88 | vae,
89 | reference_unet,
90 | denoising_unet,
91 | face_locator,
92 | imageproj,
93 | scheduler: Union[
94 | DDIMScheduler,
95 | PNDMScheduler,
96 | LMSDiscreteScheduler,
97 | EulerDiscreteScheduler,
98 | EulerAncestralDiscreteScheduler,
99 | DPMSolverMultistepScheduler,
100 | ],
101 | ):
102 | super().__init__()
103 |
104 | self.register_modules(
105 | vae=vae,
106 | reference_unet=reference_unet,
107 | denoising_unet=denoising_unet,
108 | face_locator=face_locator,
109 | scheduler=scheduler,
110 | imageproj=imageproj,
111 | )
112 | self.vae_scale_factor = 2 ** (
113 | len(self.vae.config.block_out_channels) - 1)
114 | self.clip_image_processor = CLIPImageProcessor()
115 | self.ref_image_processor = VaeImageProcessor(
116 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
117 | )
118 | self.cond_image_processor = VaeImageProcessor(
119 | vae_scale_factor=self.vae_scale_factor,
120 | do_convert_rgb=True,
121 | do_normalize=False,
122 | )
123 |
124 | def enable_vae_slicing(self):
125 | """
126 | Enable VAE slicing.
127 |
128 | This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images.
129 | """
130 | self.vae.enable_slicing()
131 |
132 | def disable_vae_slicing(self):
133 | """
134 | Disable vae slicing.
135 |
136 | This function disables the vae slicing for the StaticPipeline object.
137 | It calls the `disable_slicing()` method of the vae model.
138 | This is useful when you want to use the entire vae model for decoding latents
139 | instead of slicing it for better performance.
140 | """
141 | self.vae.disable_slicing()
142 |
143 | def enable_sequential_cpu_offload(self, gpu_id=0):
144 | """
145 | Offloads selected models to the GPU for increased performance.
146 |
147 | Args:
148 | gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0.
149 | """
150 | device = torch.device(f"cuda:{gpu_id}")
151 |
152 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
153 | if cpu_offloaded_model is not None:
154 | cpu_offload(cpu_offloaded_model, device)
155 |
156 | @property
157 | def _execution_device(self):
158 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
159 | return self.device
160 | for module in self.unet.modules():
161 | if (
162 | hasattr(module, "_hf_hook")
163 | and hasattr(module._hf_hook, "execution_device")
164 | and module._hf_hook.execution_device is not None
165 | ):
166 | return torch.device(module._hf_hook.execution_device)
167 | return self.device
168 |
169 | def decode_latents(self, latents):
170 | """
171 | Decode the given latents to video frames.
172 |
173 | Parameters:
174 | latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width).
175 |
176 | Returns:
177 | video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width).
178 | """
179 | video_length = latents.shape[2]
180 | latents = 1 / 0.18215 * latents
181 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
182 | # video = self.vae.decode(latents).sample
183 | video = []
184 | for frame_idx in tqdm(range(latents.shape[0])):
185 | video.append(self.vae.decode(
186 | latents[frame_idx: frame_idx + 1]).sample)
187 | video = torch.cat(video)
188 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
189 | video = (video / 2 + 0.5).clamp(0, 1)
190 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
191 | video = video.cpu().float().numpy()
192 | return video
193 |
194 | def prepare_extra_step_kwargs(self, generator, eta):
195 | """
196 | Prepare extra keyword arguments for the scheduler step.
197 |
198 | Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler.
199 |
200 | Args:
201 | generator (Optional[torch.Generator]): A random number generator for reproducibility.
202 | eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1.
203 |
204 | Returns:
205 | dict: A dictionary containing the extra keyword arguments for the scheduler step.
206 | """
207 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
208 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
209 | # and should be between [0, 1]
210 |
211 | accepts_eta = "eta" in set(
212 | inspect.signature(self.scheduler.step).parameters.keys()
213 | )
214 | extra_step_kwargs = {}
215 | if accepts_eta:
216 | extra_step_kwargs["eta"] = eta
217 |
218 | # check if the scheduler accepts generator
219 | accepts_generator = "generator" in set(
220 | inspect.signature(self.scheduler.step).parameters.keys()
221 | )
222 | if accepts_generator:
223 | extra_step_kwargs["generator"] = generator
224 | return extra_step_kwargs
225 |
226 | def prepare_latents(
227 | self,
228 | batch_size,
229 | num_channels_latents,
230 | width,
231 | height,
232 | dtype,
233 | device,
234 | generator,
235 | latents=None,
236 | ):
237 | """
238 | Prepares the initial latents for the diffusion pipeline.
239 |
240 | Args:
241 | batch_size (int): The number of images to generate in one forward pass.
242 | num_channels_latents (int): The number of channels in the latents tensor.
243 | width (int): The width of the latents tensor.
244 | height (int): The height of the latents tensor.
245 | dtype (torch.dtype): The data type of the latents tensor.
246 | device (torch.device): The device to place the latents tensor on.
247 | generator (Optional[torch.Generator], optional): A random number generator
248 | for reproducibility. Defaults to None.
249 | latents (Optional[torch.Tensor], optional): Pre-computed latents to use as
250 | initial conditions for the diffusion process. Defaults to None.
251 |
252 | Returns:
253 | torch.Tensor: The prepared latents tensor.
254 | """
255 | shape = (
256 | batch_size,
257 | num_channels_latents,
258 | height // self.vae_scale_factor,
259 | width // self.vae_scale_factor,
260 | )
261 | if isinstance(generator, list) and len(generator) != batch_size:
262 | raise ValueError(
263 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
264 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
265 | )
266 |
267 | if latents is None:
268 | latents = randn_tensor(
269 | shape, generator=generator, device=device, dtype=dtype
270 | )
271 | else:
272 | latents = latents.to(device)
273 |
274 | # scale the initial noise by the standard deviation required by the scheduler
275 | latents = latents * self.scheduler.init_noise_sigma
276 | return latents
277 |
278 | def prepare_condition(
279 | self,
280 | cond_image,
281 | width,
282 | height,
283 | device,
284 | dtype,
285 | do_classififer_free_guidance=False,
286 | ):
287 | """
288 | Prepares the condition for the face animation pipeline.
289 |
290 | Args:
291 | cond_image (torch.Tensor): The conditional image tensor.
292 | width (int): The width of the output image.
293 | height (int): The height of the output image.
294 | device (torch.device): The device to run the pipeline on.
295 | dtype (torch.dtype): The data type of the tensor.
296 | do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False.
297 |
298 | Returns:
299 | Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors.
300 | """
301 | image = self.cond_image_processor.preprocess(
302 | cond_image, height=height, width=width
303 | ).to(dtype=torch.float32)
304 |
305 | image = image.to(device=device, dtype=dtype)
306 |
307 | if do_classififer_free_guidance:
308 | image = torch.cat([image] * 2)
309 |
310 | return image
311 |
312 | @torch.no_grad()
313 | def __call__(
314 | self,
315 | ref_image,
316 | face_mask,
317 | width,
318 | height,
319 | num_inference_steps,
320 | guidance_scale,
321 | face_embedding,
322 | num_images_per_prompt=1,
323 | eta: float = 0.0,
324 | generator: Optional[Union[torch.Generator,
325 | List[torch.Generator]]] = None,
326 | output_type: Optional[str] = "tensor",
327 | return_dict: bool = True,
328 | callback: Optional[Callable[[
329 | int, int, torch.FloatTensor], None]] = None,
330 | callback_steps: Optional[int] = 1,
331 | **kwargs,
332 | ):
333 | # Default height and width to unet
334 | height = height or self.unet.config.sample_size * self.vae_scale_factor
335 | width = width or self.unet.config.sample_size * self.vae_scale_factor
336 |
337 | device = self._execution_device
338 |
339 | do_classifier_free_guidance = guidance_scale > 1.0
340 |
341 | # Prepare timesteps
342 | self.scheduler.set_timesteps(num_inference_steps, device=device)
343 | timesteps = self.scheduler.timesteps
344 |
345 | batch_size = 1
346 |
347 | image_prompt_embeds = self.imageproj(face_embedding)
348 | uncond_image_prompt_embeds = self.imageproj(
349 | torch.zeros_like(face_embedding))
350 |
351 | if do_classifier_free_guidance:
352 | image_prompt_embeds = torch.cat(
353 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
354 | )
355 |
356 | reference_control_writer = ReferenceAttentionControl(
357 | self.reference_unet,
358 | do_classifier_free_guidance=do_classifier_free_guidance,
359 | mode="write",
360 | batch_size=batch_size,
361 | fusion_blocks="full",
362 | )
363 | reference_control_reader = ReferenceAttentionControl(
364 | self.denoising_unet,
365 | do_classifier_free_guidance=do_classifier_free_guidance,
366 | mode="read",
367 | batch_size=batch_size,
368 | fusion_blocks="full",
369 | )
370 |
371 | num_channels_latents = self.denoising_unet.in_channels
372 | latents = self.prepare_latents(
373 | batch_size * num_images_per_prompt,
374 | num_channels_latents,
375 | width,
376 | height,
377 | face_embedding.dtype,
378 | device,
379 | generator,
380 | )
381 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
382 | # latents_dtype = latents.dtype
383 |
384 | # Prepare extra step kwargs.
385 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
386 |
387 | # Prepare ref image latents
388 | ref_image_tensor = self.ref_image_processor.preprocess(
389 | ref_image, height=height, width=width
390 | ) # (bs, c, width, height)
391 | ref_image_tensor = ref_image_tensor.to(
392 | dtype=self.vae.dtype, device=self.vae.device
393 | )
394 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
395 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
396 |
397 | # Prepare face mask image
398 | face_mask_tensor = self.cond_image_processor.preprocess(
399 | face_mask, height=height, width=width
400 | )
401 | face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w)
402 | face_mask_tensor = face_mask_tensor.to(
403 | device=device, dtype=self.face_locator.dtype
404 | )
405 | mask_fea = self.face_locator(face_mask_tensor)
406 | mask_fea = (
407 | torch.cat(
408 | [mask_fea] * 2) if do_classifier_free_guidance else mask_fea
409 | )
410 |
411 | # denoising loop
412 | num_warmup_steps = len(timesteps) - \
413 | num_inference_steps * self.scheduler.order
414 | with self.progress_bar(total=num_inference_steps) as progress_bar:
415 | for i, t in enumerate(timesteps):
416 | # 1. Forward reference image
417 | if i == 0:
418 | self.reference_unet(
419 | ref_image_latents.repeat(
420 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
421 | ),
422 | torch.zeros_like(t),
423 | encoder_hidden_states=image_prompt_embeds,
424 | return_dict=False,
425 | )
426 |
427 | # 2. Update reference unet feature into denosing net
428 | reference_control_reader.update(reference_control_writer)
429 |
430 | # 3.1 expand the latents if we are doing classifier free guidance
431 | latent_model_input = (
432 | torch.cat(
433 | [latents] * 2) if do_classifier_free_guidance else latents
434 | )
435 | latent_model_input = self.scheduler.scale_model_input(
436 | latent_model_input, t
437 | )
438 |
439 | noise_pred = self.denoising_unet(
440 | latent_model_input,
441 | t,
442 | encoder_hidden_states=image_prompt_embeds,
443 | mask_cond_fea=mask_fea,
444 | return_dict=False,
445 | )[0]
446 |
447 | # perform guidance
448 | if do_classifier_free_guidance:
449 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
450 | noise_pred = noise_pred_uncond + guidance_scale * (
451 | noise_pred_text - noise_pred_uncond
452 | )
453 |
454 | # compute the previous noisy sample x_t -> x_t-1
455 | latents = self.scheduler.step(
456 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
457 | )[0]
458 |
459 | # call the callback, if provided
460 | if i == len(timesteps) - 1 or (
461 | (i + 1) > num_warmup_steps and (i +
462 | 1) % self.scheduler.order == 0
463 | ):
464 | progress_bar.update()
465 | if callback is not None and i % callback_steps == 0:
466 | step_idx = i // getattr(self.scheduler, "order", 1)
467 | callback(step_idx, t, latents)
468 | reference_control_reader.clear()
469 | reference_control_writer.clear()
470 |
471 | # Post-processing
472 | image = self.decode_latents(latents) # (b, c, 1, h, w)
473 |
474 | # Convert to tensor
475 | if output_type == "tensor":
476 | image = torch.from_numpy(image)
477 |
478 | if not return_dict:
479 | return image
480 |
481 | return StaticPipelineOutput(images=image)
482 |
--------------------------------------------------------------------------------
/hallo/animate/face_animate.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is responsible for animating faces in videos using a combination of deep learning techniques.
4 | It provides a pipeline for generating face animations by processing video frames and extracting face features.
5 | The module utilizes various schedulers and utilities for efficient face animation and supports different types
6 | of latents for more control over the animation process.
7 |
8 | Functions and Classes:
9 | - FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks.
10 | - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.).
11 | - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements.
12 | - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers.
13 | - decode_latents: Decodes the latents into video frames, ready for animation.
14 |
15 | Usage:
16 | - Import the necessary packages and classes.
17 | - Create a FaceAnimatePipeline instance with the required components.
18 | - Prepare the latents for the animation process.
19 | - Use the pipeline to generate the animated video.
20 |
21 | Note:
22 | - This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning.
23 | - The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases.
24 | """
25 |
26 | import inspect
27 | from dataclasses import dataclass
28 | from typing import Callable, List, Optional, Union
29 |
30 | import numpy as np
31 | import torch
32 | from diffusers import (DDIMScheduler, DiffusionPipeline,
33 | DPMSolverMultistepScheduler,
34 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
35 | LMSDiscreteScheduler, PNDMScheduler)
36 | from diffusers.image_processor import VaeImageProcessor
37 | from diffusers.utils import BaseOutput
38 | from diffusers.utils.torch_utils import randn_tensor
39 | from einops import rearrange, repeat
40 | from tqdm import tqdm
41 |
42 | from hallo.models.mutual_self_attention import ReferenceAttentionControl
43 |
44 |
45 | @dataclass
46 | class FaceAnimatePipelineOutput(BaseOutput):
47 | """
48 | FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline.
49 |
50 | Attributes:
51 | videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames.
52 |
53 | Methods:
54 | __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames.
55 | """
56 | videos: Union[torch.Tensor, np.ndarray]
57 |
58 | class FaceAnimatePipeline(DiffusionPipeline):
59 | """
60 | FaceAnimatePipeline is a custom DiffusionPipeline for animating faces.
61 |
62 | It inherits from the DiffusionPipeline class and is used to animate faces by
63 | utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet,
64 | a face locator, and an image processor. The pipeline is responsible for generating
65 | and animating face latents, and decoding the latents to produce the final video output.
66 |
67 | Attributes:
68 | vae (VaeImageProcessor): Variational autoencoder for processing images.
69 | reference_unet (nn.Module): Reference UNet for mutual self-attention.
70 | denoising_unet (nn.Module): Denoising UNet for image denoising.
71 | face_locator (nn.Module): Face locator for detecting and cropping faces.
72 | image_proj (nn.Module): Image projector for processing images.
73 | scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
74 | EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
75 | DPMSolverMultistepScheduler]): Diffusion scheduler for
76 | controlling the noise level.
77 |
78 | Methods:
79 | __init__(self, vae, reference_unet, denoising_unet, face_locator,
80 | image_proj, scheduler): Initializes the FaceAnimatePipeline
81 | with the given components and scheduler.
82 | prepare_latents(self, batch_size, num_channels_latents, width, height,
83 | video_length, dtype, device, generator=None, latents=None):
84 | Prepares the initial latents for video generation.
85 | prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword
86 | arguments for the scheduler step.
87 | decode_latents(self, latents): Decodes the latents to produce the final
88 | video output.
89 | """
90 | def __init__(
91 | self,
92 | vae,
93 | reference_unet,
94 | denoising_unet,
95 | face_locator,
96 | image_proj,
97 | scheduler: Union[
98 | DDIMScheduler,
99 | PNDMScheduler,
100 | LMSDiscreteScheduler,
101 | EulerDiscreteScheduler,
102 | EulerAncestralDiscreteScheduler,
103 | DPMSolverMultistepScheduler,
104 | ],
105 | ) -> None:
106 | super().__init__()
107 |
108 | self.register_modules(
109 | vae=vae,
110 | reference_unet=reference_unet,
111 | denoising_unet=denoising_unet,
112 | face_locator=face_locator,
113 | scheduler=scheduler,
114 | image_proj=image_proj,
115 | )
116 |
117 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
118 |
119 | self.ref_image_processor = VaeImageProcessor(
120 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True,
121 | )
122 |
123 | @property
124 | def _execution_device(self):
125 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
126 | return self.device
127 | for module in self.unet.modules():
128 | if (
129 | hasattr(module, "_hf_hook")
130 | and hasattr(module._hf_hook, "execution_device")
131 | and module._hf_hook.execution_device is not None
132 | ):
133 | return torch.device(module._hf_hook.execution_device)
134 | return self.device
135 |
136 | def prepare_latents(
137 | self,
138 | batch_size: int, # Number of videos to generate in parallel
139 | num_channels_latents: int, # Number of channels in the latents
140 | width: int, # Width of the video frame
141 | height: int, # Height of the video frame
142 | video_length: int, # Length of the video in frames
143 | dtype: torch.dtype, # Data type of the latents
144 | device: torch.device, # Device to store the latents on
145 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
146 | latents: Optional[torch.Tensor] = None # Pre-generated latents (optional)
147 | ):
148 | """
149 | Prepares the initial latents for video generation.
150 |
151 | Args:
152 | batch_size (int): Number of videos to generate in parallel.
153 | num_channels_latents (int): Number of channels in the latents.
154 | width (int): Width of the video frame.
155 | height (int): Height of the video frame.
156 | video_length (int): Length of the video in frames.
157 | dtype (torch.dtype): Data type of the latents.
158 | device (torch.device): Device to store the latents on.
159 | generator (Optional[torch.Generator]): Random number generator for reproducibility.
160 | latents (Optional[torch.Tensor]): Pre-generated latents (optional).
161 |
162 | Returns:
163 | latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height)
164 | containing the initial latents for video generation.
165 | """
166 | shape = (
167 | batch_size,
168 | num_channels_latents,
169 | video_length,
170 | height // self.vae_scale_factor,
171 | width // self.vae_scale_factor,
172 | )
173 | if isinstance(generator, list) and len(generator) != batch_size:
174 | raise ValueError(
175 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
176 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
177 | )
178 |
179 | if latents is None:
180 | latents = randn_tensor(
181 | shape, generator=generator, device=device, dtype=dtype
182 | )
183 | else:
184 | latents = latents.to(device)
185 |
186 | # scale the initial noise by the standard deviation required by the scheduler
187 | latents = latents * self.scheduler.init_noise_sigma
188 | return latents
189 |
190 | def prepare_extra_step_kwargs(self, generator, eta):
191 | """
192 | Prepares extra keyword arguments for the scheduler step.
193 |
194 | Args:
195 | generator (Optional[torch.Generator]): Random number generator for reproducibility.
196 | eta (float): The eta (η) parameter used with the DDIMScheduler.
197 | It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1].
198 |
199 | Returns:
200 | dict: A dictionary containing the extra keyword arguments for the scheduler step.
201 | """
202 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
203 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
204 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
205 | # and should be between [0, 1]
206 |
207 | accepts_eta = "eta" in set(
208 | inspect.signature(self.scheduler.step).parameters.keys()
209 | )
210 | extra_step_kwargs = {}
211 | if accepts_eta:
212 | extra_step_kwargs["eta"] = eta
213 |
214 | # check if the scheduler accepts generator
215 | accepts_generator = "generator" in set(
216 | inspect.signature(self.scheduler.step).parameters.keys()
217 | )
218 | if accepts_generator:
219 | extra_step_kwargs["generator"] = generator
220 | return extra_step_kwargs
221 |
222 | def decode_latents(self, latents):
223 | """
224 | Decode the latents to produce a video.
225 |
226 | Parameters:
227 | latents (torch.Tensor): The latents to be decoded.
228 |
229 | Returns:
230 | video (torch.Tensor): The decoded video.
231 | video_length (int): The length of the video in frames.
232 | """
233 | video_length = latents.shape[2]
234 | latents = 1 / 0.18215 * latents
235 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
236 | # video = self.vae.decode(latents).sample
237 | video = []
238 | for frame_idx in tqdm(range(latents.shape[0])):
239 | video.append(self.vae.decode(
240 | latents[frame_idx: frame_idx + 1]).sample)
241 | video = torch.cat(video)
242 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
243 | video = (video / 2 + 0.5).clamp(0, 1)
244 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
245 | video = video.cpu().float().numpy()
246 | return video
247 |
248 |
249 | @torch.no_grad()
250 | def __call__(
251 | self,
252 | ref_image,
253 | face_emb,
254 | audio_tensor,
255 | face_mask,
256 | pixel_values_full_mask,
257 | pixel_values_face_mask,
258 | pixel_values_lip_mask,
259 | width,
260 | height,
261 | video_length,
262 | num_inference_steps,
263 | guidance_scale,
264 | num_images_per_prompt=1,
265 | eta: float = 0.0,
266 | motion_scale: Optional[List[torch.Tensor]] = None,
267 | generator: Optional[Union[torch.Generator,
268 | List[torch.Generator]]] = None,
269 | output_type: Optional[str] = "tensor",
270 | return_dict: bool = True,
271 | callback: Optional[Callable[[
272 | int, int, torch.FloatTensor], None]] = None,
273 | callback_steps: Optional[int] = 1,
274 | **kwargs,
275 | ):
276 | # Default height and width to unet
277 | height = height or self.unet.config.sample_size * self.vae_scale_factor
278 | width = width or self.unet.config.sample_size * self.vae_scale_factor
279 |
280 | device = self._execution_device
281 |
282 | do_classifier_free_guidance = guidance_scale > 1.0
283 |
284 | # Prepare timesteps
285 | self.scheduler.set_timesteps(num_inference_steps, device=device)
286 | timesteps = self.scheduler.timesteps
287 |
288 | batch_size = 1
289 |
290 | # prepare clip image embeddings
291 | clip_image_embeds = face_emb
292 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
293 |
294 | encoder_hidden_states = self.image_proj(clip_image_embeds)
295 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
296 |
297 | if do_classifier_free_guidance:
298 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
299 |
300 | reference_control_writer = ReferenceAttentionControl(
301 | self.reference_unet,
302 | do_classifier_free_guidance=do_classifier_free_guidance,
303 | mode="write",
304 | batch_size=batch_size,
305 | fusion_blocks="full",
306 | )
307 | reference_control_reader = ReferenceAttentionControl(
308 | self.denoising_unet,
309 | do_classifier_free_guidance=do_classifier_free_guidance,
310 | mode="read",
311 | batch_size=batch_size,
312 | fusion_blocks="full",
313 | )
314 |
315 | num_channels_latents = self.denoising_unet.in_channels
316 |
317 | latents = self.prepare_latents(
318 | batch_size * num_images_per_prompt,
319 | num_channels_latents,
320 | width,
321 | height,
322 | video_length,
323 | clip_image_embeds.dtype,
324 | device,
325 | generator,
326 | )
327 |
328 | # Prepare extra step kwargs.
329 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
330 |
331 | # Prepare ref image latents
332 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
333 | ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
334 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
335 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
336 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
337 |
338 |
339 | face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
340 | face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
341 | face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W)
342 | face_mask = self.face_locator(face_mask)
343 | face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask
344 |
345 | pixel_values_full_mask = (
346 | [torch.cat([mask] * 2) for mask in pixel_values_full_mask]
347 | if do_classifier_free_guidance
348 | else pixel_values_full_mask
349 | )
350 | pixel_values_face_mask = (
351 | [torch.cat([mask] * 2) for mask in pixel_values_face_mask]
352 | if do_classifier_free_guidance
353 | else pixel_values_face_mask
354 | )
355 | pixel_values_lip_mask = (
356 | [torch.cat([mask] * 2) for mask in pixel_values_lip_mask]
357 | if do_classifier_free_guidance
358 | else pixel_values_lip_mask
359 | )
360 | pixel_values_face_mask_ = []
361 | for mask in pixel_values_face_mask:
362 | pixel_values_face_mask_.append(
363 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
364 | pixel_values_face_mask = pixel_values_face_mask_
365 | pixel_values_lip_mask_ = []
366 | for mask in pixel_values_lip_mask:
367 | pixel_values_lip_mask_.append(
368 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
369 | pixel_values_lip_mask = pixel_values_lip_mask_
370 | pixel_values_full_mask_ = []
371 | for mask in pixel_values_full_mask:
372 | pixel_values_full_mask_.append(
373 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
374 | pixel_values_full_mask = pixel_values_full_mask_
375 |
376 |
377 | uncond_audio_tensor = torch.zeros_like(audio_tensor)
378 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
379 | audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)
380 |
381 | # denoising loop
382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
383 | with self.progress_bar(total=num_inference_steps) as progress_bar:
384 | for i, t in enumerate(timesteps):
385 | # Forward reference image
386 | if i == 0:
387 | self.reference_unet(
388 | ref_image_latents.repeat(
389 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
390 | ),
391 | torch.zeros_like(t),
392 | encoder_hidden_states=encoder_hidden_states,
393 | return_dict=False,
394 | )
395 | reference_control_reader.update(reference_control_writer)
396 |
397 | # expand the latents if we are doing classifier free guidance
398 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
399 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
400 |
401 | noise_pred = self.denoising_unet(
402 | latent_model_input,
403 | t,
404 | encoder_hidden_states=encoder_hidden_states,
405 | mask_cond_fea=face_mask,
406 | full_mask=pixel_values_full_mask,
407 | face_mask=pixel_values_face_mask,
408 | lip_mask=pixel_values_lip_mask,
409 | audio_embedding=audio_tensor,
410 | motion_scale=motion_scale,
411 | return_dict=False,
412 | )[0]
413 |
414 | # perform guidance
415 | if do_classifier_free_guidance:
416 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
417 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
418 |
419 | # compute the previous noisy sample x_t -> x_t-1
420 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
421 |
422 | # call the callback, if provided
423 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
424 | progress_bar.update()
425 | if callback is not None and i % callback_steps == 0:
426 | step_idx = i // getattr(self.scheduler, "order", 1)
427 | callback(step_idx, t, latents)
428 |
429 | reference_control_reader.clear()
430 | reference_control_writer.clear()
431 |
432 | # Post-processing
433 | images = self.decode_latents(latents) # (b, c, f, h, w)
434 |
435 | # Convert to tensor
436 | if output_type == "tensor":
437 | images = torch.from_numpy(images)
438 |
439 | if not return_dict:
440 | return images
441 |
442 | return FaceAnimatePipelineOutput(videos=images)
443 |
--------------------------------------------------------------------------------
/hallo/models/transformer_2d.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 | # src/models/transformer_2d.py
3 |
4 | """
5 | This module defines the Transformer2DModel, a PyTorch model that extends ModelMixin and ConfigMixin. It includes
6 | methods for gradient checkpointing, forward propagation, and various utility functions. The model is designed for
7 | 2D image-related tasks and uses LoRa (Low-Rank All-Attention) compatible layers for efficient attention computation.
8 |
9 | The file includes the following import statements:
10 |
11 | - From dataclasses import dataclass
12 | - From typing import Any, Dict, Optional
13 | - Import torch
14 | - From diffusers.configuration_utils import ConfigMixin, register_to_config
15 | - From diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
16 | - From diffusers.models.modeling_utils import ModelMixin
17 | - From diffusers.models.normalization import AdaLayerNormSingle
18 | - From diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
19 | is_torch_version)
20 | - From torch import nn
21 | - From .attention import BasicTransformerBlock
22 |
23 | The file also includes the following classes and functions:
24 |
25 | - Transformer2DModel: A model class that extends ModelMixin and ConfigMixin. It includes methods for gradient
26 | checkpointing, forward propagation, and various utility functions.
27 | - _set_gradient_checkpointing: A utility function to set gradient checkpointing for a given module.
28 | - forward: The forward propagation method for the Transformer2DModel.
29 |
30 | To use this module, you can import the Transformer2DModel class and create an instance of the model with the desired
31 | configuration. Then, you can use the forward method to pass input tensors through the model and get the output tensors.
32 | """
33 |
34 | from dataclasses import dataclass
35 | from typing import Any, Dict, Optional
36 |
37 | import torch
38 | from diffusers.configuration_utils import ConfigMixin, register_to_config
39 | # from diffusers.models.embeddings import CaptionProjection
40 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
41 | from diffusers.models.modeling_utils import ModelMixin
42 | from diffusers.models.normalization import AdaLayerNormSingle
43 | from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
44 | is_torch_version)
45 | from torch import nn
46 |
47 | from .attention import BasicTransformerBlock
48 |
49 |
50 | @dataclass
51 | class Transformer2DModelOutput(BaseOutput):
52 | """
53 | The output of [`Transformer2DModel`].
54 |
55 | Args:
56 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`
57 | or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
58 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
59 | distributions for the unnoised latent pixels.
60 | """
61 |
62 | sample: torch.FloatTensor
63 | ref_feature: torch.FloatTensor
64 |
65 |
66 | class Transformer2DModel(ModelMixin, ConfigMixin):
67 | """
68 | A 2D Transformer model for image-like data.
69 |
70 | Parameters:
71 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
72 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
73 | in_channels (`int`, *optional*):
74 | The number of channels in the input and output (specify if the input is **continuous**).
75 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
76 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
77 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
78 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
79 | This is fixed during training since it is used to learn a number of position embeddings.
80 | num_vector_embeds (`int`, *optional*):
81 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
82 | Includes the class for the masked latent pixel.
83 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
84 | num_embeds_ada_norm ( `int`, *optional*):
85 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is
86 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
87 | added to the hidden states.
88 |
89 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
90 | attention_bias (`bool`, *optional*):
91 | Configure if the `TransformerBlocks` attention should contain a bias parameter.
92 | """
93 |
94 | _supports_gradient_checkpointing = True
95 |
96 | @register_to_config
97 | def __init__(
98 | self,
99 | num_attention_heads: int = 16,
100 | attention_head_dim: int = 88,
101 | in_channels: Optional[int] = None,
102 | out_channels: Optional[int] = None,
103 | num_layers: int = 1,
104 | dropout: float = 0.0,
105 | norm_num_groups: int = 32,
106 | cross_attention_dim: Optional[int] = None,
107 | attention_bias: bool = False,
108 | num_vector_embeds: Optional[int] = None,
109 | patch_size: Optional[int] = None,
110 | activation_fn: str = "geglu",
111 | num_embeds_ada_norm: Optional[int] = None,
112 | use_linear_projection: bool = False,
113 | only_cross_attention: bool = False,
114 | double_self_attention: bool = False,
115 | upcast_attention: bool = False,
116 | norm_type: str = "layer_norm",
117 | norm_elementwise_affine: bool = True,
118 | norm_eps: float = 1e-5,
119 | attention_type: str = "default",
120 | ):
121 | super().__init__()
122 | self.use_linear_projection = use_linear_projection
123 | self.num_attention_heads = num_attention_heads
124 | self.attention_head_dim = attention_head_dim
125 | inner_dim = num_attention_heads * attention_head_dim
126 |
127 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
128 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
129 |
130 | # 1. Transformer2DModel can process both standard continuous images of
131 | # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of
132 | # shape `(batch_size, num_image_vectors)`
133 | # Define whether input is continuous or discrete depending on configuration
134 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
135 | self.is_input_vectorized = num_vector_embeds is not None
136 | self.is_input_patches = in_channels is not None and patch_size is not None
137 |
138 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
139 | deprecation_message = (
140 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
141 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
142 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
143 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
144 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
145 | )
146 | deprecate(
147 | "norm_type!=num_embeds_ada_norm",
148 | "1.0.0",
149 | deprecation_message,
150 | standard_warn=False,
151 | )
152 | norm_type = "ada_norm"
153 |
154 | if self.is_input_continuous and self.is_input_vectorized:
155 | raise ValueError(
156 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
157 | " sure that either `in_channels` or `num_vector_embeds` is None."
158 | )
159 |
160 | if self.is_input_vectorized and self.is_input_patches:
161 | raise ValueError(
162 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
163 | " sure that either `num_vector_embeds` or `num_patches` is None."
164 | )
165 |
166 | if (
167 | not self.is_input_continuous
168 | and not self.is_input_vectorized
169 | and not self.is_input_patches
170 | ):
171 | raise ValueError(
172 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
173 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
174 | )
175 |
176 | # 2. Define input layers
177 | self.in_channels = in_channels
178 |
179 | self.norm = torch.nn.GroupNorm(
180 | num_groups=norm_num_groups,
181 | num_channels=in_channels,
182 | eps=1e-6,
183 | affine=True,
184 | )
185 | if use_linear_projection:
186 | self.proj_in = linear_cls(in_channels, inner_dim)
187 | else:
188 | self.proj_in = conv_cls(
189 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
190 | )
191 |
192 | # 3. Define transformers blocks
193 | self.transformer_blocks = nn.ModuleList(
194 | [
195 | BasicTransformerBlock(
196 | inner_dim,
197 | num_attention_heads,
198 | attention_head_dim,
199 | dropout=dropout,
200 | cross_attention_dim=cross_attention_dim,
201 | activation_fn=activation_fn,
202 | num_embeds_ada_norm=num_embeds_ada_norm,
203 | attention_bias=attention_bias,
204 | only_cross_attention=only_cross_attention,
205 | double_self_attention=double_self_attention,
206 | upcast_attention=upcast_attention,
207 | norm_type=norm_type,
208 | norm_elementwise_affine=norm_elementwise_affine,
209 | norm_eps=norm_eps,
210 | attention_type=attention_type,
211 | )
212 | for d in range(num_layers)
213 | ]
214 | )
215 |
216 | # 4. Define output layers
217 | self.out_channels = in_channels if out_channels is None else out_channels
218 | # TODO: should use out_channels for continuous projections
219 | if use_linear_projection:
220 | self.proj_out = linear_cls(inner_dim, in_channels)
221 | else:
222 | self.proj_out = conv_cls(
223 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
224 | )
225 |
226 | # 5. PixArt-Alpha blocks.
227 | self.adaln_single = None
228 | self.use_additional_conditions = False
229 | if norm_type == "ada_norm_single":
230 | self.use_additional_conditions = self.config.sample_size == 128
231 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
232 | # additional conditions until we find better name
233 | self.adaln_single = AdaLayerNormSingle(
234 | inner_dim, use_additional_conditions=self.use_additional_conditions
235 | )
236 |
237 | self.caption_projection = None
238 |
239 | self.gradient_checkpointing = False
240 |
241 | def _set_gradient_checkpointing(self, module, value=False):
242 | if hasattr(module, "gradient_checkpointing"):
243 | module.gradient_checkpointing = value
244 |
245 | def forward(
246 | self,
247 | hidden_states: torch.Tensor,
248 | encoder_hidden_states: Optional[torch.Tensor] = None,
249 | timestep: Optional[torch.LongTensor] = None,
250 | _added_cond_kwargs: Dict[str, torch.Tensor] = None,
251 | class_labels: Optional[torch.LongTensor] = None,
252 | cross_attention_kwargs: Dict[str, Any] = None,
253 | attention_mask: Optional[torch.Tensor] = None,
254 | encoder_attention_mask: Optional[torch.Tensor] = None,
255 | return_dict: bool = True,
256 | ):
257 | """
258 | The [`Transformer2DModel`] forward method.
259 |
260 | Args:
261 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
262 | `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263 | Input `hidden_states`.
264 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266 | self-attention.
267 | timestep ( `torch.LongTensor`, *optional*):
268 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271 | `AdaLayerZeroNorm`.
272 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274 | `self.processor` in
275 | [diffusers.models.attention_processor]
276 | (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
277 | attention_mask ( `torch.Tensor`, *optional*):
278 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
279 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
280 | negative values to the attention scores corresponding to "discard" tokens.
281 | encoder_attention_mask ( `torch.Tensor`, *optional*):
282 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
283 |
284 | * Mask `(batch, sequence_length)` True = keep, False = discard.
285 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
286 |
287 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
288 | above. This bias will be added to the cross-attention scores.
289 | return_dict (`bool`, *optional*, defaults to `True`):
290 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
291 | tuple.
292 |
293 | Returns:
294 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
295 | `tuple` where the first element is the sample tensor.
296 | """
297 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
298 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
299 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
300 | # expects mask of shape:
301 | # [batch, key_tokens]
302 | # adds singleton query_tokens dimension:
303 | # [batch, 1, key_tokens]
304 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
305 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
306 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
307 | if attention_mask is not None and attention_mask.ndim == 2:
308 | # assume that mask is expressed as:
309 | # (1 = keep, 0 = discard)
310 | # convert mask into a bias that can be added to attention scores:
311 | # (keep = +0, discard = -10000.0)
312 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
313 | attention_mask = attention_mask.unsqueeze(1)
314 |
315 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
316 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
317 | encoder_attention_mask = (
318 | 1 - encoder_attention_mask.to(hidden_states.dtype)
319 | ) * -10000.0
320 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
321 |
322 | # Retrieve lora scale.
323 | lora_scale = (
324 | cross_attention_kwargs.get("scale", 1.0)
325 | if cross_attention_kwargs is not None
326 | else 1.0
327 | )
328 |
329 | # 1. Input
330 | batch, _, height, width = hidden_states.shape
331 | residual = hidden_states
332 |
333 | hidden_states = self.norm(hidden_states)
334 | if not self.use_linear_projection:
335 | hidden_states = (
336 | self.proj_in(hidden_states, scale=lora_scale)
337 | if not USE_PEFT_BACKEND
338 | else self.proj_in(hidden_states)
339 | )
340 | inner_dim = hidden_states.shape[1]
341 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
342 | batch, height * width, inner_dim
343 | )
344 | else:
345 | inner_dim = hidden_states.shape[1]
346 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
347 | batch, height * width, inner_dim
348 | )
349 | hidden_states = (
350 | self.proj_in(hidden_states, scale=lora_scale)
351 | if not USE_PEFT_BACKEND
352 | else self.proj_in(hidden_states)
353 | )
354 |
355 | # 2. Blocks
356 | if self.caption_projection is not None:
357 | batch_size = hidden_states.shape[0]
358 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
359 | encoder_hidden_states = encoder_hidden_states.view(
360 | batch_size, -1, hidden_states.shape[-1]
361 | )
362 |
363 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
364 | for block in self.transformer_blocks:
365 | if self.training and self.gradient_checkpointing:
366 |
367 | def create_custom_forward(module, return_dict=None):
368 | def custom_forward(*inputs):
369 | if return_dict is not None:
370 | return module(*inputs, return_dict=return_dict)
371 |
372 | return module(*inputs)
373 |
374 | return custom_forward
375 |
376 | ckpt_kwargs: Dict[str, Any] = (
377 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378 | )
379 | hidden_states = torch.utils.checkpoint.checkpoint(
380 | create_custom_forward(block),
381 | hidden_states,
382 | attention_mask,
383 | encoder_hidden_states,
384 | encoder_attention_mask,
385 | timestep,
386 | cross_attention_kwargs,
387 | class_labels,
388 | **ckpt_kwargs,
389 | )
390 | else:
391 | hidden_states = block(
392 | hidden_states, # shape [5, 4096, 320]
393 | attention_mask=attention_mask,
394 | encoder_hidden_states=encoder_hidden_states, # shape [1,4,768]
395 | encoder_attention_mask=encoder_attention_mask,
396 | timestep=timestep,
397 | cross_attention_kwargs=cross_attention_kwargs,
398 | class_labels=class_labels,
399 | )
400 |
401 | # 3. Output
402 | output = None
403 | if self.is_input_continuous:
404 | if not self.use_linear_projection:
405 | hidden_states = (
406 | hidden_states.reshape(batch, height, width, inner_dim)
407 | .permute(0, 3, 1, 2)
408 | .contiguous()
409 | )
410 | hidden_states = (
411 | self.proj_out(hidden_states, scale=lora_scale)
412 | if not USE_PEFT_BACKEND
413 | else self.proj_out(hidden_states)
414 | )
415 | else:
416 | hidden_states = (
417 | self.proj_out(hidden_states, scale=lora_scale)
418 | if not USE_PEFT_BACKEND
419 | else self.proj_out(hidden_states)
420 | )
421 | hidden_states = (
422 | hidden_states.reshape(batch, height, width, inner_dim)
423 | .permute(0, 3, 1, 2)
424 | .contiguous()
425 | )
426 |
427 | output = hidden_states + residual
428 | if not return_dict:
429 | return (output, ref_feature)
430 |
431 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
432 |
--------------------------------------------------------------------------------
/hallo/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1120
2 | """
3 | This module contains the implementation of mutual self-attention,
4 | which is a type of attention mechanism used in deep learning models.
5 | The module includes several classes and functions related to attention mechanisms,
6 | such as BasicTransformerBlock and TemporalBasicTransformerBlock.
7 | The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning,
8 | such as image and video processing, natural language processing, and so on.
9 | """
10 |
11 | from typing import Any, Dict, Optional
12 |
13 | import torch
14 | from einops import rearrange
15 |
16 | from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
17 |
18 |
19 | def torch_dfs(model: torch.nn.Module):
20 | """
21 | Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture.
22 |
23 | This function recursively traverses all the children modules of a given PyTorch model and returns a list
24 | containing all the modules in the model's architecture. The DFS approach starts with the input model and
25 | explores its children modules depth-wise before backtracking and exploring other branches.
26 |
27 | Args:
28 | model (torch.nn.Module): The root module of the neural network to traverse.
29 |
30 | Returns:
31 | list: A list of all the modules in the model's architecture.
32 | """
33 | result = [model]
34 | for child in model.children():
35 | result += torch_dfs(child)
36 | return result
37 |
38 |
39 | class ReferenceAttentionControl:
40 | """
41 | This class is used to control the reference attention mechanism in a neural network model.
42 | It is responsible for managing the guidance and fusion blocks, and modifying the self-attention
43 | and group normalization mechanisms. The class also provides methods for registering reference hooks
44 | and updating/clearing the internal state of the attention control object.
45 |
46 | Attributes:
47 | unet: The UNet model associated with this attention control object.
48 | mode: The operating mode of the attention control object, either 'write' or 'read'.
49 | do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism.
50 | attention_auto_machine_weight: The weight assigned to the attention auto-machine.
51 | gn_auto_machine_weight: The weight assigned to the group normalization auto-machine.
52 | style_fidelity: The style fidelity parameter for the attention mechanism.
53 | reference_attn: Whether to use reference attention in the model.
54 | reference_adain: Whether to use reference AdaIN in the model.
55 | fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion').
56 | batch_size: The batch size used for processing video frames.
57 |
58 | Methods:
59 | register_reference_hooks: Registers the reference hooks for the attention control object.
60 | hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block.
61 | update: Updates the internal state of the attention control object using the provided writer and dtype.
62 | clear: Clears the internal state of the attention control object.
63 | """
64 | def __init__(
65 | self,
66 | unet,
67 | mode="write",
68 | do_classifier_free_guidance=False,
69 | attention_auto_machine_weight=float("inf"),
70 | gn_auto_machine_weight=1.0,
71 | style_fidelity=1.0,
72 | reference_attn=True,
73 | reference_adain=False,
74 | fusion_blocks="midup",
75 | batch_size=1,
76 | ) -> None:
77 | """
78 | Initializes the ReferenceAttentionControl class.
79 |
80 | Args:
81 | unet (torch.nn.Module): The UNet model.
82 | mode (str, optional): The mode of operation. Defaults to "write".
83 | do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False.
84 | attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity.
85 | gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0.
86 | style_fidelity (float, optional): The style fidelity. Defaults to 1.0.
87 | reference_attn (bool, optional): Whether to use reference attention. Defaults to True.
88 | reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False.
89 | fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup".
90 | batch_size (int, optional): The batch size. Defaults to 1.
91 |
92 | Raises:
93 | ValueError: If the mode is not recognized.
94 | ValueError: If the fusion blocks are not recognized.
95 | """
96 | # 10. Modify self attention and group norm
97 | self.unet = unet
98 | assert mode in ["read", "write"]
99 | assert fusion_blocks in ["midup", "full"]
100 | self.reference_attn = reference_attn
101 | self.reference_adain = reference_adain
102 | self.fusion_blocks = fusion_blocks
103 | self.register_reference_hooks(
104 | mode,
105 | do_classifier_free_guidance,
106 | attention_auto_machine_weight,
107 | gn_auto_machine_weight,
108 | style_fidelity,
109 | reference_attn,
110 | reference_adain,
111 | fusion_blocks,
112 | batch_size=batch_size,
113 | )
114 |
115 | def register_reference_hooks(
116 | self,
117 | mode,
118 | do_classifier_free_guidance,
119 | _attention_auto_machine_weight,
120 | _gn_auto_machine_weight,
121 | _style_fidelity,
122 | _reference_attn,
123 | _reference_adain,
124 | _dtype=torch.float16,
125 | batch_size=1,
126 | num_images_per_prompt=1,
127 | device=torch.device("cpu"),
128 | _fusion_blocks="midup",
129 | ):
130 | """
131 | Registers reference hooks for the model.
132 |
133 | This function is responsible for registering reference hooks in the model,
134 | which are used to modify the attention mechanism and group normalization layers.
135 | It takes various parameters as input, such as mode,
136 | do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity,
137 | _reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks.
138 |
139 | Args:
140 | self: Reference to the instance of the class.
141 | mode: The mode of operation for the reference hooks.
142 | do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance.
143 | _attention_auto_machine_weight: The weight for the attention auto-machine.
144 | _gn_auto_machine_weight: The weight for the group normalization auto-machine.
145 | _style_fidelity: The style fidelity for the reference hooks.
146 | _reference_attn: A boolean flag indicating whether to use reference attention.
147 | _reference_adain: A boolean flag indicating whether to use reference AdaIN.
148 | _dtype: The data type for the reference hooks.
149 | batch_size: The batch size for the reference hooks.
150 | num_images_per_prompt: The number of images per prompt for the reference hooks.
151 | device: The device for the reference hooks.
152 | _fusion_blocks: The fusion blocks for the reference hooks.
153 |
154 | Returns:
155 | None
156 | """
157 | MODE = mode
158 | if do_classifier_free_guidance:
159 | uc_mask = (
160 | torch.Tensor(
161 | [1] * batch_size * num_images_per_prompt * 16
162 | + [0] * batch_size * num_images_per_prompt * 16
163 | )
164 | .to(device)
165 | .bool()
166 | )
167 | else:
168 | uc_mask = (
169 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
170 | .to(device)
171 | .bool()
172 | )
173 |
174 | def hacked_basic_transformer_inner_forward(
175 | self,
176 | hidden_states: torch.FloatTensor,
177 | attention_mask: Optional[torch.FloatTensor] = None,
178 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
179 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
180 | timestep: Optional[torch.LongTensor] = None,
181 | cross_attention_kwargs: Dict[str, Any] = None,
182 | class_labels: Optional[torch.LongTensor] = None,
183 | video_length=None,
184 | ):
185 | gate_msa = None
186 | shift_mlp = None
187 | scale_mlp = None
188 | gate_mlp = None
189 |
190 | if self.use_ada_layer_norm: # False
191 | norm_hidden_states = self.norm1(hidden_states, timestep)
192 | elif self.use_ada_layer_norm_zero:
193 | (
194 | norm_hidden_states,
195 | gate_msa,
196 | shift_mlp,
197 | scale_mlp,
198 | gate_mlp,
199 | ) = self.norm1(
200 | hidden_states,
201 | timestep,
202 | class_labels,
203 | hidden_dtype=hidden_states.dtype,
204 | )
205 | else:
206 | norm_hidden_states = self.norm1(hidden_states)
207 |
208 | # 1. Self-Attention
209 | # self.only_cross_attention = False
210 | cross_attention_kwargs = (
211 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
212 | )
213 | if self.only_cross_attention:
214 | attn_output = self.attn1(
215 | norm_hidden_states,
216 | encoder_hidden_states=(
217 | encoder_hidden_states if self.only_cross_attention else None
218 | ),
219 | attention_mask=attention_mask,
220 | **cross_attention_kwargs,
221 | )
222 | else:
223 | if MODE == "write":
224 | self.bank.append(norm_hidden_states.clone())
225 | attn_output = self.attn1(
226 | norm_hidden_states,
227 | encoder_hidden_states=(
228 | encoder_hidden_states if self.only_cross_attention else None
229 | ),
230 | attention_mask=attention_mask,
231 | **cross_attention_kwargs,
232 | )
233 | if MODE == "read":
234 |
235 | bank_fea = [
236 | rearrange(
237 | rearrange(
238 | d,
239 | "(b s) l c -> b s l c",
240 | b=norm_hidden_states.shape[0] // video_length,
241 | )[:, 0, :, :]
242 | # .unsqueeze(1)
243 | .repeat(1, video_length, 1, 1),
244 | "b t l c -> (b t) l c",
245 | )
246 | for d in self.bank
247 | ]
248 | motion_frames_fea = [rearrange(
249 | d,
250 | "(b s) l c -> b s l c",
251 | b=norm_hidden_states.shape[0] // video_length,
252 | )[:, 1:, :, :] for d in self.bank]
253 | modify_norm_hidden_states = torch.cat(
254 | [norm_hidden_states] + bank_fea, dim=1
255 | )
256 | hidden_states_uc = (
257 | self.attn1(
258 | norm_hidden_states,
259 | encoder_hidden_states=modify_norm_hidden_states,
260 | attention_mask=attention_mask,
261 | )
262 | + hidden_states
263 | )
264 | if do_classifier_free_guidance:
265 | hidden_states_c = hidden_states_uc.clone()
266 | _uc_mask = uc_mask.clone()
267 | if hidden_states.shape[0] != _uc_mask.shape[0]:
268 | _uc_mask = (
269 | torch.Tensor(
270 | [1] * (hidden_states.shape[0] // 2)
271 | + [0] * (hidden_states.shape[0] // 2)
272 | )
273 | .to(device)
274 | .bool()
275 | )
276 | hidden_states_c[_uc_mask] = (
277 | self.attn1(
278 | norm_hidden_states[_uc_mask],
279 | encoder_hidden_states=norm_hidden_states[_uc_mask],
280 | attention_mask=attention_mask,
281 | )
282 | + hidden_states[_uc_mask]
283 | )
284 | hidden_states = hidden_states_c.clone()
285 | else:
286 | hidden_states = hidden_states_uc
287 |
288 | # self.bank.clear()
289 | if self.attn2 is not None:
290 | # Cross-Attention
291 | norm_hidden_states = (
292 | self.norm2(hidden_states, timestep)
293 | if self.use_ada_layer_norm
294 | else self.norm2(hidden_states)
295 | )
296 | hidden_states = (
297 | self.attn2(
298 | norm_hidden_states,
299 | encoder_hidden_states=encoder_hidden_states,
300 | attention_mask=attention_mask,
301 | )
302 | + hidden_states
303 | )
304 |
305 | # Feed-forward
306 | hidden_states = self.ff(self.norm3(
307 | hidden_states)) + hidden_states
308 |
309 | # Temporal-Attention
310 | if self.unet_use_temporal_attention:
311 | d = hidden_states.shape[1]
312 | hidden_states = rearrange(
313 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
314 | )
315 | norm_hidden_states = (
316 | self.norm_temp(hidden_states, timestep)
317 | if self.use_ada_layer_norm
318 | else self.norm_temp(hidden_states)
319 | )
320 | hidden_states = (
321 | self.attn_temp(norm_hidden_states) + hidden_states
322 | )
323 | hidden_states = rearrange(
324 | hidden_states, "(b d) f c -> (b f) d c", d=d
325 | )
326 |
327 | return hidden_states, motion_frames_fea
328 |
329 | if self.use_ada_layer_norm_zero:
330 | attn_output = gate_msa.unsqueeze(1) * attn_output
331 | hidden_states = attn_output + hidden_states
332 |
333 | if self.attn2 is not None:
334 | norm_hidden_states = (
335 | self.norm2(hidden_states, timestep)
336 | if self.use_ada_layer_norm
337 | else self.norm2(hidden_states)
338 | )
339 |
340 | # 2. Cross-Attention
341 | tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
342 | attn_output = self.attn2(
343 | norm_hidden_states,
344 | # TODO: repeat这个地方需要斟酌一下
345 | encoder_hidden_states=encoder_hidden_states.repeat(
346 | tmp, 1, 1),
347 | attention_mask=encoder_attention_mask,
348 | **cross_attention_kwargs,
349 | )
350 | hidden_states = attn_output + hidden_states
351 |
352 | # 3. Feed-forward
353 | norm_hidden_states = self.norm3(hidden_states)
354 |
355 | if self.use_ada_layer_norm_zero:
356 | norm_hidden_states = (
357 | norm_hidden_states *
358 | (1 + scale_mlp[:, None]) + shift_mlp[:, None]
359 | )
360 |
361 | ff_output = self.ff(norm_hidden_states)
362 |
363 | if self.use_ada_layer_norm_zero:
364 | ff_output = gate_mlp.unsqueeze(1) * ff_output
365 |
366 | hidden_states = ff_output + hidden_states
367 |
368 | return hidden_states
369 |
370 | if self.reference_attn:
371 | if self.fusion_blocks == "midup":
372 | attn_modules = [
373 | module
374 | for module in (
375 | torch_dfs(self.unet.mid_block) +
376 | torch_dfs(self.unet.up_blocks)
377 | )
378 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
379 | ]
380 | elif self.fusion_blocks == "full":
381 | attn_modules = [
382 | module
383 | for module in torch_dfs(self.unet)
384 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
385 | ]
386 | attn_modules = sorted(
387 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
388 | )
389 |
390 | for i, module in enumerate(attn_modules):
391 | module._original_inner_forward = module.forward
392 | if isinstance(module, BasicTransformerBlock):
393 | module.forward = hacked_basic_transformer_inner_forward.__get__(
394 | module,
395 | BasicTransformerBlock)
396 | if isinstance(module, TemporalBasicTransformerBlock):
397 | module.forward = hacked_basic_transformer_inner_forward.__get__(
398 | module,
399 | TemporalBasicTransformerBlock)
400 |
401 | module.bank = []
402 | module.attn_weight = float(i) / float(len(attn_modules))
403 |
404 | def update(self, writer, dtype=torch.float16):
405 | """
406 | Update the model's parameters.
407 |
408 | Args:
409 | writer (torch.nn.Module): The model's writer object.
410 | dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16.
411 |
412 | Returns:
413 | None.
414 | """
415 | if self.reference_attn:
416 | if self.fusion_blocks == "midup":
417 | reader_attn_modules = [
418 | module
419 | for module in (
420 | torch_dfs(self.unet.mid_block) +
421 | torch_dfs(self.unet.up_blocks)
422 | )
423 | if isinstance(module, TemporalBasicTransformerBlock)
424 | ]
425 | writer_attn_modules = [
426 | module
427 | for module in (
428 | torch_dfs(writer.unet.mid_block)
429 | + torch_dfs(writer.unet.up_blocks)
430 | )
431 | if isinstance(module, BasicTransformerBlock)
432 | ]
433 | elif self.fusion_blocks == "full":
434 | reader_attn_modules = [
435 | module
436 | for module in torch_dfs(self.unet)
437 | if isinstance(module, TemporalBasicTransformerBlock)
438 | ]
439 | writer_attn_modules = [
440 | module
441 | for module in torch_dfs(writer.unet)
442 | if isinstance(module, BasicTransformerBlock)
443 | ]
444 |
445 | assert len(reader_attn_modules) == len(writer_attn_modules)
446 | reader_attn_modules = sorted(
447 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
448 | )
449 | writer_attn_modules = sorted(
450 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
451 | )
452 | for r, w in zip(reader_attn_modules, writer_attn_modules):
453 | r.bank = [v.clone().to(dtype) for v in w.bank]
454 |
455 |
456 | def clear(self):
457 | """
458 | Clears the attention bank of all reader attention modules.
459 |
460 | This method is used when the `reference_attn` attribute is set to `True`.
461 | It clears the attention bank of all reader attention modules inside the UNet
462 | model based on the selected `fusion_blocks` mode.
463 |
464 | If `fusion_blocks` is set to "midup", it searches for reader attention modules
465 | in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set
466 | to "full", it searches for reader attention modules in the entire UNet model.
467 |
468 | It sorts the reader attention modules by the number of neurons in their
469 | `norm1.normalized_shape[0]` attribute in descending order. This sorting ensures
470 | that the modules with more neurons are cleared first.
471 |
472 | Finally, it iterates through the sorted list of reader attention modules and
473 | calls the `clear()` method on each module's `bank` attribute to clear the
474 | attention bank.
475 | """
476 | if self.reference_attn:
477 | if self.fusion_blocks == "midup":
478 | reader_attn_modules = [
479 | module
480 | for module in (
481 | torch_dfs(self.unet.mid_block) +
482 | torch_dfs(self.unet.up_blocks)
483 | )
484 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
485 | ]
486 | elif self.fusion_blocks == "full":
487 | reader_attn_modules = [
488 | module
489 | for module in torch_dfs(self.unet)
490 | if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
491 | ]
492 | reader_attn_modules = sorted(
493 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
494 | )
495 | for r in reader_attn_modules:
496 | r.bank.clear()
497 |
--------------------------------------------------------------------------------